diff --git a/.codespellignore b/.codespellignore deleted file mode 100644 index 26c58f9..0000000 --- a/.codespellignore +++ /dev/null @@ -1,4 +0,0 @@ -coo -daty -anormal -wheight diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2b5c781..2b04065 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: os: [ubuntu, macos, windows] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] name: ${{ matrix.os }} - py${{ matrix.python-version }} runs-on: ${{ matrix.os }}-latest defaults: @@ -28,6 +28,9 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system OpenGL/EGL libraries (Ubuntu) + if: matrix.os == 'ubuntu' + run: sudo apt-get install -y --no-install-recommends libegl1 libgl1 - name: Install dependencies run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index 889727e..ecc3e00 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -15,10 +15,11 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v6 - - name: Setup Python 3.10 + - name: Setup Python 3.13 uses: actions/setup-python@v6 with: - python-version: '3.10' + python-version: '3.13' + cache: 'pip' - name: Install dependencies run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel @@ -27,10 +28,5 @@ jobs: run: ruff check . - name: Run codespell uses: codespell-project/actions-codespell@master - with: - check_filenames: true - check_hidden: true - skip: './.git,./build,./.mypy_cache,./.pytest_cache' - ignore_words_file: ./.codespellignore - # - name: Run pydocstyle - # run: pydocstyle . + - name: Run pydocstyle + run: pydocstyle . diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 2290615..bf01390 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -20,30 +20,35 @@ jobs: uses: actions/checkout@v6 with: path: ./main - - name: Setup Python 3.10 + - name: Setup Python 3.13 uses: actions/setup-python@v6 with: - python-version: '3.10' - - name: Install Linux dependencies - if: runner.os == 'Linux' + python-version: '3.13' + cache: pip + - name: Install system dependencies run: | - sudo apt update sudo apt-get update - sudo apt-get -y install libgl1 libegl1 - sudo apt install libxcb-cursor0 + sudo apt-get -y install libgl1 libegl1 libxcb-cursor0 pandoc - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off main/.[doc] - name: Display system information run: whippersnappy-sys_info --developer + - name: Cache WhipperSnapPy sample data + uses: actions/cache@v4 + with: + path: ~/.cache/whippersnappy + key: sample-data-sub-rs-data-v1.0 - name: Build doc run: TZ=UTC sphinx-build ./main/doc ./doc-build/dev -W --keep-going - name: Upload documentation uses: actions/upload-artifact@v6 with: name: doc-dev - path: ./doc-build/dev + path: | + doc-build/dev + !doc-build/dev/.doctrees deploy: if: github.event_name == 'push' diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index f427086..59f2311 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: os: [ubuntu, macos, windows] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] name: ${{ matrix.os }} - py${{ matrix.python-version }} runs-on: ${{ matrix.os }}-latest defaults: @@ -32,6 +32,9 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system OpenGL/EGL libraries (Ubuntu) + if: matrix.os == 'ubuntu' + run: sudo apt-get install -y --no-install-recommends libegl1 libgl1 - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel @@ -39,7 +42,7 @@ jobs: - name: Display system information run: whippersnappy-sys_info --developer - name: Run pytest - run: pytest whippersnappy --cov=whippersnappy --cov-report=xml --cov-config=pyproject.toml + run: pytest tests --cov=whippersnappy --cov-report=xml --cov-config=pyproject.toml - name: Upload to codecov if: ${{ matrix.os == 'ubuntu' && matrix.python-version == '3.10' && github.repository == 'Deep-MI/WhipperSnapPy' }} uses: codecov/codecov-action@v5 diff --git a/DOCKER.md b/DOCKER.md new file mode 100644 index 0000000..fbdbc33 --- /dev/null +++ b/DOCKER.md @@ -0,0 +1,170 @@ +# Docker Guide + +The Docker image provides a fully headless rendering environment with EGL +off-screen support. No display server or `xvfb` is required. + +The default entry point is `whippersnap4` (four-view batch rendering). +`whippersnap1` (single-view snapshot and rotation video) can be invoked by +overriding the entry point. + +--- + +## Building the image + +From the repository root: + +```bash +docker build --rm -t whippersnappy -f Dockerfile . +``` + +--- + +## Running — four-view batch rendering (`whippersnap4`) + +`whippersnap4` renders lateral and medial views of both hemispheres and writes +a single composed PNG image. + +Mount your local directories into the container and pass the in-container paths +as arguments: + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + -lh /subject/surf/lh.thickness \ + -rh /subject/surf/rh.thickness \ + -sd /subject \ + -o /output/snap4.png +``` + +### With an annotation file instead of an overlay + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + --lh_annot /subject/label/lh.aparc.annot \ + --rh_annot /subject/label/rh.aparc.annot \ + -sd /subject \ + -o /output/snap4_annot.png +``` + +### With a caption and custom thresholds + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + -lh /subject/surf/lh.thickness \ + -rh /subject/surf/rh.thickness \ + -sd /subject \ + --fthresh 2.0 --fmax 4.0 \ + --caption "Cortical thickness" \ + -o /output/snap4_thickness.png +``` + +### All `whippersnap4` options + +``` +docker run --rm whippersnappy --help +``` + +--- + +## Running — single-view snapshot (`whippersnap1`) + +Override the entry point with `--entrypoint whippersnap1`: + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + --mesh /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --bg-map /subject/surf/lh.curv \ + --view left \ + --fthresh 2.0 --fmax 4.0 \ + -o /output/snap1.png +``` + +### All `whippersnap1` options + +```bash +docker run --rm --entrypoint whippersnap1 whippersnappy --help +``` + +--- + +## Running — 360° rotation video (`whippersnap1 --rotate`) + +`whippersnap1 --rotate` renders a full 360° rotation video and writes an +`.mp4`, `.webm`, or `.gif` file. `imageio-ffmpeg` is bundled in the image — +no system ffmpeg is required. + +### MP4 (H.264, recommended) + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + --mesh /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --bg-map /subject/surf/lh.curv \ + --rotate \ + --rotate-frames 72 \ + --rotate-fps 24 \ + -o /output/rotation.mp4 +``` + +### Animated GIF (no ffmpeg needed) + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + --mesh /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --rotate \ + --rotate-frames 36 \ + --rotate-fps 12 \ + -o /output/rotation.gif +``` + +--- + +## Path mapping summary + +| Host path | Container path | Purpose | +|-----------|---------------|---------| +| `/path/to/subject` | `/subject` | FreeSurfer subject directory (contains `surf/`, `label/`) | +| `/path/to/output` | `/output` | Directory where output files are written | + +All output files are written to the container path you pass via `-o`; mount the +parent directory to retrieve them on the host. + +--- + +## Notes + +- The `--init` flag is recommended so that signals (e.g. `Ctrl-C`) are handled + correctly inside the container. +- `--user $(id -u):$(id -g)` ensures output files are owned by your host user, + not root. +- The interactive GUI (`whippersnap`) is **not** available in the Docker image — + it requires a display server and PyQt6, which are not installed. + diff --git a/Dockerfile b/Dockerfile index 8a12393..ea56592 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,18 @@ -FROM ubuntu:20.04 +FROM python:3.11-slim -# Install packages RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 pip xvfb libglib2.0-0 libxkbcommon-x11-0 libgl1 libegl1 \ - libfontconfig1 libdbus-1-3 && \ - apt clean && \ - rm -rf /var/libs/apt/lists/* /tmp/* /var/tmp/* + libegl1 \ + libgl1 \ + libglib2.0-0 \ + libfontconfig1 \ + libdbus-1-3 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* -# Install python packages RUN pip install --upgrade pip -RUN pip install pyopengl glfw pillow numpy pyrr PyQt6 COPY . /WhipperSnapPy -RUN pip install /WhipperSnapPy +RUN pip install /WhipperSnapPy[video] -ENTRYPOINT ["xvfb-run","whippersnap"] +ENTRYPOINT ["whippersnap4"] CMD ["--help"] diff --git a/README.md b/README.md index 1b2895f..0a7f81a 100644 --- a/README.md +++ b/README.md @@ -1,80 +1,169 @@ -# WhipperSnapPy - -WhipperSnapPY is a small Python OpenGL program to render FreeSurfer and -FastSurfer surface models and color overlays and generate screen shots. - -## Contents: - -- Capture 4x4 surface plots (front & back, left and right) -- OpenGL window for interactive visualization - -## Installation: - -The `WhipperSnapPy` package can be installed from pypi via -``` -python3 -m pip install whippersnappy -``` - -Note, that currently no off-screen rendering is natively supported. Even in snap -mode an invisible window will be created to render the openGL output -and capture the contents to an image. In order to run this on a headless -server, inside Docker, or via ssh we recommend to install xvfb and run - -``` -sudo apt update && apt install -y python3 python3-pip xvfb libxcb-xinerama0 -pip3 install pyopengl glfw pillow numpy pyrr PyQt6 -pip3 install whippersnappy -xvfb-run whippersnap ... -``` - -## Usage: - -### Local: - -After installing the Python package, the whippersnap program can be run using -the installed command line tool such as in the following example: -``` -whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ - -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR \ - --fmax 4 --fthresh 2 --invert \ - --caption caption.txt \ - -o $OUTPUT_DIR/whippersnappy_image.png \ -``` - -For more options see `whippersnap --help`. -Note, that adding the `--interactive` flag will start an interactive GUI that -includes a visualization of one hemisphere side and a simple application through -which color threshold values can be configured. - -### Docker: - -The whippersnap program can be run within a docker container to capture -a snapshot by building the provided Docker image and running a container as -follows: -``` -docker build --rm=true -t whippersnappy -f ./Dockerfile . -``` -``` -docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ - -v $OVERLAY_DIR:/overlay_dir \ - -v $OUTPUT_DIR:/output_dir \ - --user $(id -u):$(id -g) whippersnappy:latest \ - --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ - --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ - --sdir /surf_subject_dir \ - --output_path /output_dir/whippersnappy_image.png -``` - -In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. - -**Note:** The `--init` flag to Docker is needed for the `xvfb-run` tool to be used correctly for off screen rendering. - - -## API Documentation - -The API Documentation can be found at https://deep-mi.org/WhipperSnapPy . - -## Links: - -We also invite you to check out our lab webpage at https://deep-mi.org +# WhipperSnapPy + +WhipperSnapPy is a Python/OpenGL tool to render triangular surface meshes +with color overlays or parcellations and generate screenshots — from the +command line, in Jupyter notebooks, or via a desktop GUI. + +It works with FreeSurfer and FastSurfer brain surfaces as well as any +triangle mesh in OFF, legacy ASCII VTK PolyData, ASCII PLY, or GIfTI (.gii, .surf.gii) format, or +passed directly as a NumPy ``(vertices, faces)`` tuple. + +## Installation + +```bash +pip install whippersnappy +``` + +For rotation video support (MP4/WebM): + +```bash +pip install 'whippersnappy[video]' +``` + +For the interactive desktop GUI: + +```bash +pip install 'whippersnappy[gui]' +``` + +For interactive 3D in Jupyter notebooks: + +```bash +pip install 'whippersnappy[notebook]' +``` + +Off-screen (headless) rendering is supported natively via EGL on Linux — no +`xvfb` required. See the Docker guide for headless usage. + +## Command-Line Usage + +After installation the following commands are available: + +### Four-view snapshot (`whippersnap4`) + +Renders lateral and medial views of both hemispheres into a single composed image: + +```bash +whippersnap4 -lh $LH_OVERLAY \ + -rh $RH_OVERLAY \ + -sd $SUBJECT_DIR \ + --fmax 4 --fthresh 2 \ + --caption "Cortical Thickness" \ + -o snap4.png +``` + +### Single-view snapshot (`whippersnap1`) + +Renders one view of any triangular surface mesh: + +```bash +whippersnap1 --mesh $SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --bg-map $SUBJECT_DIR/surf/lh.curv \ + --roi $SUBJECT_DIR/label/lh.cortex.label \ + --view left \ + -o snap1.png + +# Also works with OFF / VTK / PLY +whippersnap1 --mesh mesh.off --overlay values.mgh -o snap1.png +whippersnap1 --mesh surface.surf.gii --overlay overlay.func.gii -o snap1.png +``` + +### Rotation video (`whippersnap1 --rotate`) + +Renders a 360° animation of any triangular surface mesh: + +```bash +whippersnap1 --mesh $SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --rotate \ + -o rotation.mp4 +``` + +### Desktop GUI (`whippersnap`) + +Launches an interactive Qt window with live threshold controls. + +**General mode** — any triangular mesh: + +```bash +pip install 'whippersnappy[gui]' +whippersnap --mesh mesh.off --overlay values.mgh +whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv +``` + +**FreeSurfer shortcut** — derive all paths from a subject directory: + +```bash +whippersnap -sd $SUBJECT_DIR --hemi lh -lh $LH_OVERLAY +whippersnap -sd $SUBJECT_DIR --hemi rh --annot rh.aparc.annot +``` + +For all options run `whippersnap4 --help`, `whippersnap1 --help`, or `whippersnap --help`. + +## Python API + +```python +from whippersnappy import snap1, snap4, snap_rotate, plot3d +``` + +| Function | Description | +|---|---| +| `snap1` | Single-view snapshot of any triangular mesh → PIL Image | +| `snap4` | Four-view composed image (FreeSurfer subject, lateral/medial both hemispheres) | +| `snap_rotate` | 360° rotation video of any triangular surface mesh (MP4, WebM, or GIF) | +| `plot3d` | Interactive 3D WebGL viewer for Jupyter notebooks | + +**Supported mesh inputs for `snap1`, `snap_rotate`, and `plot3d`:** +FreeSurfer binary surfaces (e.g. `lh.white`), OFF (`.off`), legacy ASCII VTK PolyData (`.vtk`), ASCII PLY (`.ply`), GIfTI surface (`.gii`, `.surf.gii`), or a `(vertices, faces)` NumPy array tuple. + +**Supported overlay/label inputs:** +FreeSurfer morph (`.curv`, `.thickness`), MGH/MGZ, ASCII (`.txt`, `.csv`), NumPy (`.npy`, `.npz`), GIfTI functional/label (`.func.gii`, `.label.gii`, `.gii`). + +### Example + +```python +from whippersnappy import snap1, snap4 + +# FreeSurfer surface with overlay +img = snap1('lh.white', + overlay='lh.thickness', + bg_map='lh.curv', + roi='lh.cortex.label') +img.save('snap1.png') + +# Four-view overview (FreeSurfer subject directory) +img = snap4(sdir='/path/to/subject', + lh_overlay='/path/to/lh.thickness', + rh_overlay='/path/to/rh.thickness', + colorbar=True, caption='Cortical Thickness (mm)') +img.save('snap4.png') + +# OFF / VTK / PLY / GIfTI mesh +img = snap1('mesh.off', overlay='values.mgh') +img = snap1('surface.surf.gii', overlay='overlay.func.gii') + +# Array inputs (e.g. from LaPy or trimesh) +import numpy as np +v = np.random.randn(1000, 3).astype(np.float32) +f = np.array([[0, 1, 2]], dtype=np.uint32) +overlay = np.random.randn(1000).astype(np.float32) +img = snap1((v, f), overlay=overlay) +``` + + +See `tutorials/whippersnappy_tutorial.ipynb` for complete notebook examples. + + +## Docker + +The Docker image provides a fully headless EGL rendering environment — no +display server or `xvfb` required. See DOCKER.md for details. + +## API Documentation + +https://deep-mi.org/WhipperSnapPy + +## Links + +Lab webpage: https://deep-mi.org diff --git a/doc/DOCKER.md b/doc/DOCKER.md new file mode 120000 index 0000000..6916342 --- /dev/null +++ b/doc/DOCKER.md @@ -0,0 +1 @@ +../DOCKER.md \ No newline at end of file diff --git a/doc/README.md b/doc/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/doc/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst index cdbecc4..27ca987 100644 --- a/doc/_templates/autosummary/function.rst +++ b/doc/_templates/autosummary/function.rst @@ -3,6 +3,3 @@ .. currentmodule:: {{ module }} .. autofunction:: {{ objname }} - -.. minigallery:: {{ fullname }} - :add-heading: diff --git a/doc/api/cli.rst b/doc/api/cli.rst new file mode 100644 index 0000000..6363ba2 --- /dev/null +++ b/doc/api/cli.rst @@ -0,0 +1,23 @@ +Command-Line Interfaces +======================= + +whippersnap1 +------------ + +.. automodule:: whippersnappy.cli.whippersnap1 + :members: + :member-order: bysource + +whippersnap4 +------------ + +.. automodule:: whippersnappy.cli.whippersnap4 + :members: + :member-order: bysource + +whippersnap +----------- + +.. automodule:: whippersnappy.cli.whippersnap + :members: + :member-order: bysource diff --git a/doc/api/index.rst b/doc/api/index.rst index 2b668c7..deed527 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -1,14 +1,13 @@ -API References -============== +.. _api_ref: +API Reference +============= .. currentmodule:: whippersnappy -.. autosummary:: - :toctree: generated/ +.. toctree:: + :maxdepth: 2 - - config_app.ConfigWindow - core - cli.whippersnap - + cli.rst + snap.rst + plot3d.rst diff --git a/doc/api/plot3d.rst b/doc/api/plot3d.rst new file mode 100644 index 0000000..3546ce2 --- /dev/null +++ b/doc/api/plot3d.rst @@ -0,0 +1,6 @@ +plot3d +====== + +.. automodule:: whippersnappy.plot3d + :members: + :member-order: bysource diff --git a/doc/api/snap.rst b/doc/api/snap.rst new file mode 100644 index 0000000..5c8eea8 --- /dev/null +++ b/doc/api/snap.rst @@ -0,0 +1,6 @@ +snap +==== + +.. automodule:: whippersnappy.snap + :members: + :member-order: bysource diff --git a/doc/conf.py b/doc/conf.py index 1bc8980..4707b81 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,8 +12,6 @@ from importlib import import_module from typing import Dict, Optional -# from sphinx_gallery.sorting import FileNameSortKey - import whippersnappy project = "WhipperSnapPy" @@ -37,6 +35,7 @@ # extensions coming with Sphinx (named "sphinx.ext.*") or your custom # ones. extensions = [ + "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.autosectionlabel", "sphinx.ext.autosummary", @@ -46,16 +45,25 @@ "sphinxcontrib.bibtex", "sphinx_copybutton", "sphinx_design", + "nbsphinx", + "IPython.sphinxext.ipython_console_highlighting", ] -templates_path = ["_templates"] +# .md files are included via '.. include:: :parser: myst_parser.sphinx_' +# in the RST stubs; they must NOT be registered as standalone Sphinx source +# documents or autosectionlabel will produce duplicate-label warnings from +# both the real file and the doc/ symlink. exclude_patterns = [ "_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints", + "README.md", # symlinked from root — included inline via rst, not as a page + "../*.md", # exclude root-level .md files ] +templates_path = ["_templates"] + # Sphinx will warn about all references where the target cannot be found. nitpicky = False nitpick_ignore = [] @@ -93,7 +101,8 @@ } # -- autosummary ------------------------------------------------------------- -autosummary_generate = True +# API stubs use automodule directly — no generated/ dir needed. +autosummary_generate = False # -- autodoc ----------------------------------------------------------------- autodoc_typehints = "none" @@ -101,6 +110,26 @@ autodoc_warningiserror = True autoclass_content = "class" +# Mock modules that may not be available in the doc builder environment +# (PyQt6, OpenGL, GLFW, pythreejs, etc.). Adjust this list if your builder +# provides any of these packages. +autodoc_mock_imports = [ + "PyQt6", + "PyQt6.QtWidgets", + "PyQt6.QtCore", + "PyQt6.QtGui", + "glfw", + "OpenGL", + "OpenGL.GL", + "OpenGL.GL.shaders", + "pythreejs", + "ipywidgets", + "pyopengl", + "pyrr", + "PIL", + "matplotlib", +] + # -- intersphinx ------------------------------------------------------------- intersphinx_mapping = { "matplotlib": ("https://matplotlib.org/stable", None), @@ -162,6 +191,9 @@ r"\.__iter__", r"\.__div__", r"\.__neg__", + # Imported third-party objects exposed in plot3d module + r"\.HTML$", + r"\.VBox$", } # -- sphinxcontrib-bibtex ---------------------------------------------------- @@ -232,6 +264,19 @@ def linkcode_resolve(domain: str, info: Dict[str, str]) -> Optional[str]: # } +# -- nbsphinx ---------------------------------------------------------------- +# Re-execute notebooks during the Sphinx build so outputs appear in the docs. +# Notebooks are executed with the kernel specified by nbsphinx_kernel_name. +# The sample data is fetched from the GitHub release assets (or from the +# local sub-rs/ directory in the repo when the release is not yet published). +nbsphinx_execute = "auto" + +# Kernel to use for execution (must be installed: pip install ipykernel). +nbsphinx_kernel_name = "python3" + +# Maximum execution time per cell (seconds). +nbsphinx_timeout = 600 + # -- make sure pandoc gets installed ----------------------------------------- from inspect import getsourcefile import os @@ -240,20 +285,18 @@ def linkcode_resolve(domain: str, info: Dict[str, str]) -> Optional[str]: DOCS_DIRECTORY = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0))) def ensure_pandoc_installed(_): - import pypandoc - - # Download pandoc if necessary. If pandoc is already installed and on - # the PATH, the installed version will be used. Otherwise, we will - # download a copy of pandoc into docs/bin/ and add that to our PATH. - pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin") - # Add dir containing pandoc binary to the PATH environment variable - if pandoc_dir not in os.environ["PATH"].split(os.pathsep): - os.environ["PATH"] += os.pathsep + pandoc_dir - pypandoc.ensure_pandoc_installed( - targetfolder=pandoc_dir, - delete_installer=True, - ) + try: + import pypandoc + pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin") + if pandoc_dir not in os.environ["PATH"].split(os.pathsep): + os.environ["PATH"] += os.pathsep + pandoc_dir + pypandoc.ensure_pandoc_installed( + targetfolder=pandoc_dir, + delete_installer=True, + ) + except Exception: + pass # pandoc already on PATH (CI) or download failed (local SSL) — continue + def setup(app): app.connect("builder-inited", ensure_pandoc_installed) - diff --git a/doc/index.rst b/doc/index.rst index 439547d..ed11c86 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,96 +1,11 @@ .. include:: ./links.inc - -.. whippersnappy documentation master file, created by - sphinx-quickstart on Wed Jul 12 09:51:03 2023. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to whippersnappy's documentation! -========================================= - -WhipperSnapPY is a small Python OpenGL program to render -FreeSurfer and FastSurfer surface models and color overlays and generate screen shots. - -License -------- - -`Whippersnappy `_ is licensed under the `MIT license`_. -A full copy of the license can be found `on GitHub `_. - -Contents --------- - -- Capture 4x4 surface plots (front & back, left and right) -- OpenGL window for interactive visualization - -Note, that currently no off-screen rendering is supported. Even in snap mode an invisible window will be created to render the openGL output and capture the contents to an image. In order to run this on a headless server, inside Docker, or via ssh we recommend to install xvfb and run - -.. code-block:: bash - - apt update && apt install -y python3 python3-pip xvfb - pip3 install pyopengl glfw pillow numpy pyrr PyQt5==5.15.6 - pip3 install . - xvfb-run whippersnap ... - -Installation ------------- - -The `Whippersnappy `_ package can be installed from this repository using: - -.. code-block:: bash - - python3 -m pip install . - -Usage ------ - -Local -''''' - -After installing the Python package, the whippersnap program can be run using the installed command line tool such as in the following example: - -.. code-block:: bash - - whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ - -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR \ - -o $OUTPUT_DIR/whippersnappy_image.png - -Note that adding the `--interactive` flag will start an interactive GUI that includes a visualization of one hemisphere side and a simple application through which color threshold values can be configured. - -Docker -'''''' - -the whippersnap program can be run within a docker container to capture a snapshot by building the provided Docker image and running a container as follows: - -.. code-block:: bash - - docker build --rm=true -t whippersnappy -f ./Dockerfile . - -.. code-block:: bash - - docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ - -v $OVERLAY_DIR:/overlay_dir \ - -v $OUTPUT_DIR:/output_dir \ - --user $(id -u):$(id -g) whippersnappy:latest \ - --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ - --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ - --sdir /surf_subject_dir \ - --output_path /output_dir/whippersnappy_image.png - -In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. - -**Note:** The `--init` flag is needed for the `xvfb-run` tool to be used correctly. - -Links ------ - -We also invite you to check out our lab webpage at https://deep-mi.org +.. include:: README.md + :parser: myst_parser.sphinx_ .. toctree:: :hidden: + DOCKER.md + tutorials/index api/index - - diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst new file mode 100644 index 0000000..aa08c0f --- /dev/null +++ b/doc/tutorials/index.rst @@ -0,0 +1,12 @@ +Tutorials +========= + +Hands-on notebooks demonstrating WhipperSnapPy's surface visualization +functionality. Each notebook can be downloaded and run locally — just +replace ``sdir = ""`` with the path to your own FastSurfer subject directory. + +.. toctree:: + :maxdepth: 1 + + whippersnappy_tutorial.ipynb + diff --git a/doc/tutorials/whippersnappy_tutorial.ipynb b/doc/tutorials/whippersnappy_tutorial.ipynb new file mode 120000 index 0000000..9a82e1b --- /dev/null +++ b/doc/tutorials/whippersnappy_tutorial.ipynb @@ -0,0 +1 @@ +../../tutorials/whippersnappy_tutorial.ipynb \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 88dafe6..e2fc827 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = 'whippersnappy' -version = '1.4.0-dev' +version = '2.0.0-dev' description = 'A package to plot and capture FastSurfer and FreeSurfer-style surface overlays.' readme = 'README.md' license = {file = 'LICENSE'} @@ -26,10 +26,10 @@ classifiers = [ 'Operating System :: Unix', 'Operating System :: MacOS', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Natural Language :: English', 'License :: OSI Approved :: MIT License', 'Intended Audience :: Science/Research', @@ -38,10 +38,9 @@ dependencies = [ 'glfw', 'numpy>=1.21', 'pyrr', - 'pillow', - 'pyopengl==3.1.6', + 'pillow>=9.1', + 'pyopengl>=3.1.8', 'nibabel', - 'PyQt6', 'psutil' ] @@ -54,17 +53,34 @@ doc = [ 'furo!=2023.8.17', 'matplotlib', 'memory-profiler', + 'myst-parser', 'numpydoc', 'sphinx!=7.2.*', 'sphinxcontrib-bibtex', 'sphinx-copybutton', 'sphinx-design', - 'sphinx-gallery', 'sphinx-issues', 'pypandoc', 'nbsphinx', 'IPython', # For syntax highlighting in notebooks 'ipykernel', + # Needed to execute the tutorial notebook via nbsphinx: + 'pooch>=1.6', + 'pythreejs', + 'ipywidgets', + 'imageio>=2.28', +] +notebook = [ + 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) + 'ipywidgets', # Required for pythreejs + 'pooch>=1.6', +] +gui = [ + 'PyQt6', +] +video = [ + 'imageio>=2.28', + 'imageio-ffmpeg>=0.4.9', # bundles its own ffmpeg binary ] style = [ 'bibclean', @@ -82,6 +98,9 @@ all = [ 'whippersnappy[doc]', 'whippersnappy[style]', 'whippersnappy[test]', + 'whippersnappy[notebook]', + 'whippersnappy[gui]', + 'whippersnappy[video]', ] full = [ 'whippersnappy[all]', @@ -94,18 +113,20 @@ source = 'https://github.com/Deep-MI/WhipperSnapPy' tracker = 'https://github.com/Deep-MI/WhipperSnapPy/issues' [project.scripts] -whippersnap = 'whippersnappy.cli:run' +whippersnap = 'whippersnappy.cli.whippersnap:run' +whippersnap1 = 'whippersnappy.cli.whippersnap1:run' +whippersnap4 = 'whippersnappy.cli.whippersnap4:run' whippersnappy-sys_info = 'whippersnappy.commands.sys_info:run' [tool.setuptools] -include-package-data = false +include-package-data = true [tool.setuptools.packages.find] include = ['whippersnappy*'] exclude = ['whippersnappy*tests'] [tool.setuptools.package-data] -whippersnappy = ['*.ttf'] +whippersnappy = ['resources/fonts/*.ttf'] [tool.pydocstyle] convention = 'numpy' @@ -135,6 +156,12 @@ select = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +[tool.codespell] +ignore-words-list = 'aNormal,wheight' +check-filenames = true +check-hidden = true +skip = './.git,./build,./.mypy_cache,./.pytest_cache' + [tool.pytest.ini_options] minversion = '6.0' addopts = '--durations 20 --junit-xml=junit-results.xml --verbose' diff --git a/whippersnappy/utils/tests/__init__.py b/tests/__init__.py similarity index 100% rename from whippersnappy/utils/tests/__init__.py rename to tests/__init__.py diff --git a/tests/data/tetra.off b/tests/data/tetra.off new file mode 100644 index 0000000..def3b95 --- /dev/null +++ b/tests/data/tetra.off @@ -0,0 +1,12 @@ +OFF +# Minimal tetrahedron — 4 vertices, 4 triangular faces +4 4 6 +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 + diff --git a/tests/test_array_and_rendering.py b/tests/test_array_and_rendering.py new file mode 100644 index 0000000..b5c723d --- /dev/null +++ b/tests/test_array_and_rendering.py @@ -0,0 +1,270 @@ +"""Tests for the array-input pathway introduced in v2.0-rc. + +These tests exercise the resolver functions and the geometry-preparation +pipeline entirely without touching any file on disk, using small synthetic +triangle meshes. +""" + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import ( + resolve_annot, + resolve_bg_map, + resolve_mesh, + resolve_overlay, + resolve_roi, +) +from whippersnappy.geometry.prepare import ( + estimate_overlay_thresholds, + prepare_geometry, + prepare_geometry_from_arrays, +) + +# --------------------------------------------------------------------------- +# Minimal synthetic mesh (tetrahedron) +# --------------------------------------------------------------------------- + +_V = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) +_F = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) +_N = _V.shape[0] # 4 vertices + + +# --------------------------------------------------------------------------- +# resolve_mesh +# --------------------------------------------------------------------------- + +class TestResolveMesh: + def test_valid_inputs(self): + v, f = resolve_mesh((_V, _F)) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + # list input should also work + v2, f2 = resolve_mesh([_V, _F]) + assert v2.shape == (4, 3) and f2.shape == (4, 3) + + def test_invalid_inputs_raise(self): + with pytest.raises(TypeError): + resolve_mesh(42) + with pytest.raises(ValueError): + resolve_mesh((np.ones((4, 4), dtype=np.float32), _F)) + with pytest.raises(ValueError): + resolve_mesh((_V, np.ones((4, 4), dtype=np.uint32))) + + +# --------------------------------------------------------------------------- +# resolve_overlay / resolve_bg_map (identical logic, tested together) +# --------------------------------------------------------------------------- + +class TestResolveScalarOverlay: + """Tests for resolve_overlay and resolve_bg_map (same logic).""" + + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_none_returns_none(self, fn): + assert fn(None, n_vertices=_N) is None + + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_array_input_shape_and_dtype(self, fn): + arr = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + result = fn(arr, n_vertices=_N) + assert result.shape == (_N,) and result.dtype == np.float32 + + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_shape_mismatch_raises(self, fn): + with pytest.raises(ValueError): + fn(np.ones(2), n_vertices=_N) + + def test_n_vertices_none_skips_shape_check(self): + arr = np.array([0.1, 0.5], dtype=np.float32) + assert resolve_overlay(arr, n_vertices=None).shape == (2,) + + +# --------------------------------------------------------------------------- +# resolve_roi +# --------------------------------------------------------------------------- + +class TestResolveRoi: + def test_none_returns_none(self): + assert resolve_roi(None, n_vertices=_N) is None + + def test_bool_array(self): + roi = np.array([True, True, True, False], dtype=bool) + result = resolve_roi(roi, n_vertices=_N) + assert result.dtype == bool and result.shape == (_N,) + assert result[3] is np.bool_(False) + + def test_shape_mismatch_raises(self): + with pytest.raises(ValueError): + resolve_roi(np.ones(2, dtype=bool), n_vertices=_N) + + +# --------------------------------------------------------------------------- +# resolve_annot +# --------------------------------------------------------------------------- + +class TestResolveAnnot: + def test_none_returns_none(self): + assert resolve_annot(None, n_vertices=_N) is None + + def test_two_and_three_tuple(self): + labels = np.array([0, 1, 0, 1]) + ctab = np.array([[255, 0, 0, 0, 0], [0, 255, 0, 0, 1]]) + # two-tuple: names should be None + r2 = resolve_annot((labels, ctab), n_vertices=_N) + assert len(r2) == 3 and r2[2] is None + # three-tuple: names passed through + names = ["a", "b"] + r3 = resolve_annot((labels, ctab, names), n_vertices=_N) + assert r3[2] == names + + def test_invalid_inputs_raise(self): + with pytest.raises(TypeError): + resolve_annot(42, n_vertices=_N) + with pytest.raises(ValueError): + resolve_annot((np.zeros(2, dtype=int), np.array([[255, 0, 0, 0, 0]])), n_vertices=_N) + + +# --------------------------------------------------------------------------- +# estimate_overlay_thresholds +# --------------------------------------------------------------------------- + +class TestEstimateOverlayThresholds: + def test_auto_and_passthrough(self): + arr = np.array([1.0, 2.0, 3.0, -1.5], dtype=np.float32) + fmin, fmax = estimate_overlay_thresholds(arr) + assert fmin >= 0 and fmax == pytest.approx(3.0) + # explicit values passed through unchanged + fmin2, fmax2 = estimate_overlay_thresholds(arr, minval=0.5, maxval=5.0) + assert fmin2 == pytest.approx(0.5) and fmax2 == pytest.approx(5.0) + + +# --------------------------------------------------------------------------- +# prepare_geometry_from_arrays — pure array pipeline +# --------------------------------------------------------------------------- + +class TestPrepareGeometryFromArrays: + def test_no_overlay(self): + vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays(_V, _F) + assert vdata.shape == (_N, 9) and tris.shape == (4, 3) + assert fmin is None and fmax is None + + def test_with_overlay_bg_map_roi(self): + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays( + _V, _F, overlay=overlay, bg_map=bg, roi=roi + ) + assert vdata.shape == (_N, 9) and fmin is not None + + def test_overlay_shape_mismatch_raises(self): + with pytest.raises(ValueError): + prepare_geometry_from_arrays(_V, _F, overlay=np.array([0.1, 0.5])) + + +# --------------------------------------------------------------------------- +# prepare_geometry — thin wrapper (array path) +# --------------------------------------------------------------------------- + +class TestPrepareGeometry: + def test_tuple_mesh_various_inputs(self): + """One call covers: tuple mesh, overlay, roi, bg_map — all together.""" + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry( + (_V, _F), overlay=overlay, roi=roi, bg_map=bg + ) + assert vdata.shape == (_N, 9) and fmin is not None + + def test_invalid_mesh_type_raises(self): + with pytest.raises(TypeError): + prepare_geometry(12345) + + +# --------------------------------------------------------------------------- +# snap1 rendering — actual OpenGL image output +# --------------------------------------------------------------------------- + +def _snap1_offscreen(**kwargs): + """Call snap1 with an invisible (offscreen) GLFW context. + + On macOS a visible GLFW window goes through the Cocoa compositor; the + first glReadPixels call may return all-black before the compositor has + finished its first composite pass. An invisible context renders + directly to the driver framebuffer and reads back correctly. + + Skips the test automatically if no OpenGL context can be created + (headless CI without GPU or EGL support). + """ + import whippersnappy.gl.utils as gl_utils # noqa: PLC0415 + + original = gl_utils.create_window_with_fallback + + def _invisible(*args, **kw): + kw["visible"] = False + return original(*args, **kw) + + gl_utils.create_window_with_fallback = _invisible + try: + from whippersnappy import snap1 # noqa: PLC0415 + return snap1(width=200, height=200, colorbar=False, **kwargs) + except RuntimeError as exc: + if "context" in str(exc).lower() or "opengl" in str(exc).lower(): + pytest.skip(f"No OpenGL context available: {exc}") + raise + finally: + gl_utils.create_window_with_fallback = original + + +class TestSnap1Rendering: + """End-to-end rendering tests: snap1 must return a non-empty PIL Image. + + All tests use the tetrahedron mesh (_V, _F) — a true 3-D shape that is + visible from any camera direction, unlike a flat surface which can + appear edge-on and produce an all-black image. + + Tests use an offscreen GLFW context (see ``_snap1_offscreen``) and are + skipped automatically when no OpenGL context is available. + """ + + def test_snap1_basic(self): + """snap1 returns the right size and renders a non-uniform image.""" + img = _snap1_offscreen(mesh=(_V, _F)) + assert img.width == 200 and img.height == 200 + arr = np.array(img) + assert arr.min() != arr.max(), ( + "Rendered image is completely uniform — shading is not working." + ) + + def test_snap1_with_overlay_and_roi(self): + """Overlay + ROI mask: image is non-uniform and correct size.""" + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + img = _snap1_offscreen( + mesh=(_V, _F), overlay=overlay, roi=roi, fthresh=0.0, fmax=1.0 + ) + assert img.width == 200 + assert np.array(img).min() != np.array(img).max(), ( + "Image with overlay+ROI is completely uniform." + ) + + def test_snap1_with_bg_map(self): + """bg_map array: image is non-uniform.""" + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + arr = np.array(_snap1_offscreen(mesh=(_V, _F), bg_map=bg)) + assert arr.min() != arr.max(), "Image with bg_map is completely uniform." + + def test_label_map_and_lut_rendering(self): + """Label map + LUT: image is non-uniform and correct size. + Skips on platforms without OpenGL context (Windows/macOS headless). + Reuses the tetra mesh (_V, _F) as in other tests.""" + labels = np.array([1,2,1,2], dtype=int) + lut = np.array([[1,255,0,0],[2,0,255,0]], dtype=float) + lut[:,1:] = lut[:,1:] / 255.0 + annot = (labels, lut) + img = _snap1_offscreen(mesh=(_V, _F), annot=annot) + assert img is not None + arr = np.array(img) + assert arr.shape[0] > 0 and arr.shape[1] > 0 + assert np.any(arr != arr[0,0]) # Not uniform diff --git a/whippersnappy/utils/tests/test_config.py b/tests/test_config.py similarity index 94% rename from whippersnappy/utils/tests/test_config.py rename to tests/test_config.py index 961f472..5210734 100644 --- a/whippersnappy/utils/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ from io import StringIO -from .._config import sys_info +from whippersnappy._config import sys_info def test_sys_info(): diff --git a/tests/test_mesh_io.py b/tests/test_mesh_io.py new file mode 100644 index 0000000..7a4b589 --- /dev/null +++ b/tests/test_mesh_io.py @@ -0,0 +1,378 @@ +"""Tests for whippersnappy/geometry/mesh_io.py and the updated resolve_mesh. + +All tests use in-memory strings written to temporary files so no external +data is required (except the bundled tetra.off sample). +""" + +import os +import tempfile + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import resolve_mesh +from whippersnappy.geometry.mesh_io import ( + read_gifti_surface, + read_mesh, + read_off, + read_ply_ascii, + read_vtk_ascii_polydata, +) + +# --------------------------------------------------------------------------- +# Shared sample content strings +# --------------------------------------------------------------------------- + +_TETRA_OFF = """\ +OFF +# tetrahedron +4 4 6 +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + +_TETRA_VTK = """\ +# vtk DataFile Version 3.0 +tetrahedron +ASCII +DATASET POLYDATA +POINTS 4 float +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +POLYGONS 4 16 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + +_TETRA_PLY = """\ +ply +format ascii 1.0 +comment tetrahedron +element vertex 4 +property float x +property float y +property float z +element face 4 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + +_SAMPLES = { + ".off": _TETRA_OFF, + ".vtk": _TETRA_VTK, + ".ply": _TETRA_PLY, +} + + +def _write_tmp(content, suffix): + """Write *content* to a named temp file, return its path.""" + fd, path = tempfile.mkstemp(suffix=suffix) + with os.fdopen(fd, "w") as fh: + fh.write(content) + return path + + +def _expected_verts(): + return np.array( + [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32 + ) + + +def _expected_faces(): + return np.array( + [[0, 2, 1], [0, 1, 3], [0, 3, 2], [1, 2, 3]], dtype=np.uint32 + ) + + +# --------------------------------------------------------------------------- +# read_off +# --------------------------------------------------------------------------- + +class TestReadOff: + def test_basic(self): + path = _write_tmp(_TETRA_OFF, ".off") + try: + v, f = read_off(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_bundled_sample(self): + """Verify the bundled tests/data/tetra.off file loads correctly.""" + here = os.path.dirname(__file__) + v, f = read_off(os.path.join(here, "data", "tetra.off")) + assert v.shape == (4, 3) and f.shape == (4, 3) + + def test_error_cases(self): + cases = [ + ("", ".off", "empty"), + ("NOFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n3 0 1 3\n3 0 2 3\n3 1 2 3\n", ".off", "OFF"), + ("OFF\n4 1 4\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n4 0 1 2 3\n", ".off", "triangles"), + ("OFF\n3 1 3\n0 0 0\n1 0 0\n0 1 0\n3 0 1 99\n", ".off", "out of range"), + ] + for content, suffix, match in cases: + path = _write_tmp(content, suffix) + try: + with pytest.raises(ValueError, match=match): + read_off(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_vtk_ascii_polydata +# --------------------------------------------------------------------------- + +class TestReadVtkAsciiPolydata: + def test_basic(self): + path = _write_tmp(_TETRA_VTK, ".vtk") + try: + v, f = read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_error_cases(self): + cases = [ + ("# vtk DataFile Version 3.0\ntest\nBINARY\nDATASET POLYDATA\n", "BINARY"), + ("# vtk DataFile Version 3.0\ntest\nASCII\nDATASET UNSTRUCTURED_GRID\n", "POLYDATA"), + ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POINTS 4 float\n0 0 0\n1 0 0\n0 1 0\n0 0 1\nPOLYGONS 1 5\n4 0 1 2 3\n", + "triangles", + ), + ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POLYGONS 1 4\n3 0 1 2\n", + "POINTS", + ), + ] + for content, match in cases: + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match=match): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_ply_ascii +# --------------------------------------------------------------------------- + +class TestReadPlyAscii: + def test_basic(self): + path = _write_tmp(_TETRA_PLY, ".ply") + try: + v, f = read_ply_ascii(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_extra_vertex_props(self): + """PLY with extra per-vertex properties (e.g. nx ny nz) should still load.""" + content = """\ +ply +format ascii 1.0 +element vertex 3 +property float x +property float y +property float z +property float nx +property float ny +property float nz +element face 1 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 0.0 0.0 1.0 +1.0 0.0 0.0 0.0 0.0 1.0 +0.0 1.0 0.0 0.0 0.0 1.0 +3 0 1 2 +""" + path = _write_tmp(content, ".ply") + try: + v, f = read_ply_ascii(path) + finally: + os.unlink(path) + assert v.shape == (3, 3) and f.shape == (1, 3) + + def test_error_cases(self): + quad_ply = """\ +ply +format ascii 1.0 +element vertex 4 +property float x +property float y +property float z +element face 1 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +4 0 1 2 3 +""" + cases = [ + ("ply\nformat binary_little_endian 1.0\nelement vertex 4\nend_header\n", "binary"), + ("OFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n", "ply"), + (quad_ply, "triangles"), + ] + for content, match in cases: + path = _write_tmp(content, ".ply") + try: + with pytest.raises(ValueError, match=match): + read_ply_ascii(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_mesh dispatcher + resolve_mesh routing (combined) +# Each format is tested end-to-end through resolve_mesh (highest-level call). +# read_mesh itself is only tested for error cases not covered above. +# --------------------------------------------------------------------------- + +class TestMeshDispatchAndRouting: + @pytest.mark.parametrize("suffix,content", list(_SAMPLES.items())) + def test_resolve_mesh_path(self, suffix, content): + """resolve_mesh routes each format to the right reader and returns correct dtypes.""" + path = _write_tmp(content, suffix) + try: + v, f = resolve_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + + def test_case_insensitive_extension(self): + path = _write_tmp(_TETRA_OFF, ".OFF") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + read_mesh("/some/file.stl") + + def test_array_tuple_still_works(self): + v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + f_in = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) + v, f = resolve_mesh((v_in, f_in)) + assert v.shape == (4, 3) and f.shape == (4, 3) + + def test_array_out_of_range_raises(self): + v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32) + with pytest.raises(ValueError, match="out of range"): + resolve_mesh((v_in, np.array([[0, 1, 99]], dtype=np.uint32))) + + +# --------------------------------------------------------------------------- +# GIfTI surface reader +# --------------------------------------------------------------------------- + +def _make_surf_gii(verts, faces, suffix=".surf.gii"): + """Write a minimal GIfTI surface file and return its path.""" + import nibabel as nib + coords_da = nib.gifti.GiftiDataArray( + data=verts.astype(np.float32), intent=1008, datatype="NIFTI_TYPE_FLOAT32", + ) + faces_da = nib.gifti.GiftiDataArray( + data=faces.astype(np.int32), intent=1009, datatype="NIFTI_TYPE_INT32", + ) + img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) + fd, path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + nib.save(img, path) + return path + + +_V4 = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) +_F4 = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) + + +class TestReadGiftiSurface: + def test_basic_and_dispatch(self): + """read_gifti_surface, read_mesh, and resolve_mesh all load .surf.gii correctly.""" + path = _make_surf_gii(_V4, _F4, ".surf.gii") + try: + v, f = read_gifti_surface(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + np.testing.assert_allclose(v, _V4, atol=1e-6) + # dispatch through read_mesh and resolve_mesh also work + v2, _ = read_mesh(path) + assert v2.shape == (4, 3) + v3, f3 = resolve_mesh(path) + assert v3.dtype == np.float32 and f3.dtype == np.uint32 + finally: + os.unlink(path) + + def test_plain_gii_extension(self): + path = _make_surf_gii(_V4, _F4, ".gii") + try: + v, f = read_gifti_surface(path) + assert v.shape == (4, 3) and f.shape == (4, 3) + # dispatch also works for plain .gii + v2, _ = read_mesh(path) + assert v2.shape == (4, 3) + finally: + os.unlink(path) + + def test_missing_arrays_raise(self): + """Missing POINTSET or TRIANGLE array raises a clear ValueError.""" + import nibabel as nib + # No POINTSET — only a scalar array + scalar_da = nib.gifti.GiftiDataArray( + data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), intent=0, + ) + img = nib.gifti.GiftiImage(darrays=[scalar_da]) + fd, path = tempfile.mkstemp(suffix=".gii") + os.close(fd) + nib.save(img, path) + try: + with pytest.raises(ValueError, match="POINTSET"): + read_gifti_surface(path) + finally: + os.unlink(path) + + # POINTSET but no TRIANGLE + coords_da = nib.gifti.GiftiDataArray(data=_V4.astype(np.float32), intent=1008) + img2 = nib.gifti.GiftiImage(darrays=[coords_da]) + fd, path2 = tempfile.mkstemp(suffix=".gii") + os.close(fd) + nib.save(img2, path2) + try: + with pytest.raises(ValueError, match="TRIANGLE"): + read_gifti_surface(path2) + finally: + os.unlink(path2) diff --git a/tests/test_overlay_io.py b/tests/test_overlay_io.py new file mode 100644 index 0000000..f5bb5e4 --- /dev/null +++ b/tests/test_overlay_io.py @@ -0,0 +1,267 @@ +"""Tests for whippersnappy/geometry/overlay_io.py and the updated +_load_overlay_from_file routing in inputs.py. + +All tests write content to temporary files so no external data is required. +""" + +import os +import tempfile + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import resolve_bg_map, resolve_overlay, resolve_roi +from whippersnappy.geometry.overlay_io import ( + read_npy, + read_npz, + read_overlay, + read_txt, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write_tmp(content, suffix, binary=False): + fd, path = tempfile.mkstemp(suffix=suffix) + mode = "wb" if binary else "w" + with os.fdopen(fd, mode) as fh: + fh.write(content) + return path + + +_FLOAT_VALUES = [0.1, -1.5, 2.0, 0.0] +_INT_VALUES = [0, 1, 3, 2] + + +# --------------------------------------------------------------------------- +# read_txt +# --------------------------------------------------------------------------- + +class TestReadTxt: + def test_float_basic(self): + path = _write_tmp("\n".join(str(v) for v in _FLOAT_VALUES) + "\n", ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (4,) and arr.dtype == np.float32 + np.testing.assert_allclose(arr, _FLOAT_VALUES, atol=1e-6) + + def test_integer_promoted_to_int32(self): + path = _write_tmp("\n".join(str(v) for v in _INT_VALUES) + "\n", ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.dtype == np.int32 + np.testing.assert_array_equal(arr, _INT_VALUES) + + def test_mixed_float_stays_float32(self): + path = _write_tmp("0\n1\n2.5\n3\n", ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.dtype == np.float32 + + def test_headers_and_comments_skipped(self): + # hash comment, text header, CSV first column + for content, suffix, expected_len in [ + ("# comment\n0.5\n1.5\n", ".txt", 2), + ("value\n0.1\n0.2\n0.3\n", ".txt", 3), + ("label,ignore\n1.0,extra\n2.0,extra\n", ".csv", 2), + ]: + path = _write_tmp(content, suffix) + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (expected_len,) + + def test_error_cases(self): + for content, match in [ + ("", "No numeric"), + ("1.0\nbadvalue\n3.0\n", "Could not parse"), + ]: + path = _write_tmp(content, ".txt") + try: + with pytest.raises(ValueError, match=match): + read_txt(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_npy / read_npz +# --------------------------------------------------------------------------- + +class TestReadNpyNpz: + def test_npy_basic(self): + arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + np.testing.assert_array_equal(read_npy(path), arr_in) + finally: + os.unlink(path) + + def test_npy_column_vector_squeezed(self): + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, np.ones((5, 1), dtype=np.float32)) + try: + assert read_npy(path).shape == (5,) + finally: + os.unlink(path) + + def test_npy_2d_raises(self): + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, np.ones((3, 4), dtype=np.float32)) + try: + with pytest.raises(ValueError, match="1-D"): + read_npy(path) + finally: + os.unlink(path) + + def test_npz_data_key_and_fallback(self): + arr_in = np.array([0, 1, 2], dtype=np.int32) + # named 'data' key + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path, data=arr_in, other=np.zeros(3)) + try: + np.testing.assert_array_equal(read_npz(path), arr_in) + finally: + os.unlink(path) + # first-array fallback + arr2 = np.array([9.0, 8.0], dtype=np.float32) + fd, path2 = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path2, arr_0=arr2) + try: + np.testing.assert_array_equal(read_npz(path2), arr2) + finally: + os.unlink(path2) + + def test_npz_empty_raises(self): + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path) + try: + with pytest.raises(ValueError, match="no arrays"): + read_npz(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_overlay dispatcher +# Dispatch is implicitly covered by TestResolveOverlayRouting below; here we +# only test the error cases and the gifti rejection that are not exercised +# through resolve_overlay. +# --------------------------------------------------------------------------- + +class TestReadOverlayDispatcher: + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + read_overlay("/some/file.xyz") + + def test_case_insensitive_extension(self): + path = _write_tmp("1.0\n2.0\n", ".TXT") + try: + assert read_overlay(path).shape == (2,) + finally: + os.unlink(path) + + def test_surface_gii_rejected_with_helpful_error(self): + import nibabel as nib + coords_da = nib.gifti.GiftiDataArray( + data=np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32), + intent=1008, + ) + faces_da = nib.gifti.GiftiDataArray( + data=np.array([[0,1,2],[0,1,3]], dtype=np.int32), + intent=1009, + ) + img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) + fd, path = tempfile.mkstemp(suffix=".surf.gii") + os.close(fd) + nib.save(img, path) + try: + with pytest.raises(ValueError, match="surface geometry"): + from whippersnappy.geometry.overlay_io import read_gifti + read_gifti(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# resolve_overlay / resolve_bg_map / resolve_roi routing via inputs.py +# --------------------------------------------------------------------------- + +class TestResolveOverlayRouting: + """End-to-end: file path → resolve_overlay / resolve_bg_map / resolve_roi.""" + + @pytest.mark.parametrize("suffix,content,n", [ + (".txt", "0.1\n0.5\n0.9\n0.3\n", 4), + (".csv", "0.5\n1.5\n", 2), + ]) + def test_txt_csv_routed(self, suffix, content, n): + path = _write_tmp(content, suffix) + try: + arr = resolve_overlay(path, n_vertices=n) + finally: + os.unlink(path) + assert arr.shape == (n,) and arr.dtype == np.float32 + + def test_npy_routed(self): + arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + np.testing.assert_array_equal(resolve_overlay(path, n_vertices=3), arr_in) + finally: + os.unlink(path) + + def test_shape_mismatch_raises(self): + path = _write_tmp("0.1\n0.5\n", ".txt") + try: + with pytest.raises(ValueError, match="vertices"): + resolve_overlay(path, n_vertices=5) + finally: + os.unlink(path) + + def test_bg_map_and_roi_routed(self): + """resolve_bg_map and resolve_roi also correctly route .txt and .npy.""" + # bg_map from txt — always float32 + path = _write_tmp("1\n-1\n1\n-1\n", ".txt") + try: + arr = resolve_bg_map(path, n_vertices=4) + finally: + os.unlink(path) + assert arr.shape == (4,) and arr.dtype == np.float32 + + # roi from bool npy + arr_in = np.array([True, False, True, True]) + fd, path2 = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path2, arr_in) + try: + roi = resolve_roi(path2, n_vertices=4) + finally: + os.unlink(path2) + assert roi.dtype == bool + np.testing.assert_array_equal(roi, arr_in) + + def test_integer_txt_as_overlay(self): + """Integer txt (parcellation) values are numerically preserved.""" + path = _write_tmp("3\n0\n1\n3\n", ".txt") + try: + arr = resolve_overlay(path, n_vertices=4) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, [3.0, 0.0, 1.0, 3.0]) diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb new file mode 100644 index 0000000..5898a06 --- /dev/null +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2c795daa", + "metadata": {}, + "source": [ + "# WhipperSnapPy Tutorial\n", + "\n", + "This notebook demonstrates static and interactive 3D brain surface visualization\n", + "using WhipperSnapPy. It covers single-view snapshots (`snap1`), four-view overview\n", + "images (`snap4`), and interactive WebGL rendering (`plot3d`).\n", + "\n", + "**Tutorial data** from the Rhineland Study (Koch et al.),\n", + "[Zenodo: https://doi.org/10.5281/zenodo.11186582](https://doi.org/10.5281/zenodo.11186582), CC BY 4.0." + ] + }, + { + "cell_type": "markdown", + "id": "8bd0a771", + "metadata": {}, + "source": [ + "## Subject Directory\n", + "\n", + "Set `sdir` to your own FreeSurfer subject directory.\n", + "If you leave it empty, the sample subject **sub-rs** (one anonymized subject\n", + "from the Rhineland Study) is downloaded automatically (~20 MB, cached locally\n", + "after the first run)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30c14cc9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from whippersnappy import fetch_sample_subject\n", + "\n", + "# Set sdir to your FreeSurfer subject directory.\n", + "# Leave empty (\"\") to automatically use the sample subject (sub-rs,\n", + "# one anonymized subject from the Rhineland Study). It is used directly\n", + "# from the repository when available, otherwise downloaded (~20 MB) and\n", + "# cached locally after the first run.\n", + "sdir = \"\"\n", + "# sdir = \"/path/to/your/subject\"\n", + "\n", + "if not sdir:\n", + " sdir = fetch_sample_subject()[\"sdir\"]\n", + "\n", + "print(\"Subject directory:\", sdir)\n" + ] + }, + { + "cell_type": "markdown", + "id": "73dd4d58", + "metadata": {}, + "source": [ + "### Derive file paths from `sdir`\n", + "\n", + "All paths are constructed from `sdir`, so switching between subjects only\n", + "requires changing the single variable above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec559dcd", + "metadata": {}, + "outputs": [], + "source": [ + "# Surfaces\n", + "lh_white = os.path.join(sdir, \"surf\", \"lh.white\")\n", + "rh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n", + "\n", + "# Curvature\n", + "lh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\n", + "rh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n", + "\n", + "# Thickness overlay\n", + "lh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\n", + "rh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n", + "\n", + "# Cortex label (mask for overlay)\n", + "lh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\n", + "rh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n", + "\n", + "# Parcellation annotation (DKTatlas)\n", + "lh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\n", + "rh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "22be9ebf", + "metadata": {}, + "source": [ + "## snap1 — Basic Single View\n", + "\n", + "`snap1` renders a single static view of a surface mesh into a PIL Image.\n", + "Here we render the left hemisphere with curvature texturing only (no overlay),\n", + "which gives the classic sulcal depth shading." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "783e547b", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display\n", + "\n", + "from whippersnappy import snap1\n", + "\n", + "img = snap1(lh_white, bg_map=lh_curv)\n", + "display(img)\n" + ] + }, + { + "cell_type": "markdown", + "id": "7173d312", + "metadata": {}, + "source": [ + "## snap1 — With Thickness Overlay\n", + "\n", + "By passing `overlay` and `roi`, the surface is colored by cortical\n", + "thickness values, masked to the cortex label. The `view` parameter selects\n", + "the lateral view of the left hemisphere." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13407b67", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy.utils.types import ViewType\n", + "\n", + "img = snap1(\n", + " lh_white,\n", + " overlay=lh_thickness,\n", + " bg_map=lh_curv,\n", + " roi=lh_label,\n", + " view=ViewType.LEFT,\n", + ")\n", + "display(img)\n" + ] + }, + { + "cell_type": "markdown", + "id": "4217c291", + "metadata": {}, + "source": [ + "## snap1 — With Parcellation Annotation\n", + "\n", + "`annot` accepts a FreeSurfer `.annot` file and colors each vertex by\n", + "its parcellation label. This example uses the DKTatlas parcellation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7271e902", + "metadata": {}, + "outputs": [], + "source": [ + "img = snap1(\n", + " lh_white,\n", + " annot=lh_annot,\n", + " bg_map=lh_curv,\n", + ")\n", + "display(img)\n" + ] + }, + { + "cell_type": "markdown", + "id": "620c6c43", + "metadata": {}, + "source": [ + "## snap4 — Four-View Overview\n", + "\n", + "`snap4` renders lateral and medial views of both hemispheres and stitches\n", + "them into a single composed image. Here we color both hemispheres by\n", + "cortical thickness, masked to the cortex label." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "903514b8", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy import snap4\n", + "\n", + "img = snap4(\n", + " sdir=sdir,\n", + " lh_overlay=lh_thickness,\n", + " rh_overlay=rh_thickness,\n", + " colorbar=True,\n", + " caption=\"Cortical Thickness (mm)\",\n", + ")\n", + "display(img)\n" + ] + }, + { + "cell_type": "markdown", + "id": "5d98c87b", + "metadata": {}, + "source": [ + "## plot3d — Interactive 3D Viewer\n", + "\n", + "`plot3d` creates an interactive Three.js/WebGL viewer that works in all\n", + "Jupyter environments. You can rotate, zoom, and pan with the mouse.\n", + "Requires `pip install 'whippersnappy[notebook]'`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d5a09c7", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy import plot3d\n", + "\n", + "viewer = plot3d(\n", + " mesh=lh_white,\n", + " bg_map=lh_curv,\n", + " overlay=lh_thickness,\n", + ")\n", + "display(viewer)\n" + ] + }, + { + "cell_type": "markdown", + "id": "dc3970c3", + "metadata": {}, + "source": [ + "## snap_rotate — Rotating 360° Animation\n", + "\n", + "`snap_rotate` renders a full 360° rotation of the surface. We output an\n", + "animated GIF so it displays inline in all Jupyter environments including\n", + "PyCharm. Use `.mp4` as `outpath` instead for a smaller file when playing\n", + "outside the notebook.\n", + "This cell takes the longest to run — execute it last." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "928a68ea", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "\n", + "from whippersnappy import snap_rotate\n", + "\n", + "outpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n", + "\n", + "snap_rotate(\n", + " mesh=lh_white,\n", + " outpath=outpath_gif,\n", + " overlay=lh_thickness,\n", + " bg_map=lh_curv,\n", + " roi=lh_label,\n", + " n_frames=72,\n", + " fps=24,\n", + " width=800,\n", + " height=600,\n", + ")\n", + "print(\"GIF saved to:\", outpath_gif)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcd38db4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "Image(filename=outpath_gif)\n" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index f7f862b..d656cbc 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -1,2 +1,105 @@ -from ._version import __version__ # noqa: F401 -from .utils._config import sys_info # noqa: F401 +"""WhipperSnapPy: Plot and capture FastSurfer and FreeSurfer-style surface overlays. + +WhipperSnapPy provides tools for rendering brain surface meshes with statistical +overlays and annotations. It includes: + +- **Static rendering**: `snap1()` and `snap4()` functions for publication-quality images +- **3D plotting**: For Jupyter notebooks with mouse-controlled 3D (via Three.js) +- **GUI**: Interactive desktop viewer via the ``whippersnap`` command +- **CLI tools**: ``whippersnap1`` and ``whippersnap4`` for batch processing +- **Local mesh IO**: OFF, VTK ASCII PolyData, and PLY in addition to FreeSurfer surfaces + +For static image generation:: + + from whippersnappy import snap1, snap4 + from whippersnappy.utils.types import ViewType + from IPython.display import display + + img = snap1('path/to/lh.white', view=ViewType.LEFT) + display(img) + +For interactive 3D in Jupyter notebooks:: + + # Requires: pip install 'whippersnappy[notebook]' + from whippersnappy import plot3d + + viewer = plot3d( + mesh='path/to/lh.white', + bg_map='path/to/lh.curv', + overlay='path/to/lh.thickness', # optional: for colors + ) + display(viewer) + +For the interactive desktop GUI:: + + # Requires: pip install 'whippersnappy[gui]' + # General mode — any mesh file: + whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv + # FreeSurfer shortcut — derive paths from subject directory: + whippersnap -sd path/to/subject_dir --hemi lh -lh lh.thickness + +""" + +import os +import sys + + +def _check_display(): + """Return True if a display is available or the platform handles GL natively. + + On macOS (CGL) and Windows (WGL) PyOpenGL does not use EGL, so we always + return True to avoid setting ``PYOPENGL_PLATFORM=egl`` on those systems. + On Linux we probe for an X11/Wayland display server. + """ + if sys.platform != "linux": + # macOS uses CGL, Windows uses WGL — no EGL needed on either. + return True + display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") + if not display: + return False + try: + import ctypes + import ctypes.util + libx11 = ctypes.CDLL(ctypes.util.find_library("X11") or "libX11.so.6") + libx11.XOpenDisplay.restype = ctypes.c_void_p + libx11.XOpenDisplay.argtypes = [ctypes.c_char_p] + libx11.XCloseDisplay.restype = None + libx11.XCloseDisplay.argtypes = [ctypes.c_void_p] + dpy = libx11.XOpenDisplay(display.encode()) + if dpy: + libx11.XCloseDisplay(dpy) + return True + return False + except Exception: + return False + +if "PYOPENGL_PLATFORM" not in os.environ: + if not _check_display(): + os.environ["PYOPENGL_PLATFORM"] = "egl" + +from ._config import sys_info # noqa: F401, E402 +from ._version import __version__ # noqa: F401, E402 +from .snap import snap1, snap4, snap_rotate # noqa: E402 +from .utils.datasets import fetch_sample_subject # noqa: E402 +from .utils.types import ViewType # noqa: E402 + +# 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) +try: + from .plot3d import plot3d # noqa: E402 + _has_plot3d = True +except ImportError: + _has_plot3d = False + +# Export list +__all__ = [ + "__version__", + "sys_info", + "snap1", + "snap4", + "snap_rotate", + "fetch_sample_subject", + "ViewType", +] + +if _has_plot3d: + __all__.append("plot3d") diff --git a/whippersnappy/_config.py b/whippersnappy/_config.py new file mode 100644 index 0000000..6617445 --- /dev/null +++ b/whippersnappy/_config.py @@ -0,0 +1,177 @@ +"""Configuration and system-info helpers (top-level module).""" + +import platform +import re +import sys +from functools import partial +from importlib.metadata import requires, version +from typing import IO, Callable, Optional + +import psutil + + +def sys_info(fid: Optional[IO] = None, developer: bool = False): + """Print the system information for debugging. + + Parameters + ---------- + fid : file-like, default=None + The file to write to, passed to :func:`print`. + Can be None to use :data:`sys.stdout`. + developer : bool, default=False + If True, display information about optional dependencies. + """ + ljust = 26 + out = partial(print, end="", file=fid) + package = __package__.split(".")[0] + + # OS information - requires python 3.8 or above + out("Platform:".ljust(ljust) + platform.platform() + "\n") + # Python information + out("Python:".ljust(ljust) + sys.version.replace("\n", " ") + "\n") + out("Executable:".ljust(ljust) + sys.executable + "\n") + # CPU information + out("CPU:".ljust(ljust) + platform.processor() + "\n") + out("Physical cores:".ljust(ljust) + str(psutil.cpu_count(False)) + "\n") + out("Logical cores:".ljust(ljust) + str(psutil.cpu_count(True)) + "\n") + # Memory information + out("RAM:".ljust(ljust)) + out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n") + out("SWAP:".ljust(ljust)) + out(f"{psutil.swap_memory().total / float(2 ** 30):0.1f} GB\n") + + # dependencies + out("\nDependencies info\n") + # package version may not be available when running tests from the + # repository root (package not installed). Handle gracefully. + try: + pkg_version = version(package) + except Exception: + pkg_version = "Not installed." + out(f"{package}:".ljust(ljust) + pkg_version + "\n") + + try: + raw_requires = requires(package) or [] + except Exception: + raw_requires = [] + + # If package metadata is not present (e.g. running from source), try to + # read dependencies declared in pyproject.toml so tests running against + # the tree without installing still report expected deps. + if not raw_requires: + try: + from pathlib import Path + try: + import tomllib as _toml + except Exception: + _toml = None + + repo_root = Path(__file__).resolve().parents[1] + pyproject_path = repo_root / "pyproject.toml" + if _toml is not None and pyproject_path.exists(): + with pyproject_path.open("rb") as fh: + data = _toml.load(fh) + proj = data.get("project", {}) + deps = proj.get("dependencies", []) or [] + # dependencies may be in the form 'pkg>=1.2' etc. + raw_requires = deps + except Exception: + raw_requires = [] + + dependencies = [elt.split(";")[0].rstrip() for elt in raw_requires if "extra" not in elt] + _list_dependencies_info(out, ljust, dependencies) + + # extras + if developer: + keys = ( + "build", + "doc", + "test", + "style", + ) + for key in keys: + _from_pyproject = False + try: + raw_requires = requires(package) or [] + except Exception: + raw_requires = [] + + # If package metadata missing, fall back to pyproject.toml optional-dependencies + if not raw_requires: + try: + from pathlib import Path + try: + import tomllib as _toml + except Exception: + _toml = None + + repo_root = Path(__file__).resolve().parents[1] + pyproject_path = repo_root / "pyproject.toml" + if _toml is not None and pyproject_path.exists(): + with pyproject_path.open("rb") as fh: + data = _toml.load(fh) + proj = data.get("project", {}) + opt = proj.get("optional-dependencies", {}) or {} + deps = opt.get(key, []) or [] + raw_requires = deps + _from_pyproject = True + except Exception: + raw_requires = [] + + if _from_pyproject: + # pyproject.toml entries have no extras markers — use as-is + dependencies = [elt.split(";")[0].rstrip() for elt in raw_requires] + else: + dependencies = [ + elt.split(";")[0].rstrip() + for elt in raw_requires + if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt + ] + if len(dependencies) == 0: + continue + out(f"\nOptional '{key}' info\n") + _list_dependencies_info(out, ljust, dependencies) + + +def _list_dependencies_info(out: Callable, ljust: int, dependencies: list[str]): + """List dependencies names and versions. + + Parameters + ---------- + out : Callable + output function + ljust : int + length of returned string + dependencies : List[str] + list of dependencies + + """ + for dep in dependencies: + # handle dependencies with version specifiers + specifiers_pattern = r"(~=|==|!=|<=|>=|<|>|===)" + specifiers = re.findall(specifiers_pattern, dep) + if len(specifiers) != 0: + dep, _ = dep.split(specifiers[0]) + while not dep[-1].isalpha(): + dep = dep[:-1] + # handle dependencies provided with a [key], e.g. pydocstyle[toml] + if "[" in dep: + dep = dep.split("[")[0] + try: + version_ = version(dep) + except Exception: + version_ = "Not found." + + # handle special dependencies with backends, C dep, .. + if dep in ("matplotlib", "seaborn") and version_ != "Not found.": + try: + import importlib + plt = importlib.import_module("matplotlib.pyplot") + backend = plt.get_backend() + except Exception: + backend = "Not found" + + out(f"{dep}:".ljust(ljust) + version_ + f" (backend: {backend})\n") + + else: + out(f"{dep}:".ljust(ljust) + version_ + "\n") diff --git a/whippersnappy/_version.py b/whippersnappy/_version.py index 77f24d8..3481ece 100644 --- a/whippersnappy/_version.py +++ b/whippersnappy/_version.py @@ -1,5 +1,8 @@ """Version number.""" -from importlib.metadata import version - -__version__ = version(__package__) +try: + from importlib.metadata import version + __version__ = version(__package__) +except Exception: + # Fallback when package is not installed (e.g., running from source) + __version__ = "dev" diff --git a/whippersnappy/cli/__init__.py b/whippersnappy/cli/__init__.py index a6d22a9..b9d0b91 100644 --- a/whippersnappy/cli/__init__.py +++ b/whippersnappy/cli/__init__.py @@ -1 +1 @@ -from .whippersnap import run +# CLI modules should not be imported here; keep __init__.py empty. diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 5af061e..9e9bd19 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -1,380 +1,645 @@ #!/usr/bin/python3 -""" -Executes the whippersnappy program in an interactive or non-interactive mode. +"""Interactive GUI viewer for WhipperSnapPy. + +Opens a live OpenGL window for any triangular surface mesh together with a +Qt-based configuration panel that allows adjusting overlay thresholds at +runtime. -The non-interactive mode (the default) creates an image that contains four -views of the surface, an optional color bar, and a configurable caption. +Two input modes are supported: -The interactive mode (--interactive) opens a simple GUI with a controllable -view of one of the hemispheres. In addition, the view through a separate -configuration app which allows adjusting thresholds, etc. during runtime. +**General mode** — pass any mesh file directly:: -Usage: - $ python3 whippersnap.py -lh $LH_OVERLAY_FILE -rh $RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR -o $OUTPUT_PATH + whippersnap --mesh mesh.off --overlay values.mgh + whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv -(See help for full list of arguments.) +**FreeSurfer shortcut** — pass a subject directory and hemisphere; all +FreeSurfer paths are derived automatically:: -@Author1 : Martin Reuter -@Author2 : Ahmed Faisal Abdelrahman -@Created : 16.03.2022 + whippersnap -sd --hemi lh --overlay lh.thickness + whippersnap -sd --hemi rh --annot rh.aparc.annot + +See ``whippersnap --help`` for the full list of options. +For non-interactive four-view batch rendering use ``whippersnap4``. +For single-view non-interactive snapshots use ``whippersnap1``. """ import argparse -import math +import logging import os -import signal import sys -import threading + +if __name__ == "__main__" and __package__ is None: + # Replace the current process with `python -m whippersnappy.cli.whippersnap` + # so that relative imports work. os.execv replaces this process in-place + # (no child process, no blocking wait, signals work correctly). + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap"] + sys.argv[1:]) import glfw import OpenGL.GL as gl import pyrr -from PyQt6.QtWidgets import QApplication -from whippersnappy.config_app import ConfigWindow -from whippersnappy.core import ( - get_surf_name, +try: + from PyQt6.QtWidgets import QApplication +except Exception: + # GUI dependency missing; raise a clear error at runtime + QApplication = None + +from .._version import __version__ +from ..geometry import get_surf_name, prepare_geometry +from ..gl import ( + ViewState, + arcball_rotation_matrix, + arcball_vector, + capture_window, + compute_view_matrix, + get_view_matrices, init_window, - prepare_geometry, setup_shader, - snap4, ) +from ..utils.types import ViewType -# Global variables for config app configuration state: +# Module logger +logger = logging.getLogger(__name__) + +# Global thresholds shared between the GL render loop and the Qt config panel. +# All access is from the main thread — no locking needed. current_fthresh_ = None current_fmax_ = None -app_ = None -app_window_ = None -app_window_closed_ = False def show_window( - hemi, - overlaypath=None, - annotpath=None, - sdir=None, - caption=None, + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, invert=False, - labelname="cortex.label", - surfname=None, - curvname="curv", specular=True, + view=ViewType.LEFT, + app=None, + config_window=None, ): - """ - Start an interactive window in which an overlay can be viewed. + """Start a live interactive OpenGL+Qt window for viewing a triangular mesh. + + On macOS both GLFW/Cocoa and Qt require the main thread. This function + creates a GLFW window, registers GLFW input callbacks, then hands control + to a ``QTimer``-driven render loop so GLFW polling and Qt event processing + share the main thread. Parameters ---------- - hemi : str - Hemisphere; one of: ['lh', 'rh']. - overlaypath : str - Path to the overlay file for the specified hemi (FreeSurfer format). - annotpath : str - Path to the annotation file for the specified hemi (FreeSurfer format). - sdir : str - Subject dir containing surf files. - caption : str - Caption text to be placed on the image. - invert : bool - Invert color (blue positive, red negative). - labelname : str - Label for masking, usually cortex.label. - surfname : str - Surface to display values on, usually pial_semi_inflated from fsaverage. - curvname : str - Curvature file for texture in non-colored regions (default curv). + mesh : str or tuple of (array-like, array-like) + Path to any mesh file or a ``(vertices, faces)`` array tuple. + overlay : str, array-like, or None, optional + Per-vertex scalar overlay. + annot : str, tuple, or None, optional + FreeSurfer ``.annot`` file or ``(labels, ctab[, names])`` tuple. + bg_map : str, array-like, or None, optional + Per-vertex scalar file or array for background shading. + roi : str, array-like, or None, optional + FreeSurfer label file or boolean array masking overlay coloring. + invert : bool, optional + Invert the overlay color mapping. Default is ``False``. specular : bool, optional - If True, enable specular. - - Returns - ------- - None - This function does not return any value. + Enable specular highlights. Default is ``True``. + view : ViewType, optional + Initial camera view direction. Default is ``ViewType.LEFT``. + app : QApplication + The already-created ``QApplication`` instance. + config_window : ConfigWindow + The already-created Qt configuration panel. + + Raises + ------ + RuntimeError + If the GLFW window or OpenGL context could not be created. + + Notes + ----- + **Mouse and keyboard interaction:** + + * **Left-drag** — arcball rotation (view-relative, no gimbal lock). + * **Right-drag / Middle-drag** — pan in screen space. + * **Scroll wheel** — zoom in/out. + * **Arrow keys** — rotate in 3° increments (faster while held). + * **R / double-click** — reset view to initial preset. + * **S** — save snapshot (opens a file-save dialog). + * **Q / ESC** — quit. """ - global current_fthresh_, current_fmax_, app_, app_window_, app_window_closed_ + global current_fthresh_, current_fmax_ + + import numpy as np # noqa: PLC0415 + from PyQt6.QtCore import QTimer # noqa: PLC0415 - wwidth = 720 - weight = 600 - window = init_window(wwidth, weight, "WhipperSnapPy", visible=True) + wwidth = 720 + wheight = 600 + window = init_window(wwidth, wheight, "WhipperSnapPy", visible=True) if not window: - return False - - if surfname is None: - print("[INFO] No surf_name provided. Looking for options in surf directory...") - found_surfname = get_surf_name(sdir, hemi) - if found_surfname is None: - print( - f"[ERROR] Could not find a valid surf file in {sdir} for hemi: {hemi}!" - ) - sys.exit(0) - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) - else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) - - curvpath = None - if curvname: - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) - labelpath = None - if labelname: - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) - - # set up matrices to show object left and right side: - rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * math.pi) - rot_x = pyrr.Matrix44.from_x_rotation(0.5 * math.pi) - viewLeft = rot_x * rot_z - # rot_y = pyrr.Matrix44.from_y_rotation(math.pi) - # viewRight = rot_y * viewLeft - rot_y = pyrr.Matrix44.from_y_rotation(0) - - meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_ + raise RuntimeError( + "Could not create a GLFW window/context. OpenGL context unavailable." + ) + + # ------------------------------------------------------------------ + # Initialise view state and base view matrix + # ------------------------------------------------------------------ + view_mats = get_view_matrices() + base_view = view_mats[view] # fixed orientation preset (→ transform uniform) + vs = ViewState(zoom=0.0) # zoom/pan packed into transform + + _last_left_press_time = [0.0] + + def _reset_view(): + vs.rotation = np.eye(4, dtype=np.float32) + vs.pan = np.zeros(2, dtype=np.float32) + vs.zoom = 0.0 + + # ------------------------------------------------------------------ + # Load mesh and compile shader + # ------------------------------------------------------------------ + meshdata, triangles, _fthresh, _fmax, _pos, _neg = prepare_geometry( + mesh, overlay, annot, bg_map, roi, + current_fthresh_, current_fmax_, + invert=invert, ) - shader = setup_shader(meshdata, triangles, wwidth, weight, specular=specular) - - print() - print("Keys:") - print("Left - Right : Rotate Geometry") - print("ESC : Quit") - print() - - ypos = 0 - while glfw.get_key( - window, glfw.KEY_ESCAPE - ) != glfw.PRESS and not glfw.window_should_close(window): - # Terminate if config app window was closed: - if app_window_closed_: - break + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) - glfw.poll_events() + logger.info( + "Mouse: left-drag=rotate right/middle-drag=pan scroll=zoom " + "R/double-click=reset S=snapshot Q/ESC=quit" + ) + + # ------------------------------------------------------------------ + # GLFW input callbacks + # ------------------------------------------------------------------ + + def _mouse_button_cb(win, button, action, _mods): + x, y = glfw.get_cursor_pos(win) + pos = np.array([x, y], dtype=np.float64) + + if button == glfw.MOUSE_BUTTON_LEFT: + vs.left_button_down = (action == glfw.PRESS) + if action == glfw.PRESS: + # Double-click detection (threshold: 300 ms) + now = glfw.get_time() + if now - _last_left_press_time[0] < 0.3: + _reset_view() + _last_left_press_time[0] = now + vs.last_mouse_pos = pos + else: + vs.last_mouse_pos = None + + elif button == glfw.MOUSE_BUTTON_RIGHT: + vs.right_button_down = (action == glfw.PRESS) + vs.last_mouse_pos = pos if action == glfw.PRESS else None + + elif button == glfw.MOUSE_BUTTON_MIDDLE: + vs.middle_button_down = (action == glfw.PRESS) + vs.last_mouse_pos = pos if action == glfw.PRESS else None + + def _cursor_pos_cb(win, x, y): + if vs.last_mouse_pos is None: + return + dx = x - vs.last_mouse_pos[0] + dy = y - vs.last_mouse_pos[1] + + if vs.left_button_down: + # Arcball rotation — amplify drag for snappier feel + _sensitivity = 2.5 + mx, my = vs.last_mouse_pos + v1 = arcball_vector(mx, my, wwidth, wheight) + v2 = arcball_vector( + mx + (x - mx) * _sensitivity, + my + (y - my) * _sensitivity, + wwidth, wheight, + ) + delta = arcball_rotation_matrix(v2, v1) + vs.rotation = vs.rotation @ delta + + elif vs.right_button_down or vs.middle_button_down: + # Pan in camera space — scale to normalised mesh units + pan_sensitivity = 1.0 / min(wwidth, wheight) + vs.pan[0] += dx * pan_sensitivity + vs.pan[1] -= dy * pan_sensitivity # y is flipped + + vs.last_mouse_pos = np.array([x, y], dtype=np.float64) + + def _scroll_cb(_win, _x_off, y_off): + # scroll up (y_off > 0) → move camera closer (positive Z in camera space) + vs.zoom += y_off * 0.05 + # Allow much further zoom out (e.g. -20.0) + vs.zoom = float(np.clip(vs.zoom, -20.0, 4.5)) + + def _save_snapshot(): + """Capture the current frame and open a Qt save-file dialog.""" + from PyQt6.QtWidgets import QFileDialog # noqa: PLC0415 + path, _ = QFileDialog.getSaveFileName( + None, + "Save snapshot", + "snapshot.png", + "Images (*.png *.jpg *.jpeg *.tiff *.bmp);;All files (*)", + ) + if not path: + return # user cancelled + img = capture_window(window) + img.save(path) + logger.info("Snapshot saved to %s", path) + + _arrow_keys = {glfw.KEY_RIGHT, glfw.KEY_LEFT, glfw.KEY_UP, glfw.KEY_DOWN} + + def _key_cb(win, key, _scancode, action, _mods): + if action not in (glfw.PRESS, glfw.REPEAT): + # On key release, restore normal render rate + if key in _arrow_keys: + timer.setInterval(16) + return + delta = np.radians(3.0) + if key == glfw.KEY_RIGHT: + rot = np.array(pyrr.Matrix44.from_y_rotation(-delta), dtype=np.float32) + elif key == glfw.KEY_LEFT: + rot = np.array(pyrr.Matrix44.from_y_rotation(+delta), dtype=np.float32) + elif key == glfw.KEY_UP: + rot = np.array(pyrr.Matrix44.from_x_rotation(+delta), dtype=np.float32) + elif key == glfw.KEY_DOWN: + rot = np.array(pyrr.Matrix44.from_x_rotation(-delta), dtype=np.float32) + elif key == glfw.KEY_R: + _reset_view() + return + elif key == glfw.KEY_Q: + glfw.set_window_should_close(win, True) + return + elif key == glfw.KEY_S and action == glfw.PRESS: + _save_snapshot() + return + else: + return + # Speed up render loop while key is held for smooth rotation + if key in _arrow_keys: + timer.setInterval(8) + vs.rotation = vs.rotation @ rot + + glfw.set_mouse_button_callback(window, _mouse_button_cb) + glfw.set_cursor_pos_callback(window, _cursor_pos_cb) + glfw.set_scroll_callback(window, _scroll_cb) + glfw.set_key_callback(window, _key_cb) + + from PyQt6.QtCore import QEventLoop # noqa: PLC0415 + loop = QEventLoop() + + _quitting = [False] # guard so we only shut down once + + def _begin_quit(): + """Stop rendering and GLFW, then exit the event loop.""" + if _quitting[0]: + return + _quitting[0] = True + timer.stop() + try: + glfw.terminate() + except Exception: + pass + # Defer loop.quit() so this callback fully unwinds first, + # avoiding QThreadStorage destruction mid-stack. + QTimer.singleShot(0, loop.quit) + + def _render_frame(): + """Called by QTimer every 16 ms on the main thread.""" + global current_fthresh_, current_fmax_ + nonlocal meshdata, triangles, shader + + if _quitting[0]: + return + + if ( + glfw.get_key(window, glfw.KEY_ESCAPE) == glfw.PRESS + or glfw.window_should_close(window) + ): + _begin_quit() + return + glfw.poll_events() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - if app_window_ is not None: - if ( - app_window_.get_fthresh_value() != current_fthresh_ - or app_window_.get_fmax_value() != current_fmax_ - ): - current_fthresh_ = app_window_.get_fthresh_value() - current_fmax_ = app_window_.get_fmax_value() - meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, - overlaypath, - curvpath, - labelpath, - current_fthresh_, - current_fmax_, + # Re-prepare geometry if Qt sliders changed thresholds + if config_window is not None: + new_fthresh = config_window.get_fthresh_value() + new_fmax = config_window.get_fmax_value() + if new_fthresh != current_fthresh_ or new_fmax != current_fmax_: + current_fthresh_ = new_fthresh + current_fmax_ = new_fmax + meshdata, triangles, _ft, _fm, _p, _n = prepare_geometry( + mesh, overlay, annot, bg_map, roi, + current_fthresh_, current_fmax_, + invert=invert, ) shader = setup_shader( - meshdata, triangles, wwidth, weight, specular=specular + meshdata, triangles, wwidth, wheight, specular=specular ) - transformLoc = gl.glGetUniformLocation(shader, "transform") - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewLeft) - - if glfw.get_key(window, glfw.KEY_RIGHT) == glfw.PRESS: - ypos = ypos + 0.0004 - if glfw.get_key(window, glfw.KEY_LEFT) == glfw.PRESS: - ypos = ypos - 0.0004 - rot_y = pyrr.Matrix44.from_y_rotation(ypos) - - # Draw + # Identical to snap_rotate: transl * rotation * base_view → transform uniform. + # model and view uniforms are left as set by setup_shader (identity / camera). + gl.glUniformMatrix4fv( + gl.glGetUniformLocation(shader, "transform"), 1, gl.GL_FALSE, + compute_view_matrix(vs, base_view), + ) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - glfw.swap_buffers(window) - glfw.terminate() - app_.quit() + timer = QTimer() + timer.timeout.connect(_render_frame) + timer.start(16) + # If the Qt config panel is closed, shut down GLFW and exit the loop. + if config_window is not None: + config_window.destroyed.connect(lambda: _begin_quit()) -def config_app_exit_handler(): - global app_window_closed_ - app_window_closed_ = True + loop.exec() def run(): - global current_fthresh_, current_fmax_, app_, app_window_ - - parser = argparse.ArgumentParser() - parser.add_argument( - "-lh", - "--lh_overlay", - type=str, - default=None, - required=False, - help="Absolute path to the lh overlay file.", + """Command-line entry point for the WhipperSnapPy interactive GUI. + + Parses command-line arguments, validates them, then spawns the OpenGL + viewer thread and launches the PyQt6 configuration window in the main + thread. + + Two mutually exclusive input modes are supported: + + **General mode** — supply a mesh file directly with ``--mesh``. + Works with FreeSurfer binary surfaces, OFF, VTK, and PLY files. + + **FreeSurfer shortcut** — supply ``-sd``/``--sdir`` and + ``--hemi``; the surface, curvature, and cortex-label paths are all + derived automatically from the subject directory. + + Raises + ------ + RuntimeError + If PyQt6 is not installed (``pip install 'whippersnappy[gui]'``). + ValueError + For invalid or mutually exclusive argument combinations. + + Notes + ----- + **General mode** (``--mesh`` required): + + * ``--mesh`` — path to any triangular mesh file (FreeSurfer binary, + ``.off``, ``.vtk``, ``.ply``). + * ``--overlay`` — per-vertex scalar overlay file path or ``.mgh``. + * ``--annot`` — FreeSurfer ``.annot`` file. + * ``--bg-map`` — per-vertex scalar file for background shading. + * ``--roi`` — FreeSurfer label file or boolean mask for overlay region. + + **FreeSurfer shortcut** (``-sd``/``--sdir`` + ``--hemi`` required): + + * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and + ``label/`` subdirectories. + * ``--hemi`` — hemisphere to display: ``lh`` or ``rh``. + * ``-lh`` / ``--lh_overlay`` or ``-rh`` / ``--rh_overlay`` — overlay + for the respective hemisphere (shorthand; equivalent to ``--overlay``). + * ``--annot`` — annotation file (full path). + * ``-s`` / ``--surf_name`` — surface basename (e.g. ``white``); + auto-detected if not provided. + * ``--curv-name`` — curvature basename (default: ``curv``). + * ``--label-name`` — cortex label basename (default: ``cortex.label``). + + **Common options** (both modes): + + * ``--fthresh`` — overlay threshold (default: 2.0); adjustable live in the GUI. + * ``--fmax`` — overlay saturation (default: 4.0); adjustable live in the GUI. + * ``--invert`` — invert the color scale. + * ``--diffuse`` — use diffuse-only shading (no specular highlights). + * ``--view`` — initial camera view (default: ``left``). + + Requires ``pip install 'whippersnappy[gui]'``. + For non-interactive batch rendering use ``whippersnap4`` or ``whippersnap1``. + """ + global current_fthresh_, current_fmax_ + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + _VIEW_CHOICES = {v.name.lower(): v for v in ViewType} + + parser = argparse.ArgumentParser( + prog="whippersnap", + description=( + "Interactive GUI viewer for any triangular surface mesh. " + "Pass --mesh for a direct mesh file, or -sd/--sdir + --hemi " + "for the FreeSurfer subject-directory shortcut. " + "For non-interactive batch rendering use whippersnap4 or whippersnap1." + ), ) - parser.add_argument( - "-rh", - "--rh_overlay", - type=str, - default=None, - required=False, - help="Absolute path to the rh overlay file.", + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + + # --- General mesh mode --- + general = parser.add_argument_group( + "general mode", + "Load any triangular mesh directly (OFF, VTK, PLY, FreeSurfer binary).", ) - parser.add_argument( - "--lh_annot", - type=str, - default=None, - required=False, - help="Absolute path to the lh annotation file.", + general.add_argument( + "--mesh", type=str, default=None, + help="Path to any triangular mesh file (.off, .vtk, .ply, or FreeSurfer binary).", ) - parser.add_argument( - "--rh_annot", - type=str, - default=None, - required=False, - help="Absolute path to the rh annotation file.", + general.add_argument( + "--bg-map", dest="bg_map", type=str, default=None, + help="Per-vertex scalar file for background shading (sign → light/dark).", ) - parser.add_argument( - "-sd", - "--sdir", - type=str, - required=True, - help="Absolute path to subject directory from which surfaces will be loaded. " - "This is assumed to contain the surface files in a surf/ sub-directory.", + general.add_argument( + "--roi", type=str, default=None, + help="FreeSurfer label file or boolean mask restricting overlay coloring.", ) - parser.add_argument( - "-s", - "--surf_name", - type=str, - default=None, - help="Name of the surface file to load.", + + # --- FreeSurfer shortcut mode --- + fs = parser.add_argument_group( + "FreeSurfer shortcut", + "Derive mesh, curvature, and label paths from a subject directory.", + ) + fs.add_argument( + "-sd", "--sdir", type=str, default=None, + help="Subject directory containing surf/ and label/ subdirectories.", + ) + fs.add_argument( + "--hemi", type=str, default=None, choices=["lh", "rh"], + help="Hemisphere to display: lh or rh.", + ) + fs.add_argument( + "-s", "--surf_name", type=str, default=None, + help="Surface basename (e.g. 'white'); auto-detected if not provided.", + ) + fs.add_argument( + "--curv-name", dest="curv_name", type=str, default="curv", + help="Curvature file basename for background shading (default: curv).", + ) + fs.add_argument( + "--label-name", dest="label_name", type=str, default="cortex.label", + help="Cortex label basename for overlay masking (default: cortex.label).", ) - parser.add_argument( - "-o", - "--output_path", - type=str, - default="/tmp/whippersnappy_snap.png", - help="Absolute path to the output file (snapshot image), " - "if not running interactive mode.", + + # Hemisphere-prefixed overlay shortcuts (FreeSurfer convention) + fs.add_argument( + "-lh", "--lh_overlay", type=str, default=None, + help="Shorthand for --overlay when using lh hemisphere (e.g. lh.thickness).", ) - parser.add_argument( - "-c", "--caption", type=str, default="", help="Caption to place on the figure" + fs.add_argument( + "-rh", "--rh_overlay", type=str, default=None, + help="Shorthand for --overlay when using rh hemisphere (e.g. rh.thickness).", ) - parser.add_argument( - "--no-colorbar", - dest="no_colorbar", - action="store_true", - default=False, - help="Switch off colorbar.") - parser.add_argument("--fmax", type=float, default=4.0) - parser.add_argument("--fthresh", type=float, default=2.0) - parser.add_argument( - "-i", - "--interactive", - dest="interactive", - action="store_true", - help="Start an interactive GUI session.", + + # --- Inputs common to both modes --- + common = parser.add_argument_group("overlay / annotation (both modes)") + common.add_argument( + "--overlay", type=str, default=None, + help="Per-vertex scalar overlay file path.", ) - parser.add_argument( - "--invert", dest="invert", action="store_true", help="Invert the color scale." + common.add_argument( + "--annot", type=str, default=None, + help="FreeSurfer .annot file for parcellation coloring.", ) - parser.add_argument( - "--diffuse", - dest="specular", - action="store_false", - default=True, - help="Diffuse surface reflection (switch-off specular).", + common.add_argument("--lut", type=str, default=None, + help="Path to a label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --annot is a csv/txt label map.") + + # --- Appearance / rendering --- + rend = parser.add_argument_group("appearance") + rend.add_argument("--fmax", type=float, default=4.0, + help="Overlay saturation value (default: 4.0).") + rend.add_argument("--fthresh", type=float, default=2.0, + help="Overlay threshold value (default: 2.0).") + rend.add_argument("--invert", action="store_true", + help="Invert the color scale.") + rend.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Diffuse-only shading (no specular highlights).") + rend.add_argument( + "--view", type=str, default="left", choices=list(_VIEW_CHOICES), + help="Initial camera view direction (default: left).", ) args = parser.parse_args() - # check for mutually exclusive arguments - if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): - print("[ERROR] Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time.") - sys.exit(0) - # check if at least one variant is present - if args.lh_overlay is None and args.rh_overlay is None and args.lh_annot is None and args.rh_annot is None: - print("[ERROR] Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present.") - sys.exit(0) - # check if both hemis are present - if (args.lh_overlay is None and args.rh_overlay is not None) or \ - (args.lh_overlay is not None and args.rh_overlay is None) or \ - (args.lh_annot is None and args.rh_annot is not None) or \ - (args.lh_annot is not None and args.rh_annot is None): - print("[ERROR] If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present " \ - "(and vice versa).") - sys.exit(0) - - # - if not args.interactive: - snap4( - lhoverlaypath=args.lh_overlay, - rhoverlaypath=args.rh_overlay, - lhannotpath=args.lh_annot, - rhannotpath=args.rh_annot, - sdir=args.sdir, - caption=args.caption, - surfname=args.surf_name, - fthresh=args.fthresh, - fmax=args.fmax, - invert=args.invert, - colorbar=not(args.no_colorbar), - outpath=args.output_path, - specular=args.specular, - ) - else: - current_fthresh_ = args.fthresh - current_fmax_ = args.fmax - - # Starting interactive OpenGL window in a separate thread: - thread = threading.Thread( - target=show_window, - args=( - "lh", - args.lh_overlay, - args.lh_annot, - args.sdir, - None, - False, - "cortex.label", - args.surf_name, - "curv", - args.specular, - ), - ) - thread.start() - - # Setting up and running config app window (must be main thread): - app_ = QApplication([]) - app_.setStyle("Fusion") # the default - app_.signals.aboutToQuit.connect(config_app_exit_handler) - - screen_geometry = app_.primaryScreen().availableGeometry() - app_window_ = ConfigWindow( - screen_dims=(screen_geometry.width(), screen_geometry.height()), - initial_fthresh_value=current_fthresh_, - initial_fmax_value=current_fmax_, - ) + # ------------------------------------------------------------------ + # Resolve the two modes and build the final mesh / bg_map / roi paths + # ------------------------------------------------------------------ + fs_mode = args.sdir is not None or args.hemi is not None + general_mode = args.mesh is not None + + try: + if fs_mode and general_mode: + raise ValueError( + "Cannot combine --mesh with -sd/--sdir or --hemi. " + "Use either general mode (--mesh) or FreeSurfer mode (-sd + --hemi)." + ) + if not fs_mode and not general_mode: + raise ValueError( + "Either --mesh (general mode) or both -sd/--sdir and --hemi " + "(FreeSurfer shortcut) must be provided." + ) + + if fs_mode: + if args.sdir is None or args.hemi is None: + raise ValueError( + "FreeSurfer mode requires both -sd/--sdir and --hemi." + ) - # The following is a way to allow CTRL+C termination of the app: - signal.signal(signal.SIGINT, signal.SIG_DFL) + # Resolve overlay: --overlay takes precedence; -lh/-rh are shorthands + overlay = args.overlay + if overlay is None: + if args.hemi == "lh" and args.lh_overlay: + overlay = args.lh_overlay + elif args.hemi == "rh" and args.rh_overlay: + overlay = args.rh_overlay + elif args.lh_overlay and not fs_mode: + raise ValueError( + "-lh/--lh_overlay is only valid in FreeSurfer mode (with --hemi lh)." + ) + elif args.rh_overlay and not fs_mode: + raise ValueError( + "-rh/--rh_overlay is only valid in FreeSurfer mode (with --hemi rh)." + ) - app_window_.show() - app_.exec() + if overlay and args.annot: + raise ValueError("Cannot combine --overlay/hemisphere overlay and --annot.") + if not overlay and not args.annot and not general_mode: + raise ValueError( + "Either an overlay (-lh/-rh/--overlay) or --annot must be provided." + ) + except ValueError as e: + parser.error(str(e)) + + # Build resolved mesh / bg_map / roi + if fs_mode: + hemi = args.hemi + sdir = args.sdir + if args.surf_name is None: + found = get_surf_name(sdir, hemi) + if found is None: + parser.error(f"Could not find a valid surface in {sdir} for hemi {hemi!r}.") + mesh_path = os.path.join(sdir, "surf", f"{hemi}.{found}") + else: + mesh_path = os.path.join(sdir, "surf", f"{hemi}.{args.surf_name}") + bg_map = os.path.join(sdir, "surf", f"{hemi}.{args.curv_name}") if args.curv_name else None + roi = os.path.join(sdir, "label", f"{hemi}.{args.label_name}") if args.label_name else None + view = ViewType.RIGHT if hemi == "rh" else ViewType.LEFT + else: + mesh_path = args.mesh + bg_map = args.bg_map + roi = args.roi + view = _VIEW_CHOICES[args.view] + + # ------------------------------------------------------------------ + # Start the Qt app + OpenGL thread + # ------------------------------------------------------------------ + if QApplication is None: + print( + "ERROR: Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'", + file=sys.stderr, + ) + raise RuntimeError( + "Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'" + ) -# headless docker test using xvfb: -# Note, xvfb is a display server implementing the X11 protocol, and performing -# all graphics on memory. -# glfw needs a windows to render even if that is invisible, so above code -# will not work via ssh or on a headless server. xvfb can solve this by wrapping: -# docker run --name headless_test -ti -v$(pwd):/test ubuntu /bin/bash -# apt update && apt install -y python3 python3-pip xvfb -# pip3 install pyopengl glfw pillow numpy pyrr -# xvfb-run python3 test4.py + try: + from ..gui import ConfigWindow # noqa: PLC0415 + except ModuleNotFoundError as e: + raise RuntimeError( + "Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'" + ) from e + + current_fthresh_ = args.fthresh + current_fmax_ = args.fmax + + # Both QApplication/Qt and GLFW/Cocoa require the main thread on macOS. + # Create Qt objects here on the main thread, then pass them into + # show_window which drives rendering via a QTimer (no extra threads). + app = QApplication(sys.argv) + app.setStyle("Fusion") + + screen_geometry = app.primaryScreen().availableGeometry() + config_window = ConfigWindow( + screen_dims=(screen_geometry.width(), screen_geometry.height()), + initial_fthresh_value=current_fthresh_, + initial_fmax_value=current_fmax_, + ) + config_window.show() + + # show_window creates the GLFW window, sets up a QTimer render loop, + # then calls app.exec() — returns when either window is closed. + show_window( + mesh=mesh_path, + overlay=overlay, + annot=args.annot, + bg_map=bg_map, + roi=roi, + invert=args.invert, + specular=args.specular, + view=view, + app=app, + config_window=config_window, + ) -# instead of the above one could really do headless off screen rendering via -# EGL (preferred) or OSMesa. The latter looks doable. EGL looks tricky. -# EGL is part of any modern NVIDIA driver -# OSMesa needs to be installed, but should work almost everywhere -# using EGL maybe like this: -# https://github.com/eduble/gl -# or via these bindings: -# https://github.com/perey/pegl +if __name__ == "__main__": + run() -# or OSMesa -# https://github.com/AntonOvsyannikov/DockerGL diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py new file mode 100644 index 0000000..f3684e3 --- /dev/null +++ b/whippersnappy/cli/whippersnap1.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +"""CLI entry point for single-mesh snapshot and rotation video via snap1/snap_rotate. + +Renders any triangular surface mesh from a chosen viewpoint and saves it as a +PNG image. Alternatively, pass ``--rotate`` to produce a full 360° rotation +video (MP4, WebM, or GIF). + +The mesh can be a FreeSurfer binary surface (e.g. ``lh.white``), an ASCII OFF +file (``mesh.off``), a legacy ASCII VTK PolyData file (``mesh.vtk``), or an +ASCII PLY file (``mesh.ply``). + +Usage:: + + # FreeSurfer surface — lateral view with thickness overlay + whippersnap1 --mesh /surf/lh.white \\ + --overlay /surf/lh.thickness \\ + --bg-map /surf/lh.curv \\ + --roi /label/lh.cortex.label \\ + --view left --fthresh 1.5 --fmax 4.0 \\ + -o snap1.png + + # OFF / VTK / PLY mesh with a numpy-saved overlay + whippersnap1 --mesh mesh.off --overlay values.mgh -o snap1.png + whippersnap1 --mesh mesh.vtk -o snap1.png + whippersnap1 --mesh mesh.ply --overlay values.mgh -o snap1.png + + # 360° rotation video + whippersnap1 --mesh /surf/lh.white \\ + --overlay /surf/lh.thickness \\ + --rotate --rotate-frames 72 --rotate-fps 24 \\ + -o rotation.mp4 + + # Parcellation annotation + whippersnap1 --mesh /surf/lh.white \\ + --annot /label/lh.aparc.annot \\ + --view left -o snap_annot.png + +See ``whippersnap1 --help`` for the full list of options. +For four-view batch rendering use ``whippersnap4``. +For the interactive GUI use ``whippersnap``. +""" + +import argparse +import logging +import os +import tempfile + +import numpy as np + +if __name__ == "__main__" and __package__ is None: + import sys + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap1"] + sys.argv[1:]) + +from .. import snap1, snap_rotate +from .._version import __version__ +from ..utils.types import ColorSelection, OrientationType, ViewType + +_VIEW_CHOICES = {v.name.lower(): v for v in ViewType} +_ORIENT_CHOICES = {o.name.lower(): o for o in OrientationType} +_COLOR_CHOICES = {c.name.lower(): c for c in ColorSelection} + + +def run(): + """Command-line entry point for single-view snapshot or rotation video. + + Parses command-line arguments, validates them, and calls either + :func:`whippersnappy.snap1` (static snapshot) or + :func:`whippersnappy.snap_rotate` (360° rotation video) depending on + whether ``--rotate`` is passed. + All input is read from ``sys.argv`` via :mod:`argparse`. + + Raises + ------ + FileNotFoundError + If the mesh file or any overlay/annotation/label file cannot be found. + RuntimeError + If the OpenGL context cannot be initialised. + ValueError + For invalid argument combinations. + + Notes + ----- + **Snapshot options** (default mode): + + * ``--mesh`` — path to any triangular surface mesh: FreeSurfer binary + (e.g. ``lh.white``), ASCII OFF (``.off``), legacy ASCII VTK PolyData + (``.vtk``), or ASCII PLY (``.ply``). + * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness`` or a ``.mgh`` file). + * ``--annot`` — FreeSurfer ``.annot`` parcellation file. + * ``--roi`` — FreeSurfer label file defining vertices to include in overlay coloring. + * ``--bg-map`` — per-vertex scalar file whose sign controls light/dark background shading. + * ``--view`` — camera direction: ``left``, ``right``, ``front``, ``back``, + ``top``, ``bottom`` (default: ``left``). + * ``--fthresh`` / ``--fmax`` — overlay threshold and saturation values. + * ``--invert`` — invert the color scale. + * ``--no-colorbar`` — suppress the color bar. + * ``--caption`` — text label placed on the image. + * ``--width`` / ``--height`` — output resolution in pixels (default: 700×500). + * ``-o`` / ``--output`` — output file path (default: temp ``.png``). + + **Rotation video options** (pass ``--rotate``): + + * ``--rotate-frames`` — number of frames for one full rotation (default: 72). + * ``--rotate-fps`` — output frame rate (default: 24). + * ``--rotate-start-view`` — starting camera direction (default: ``left``). + * ``-o`` — output path; extension controls format: ``.mp4``, ``.webm``, + or ``.gif`` (GIF requires no ffmpeg). + """ + + parser = argparse.ArgumentParser( + prog="whippersnap1", + description=( + "Render a single-view screenshot of any triangular surface mesh " + "(FreeSurfer or otherwise) without a GUI. " + "Pass --rotate to produce a 360° rotation video instead." + ), + ) + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + + # --- Mesh input: --mesh flag (preferred) or bare positional (legacy) --- + parser.add_argument( + "--mesh", + type=str, + default=None, + help=( + "Path to the surface mesh file. Supported formats: " + "FreeSurfer binary surface (e.g. lh.white, rh.pial), " + "ASCII OFF (.off), legacy ASCII VTK PolyData (.vtk), ASCII PLY (.ply), " + "GIfTI surface (.gii, .surf.gii)." + ), + ) + # Keep positional for backward compatibility (silently accepted) + parser.add_argument( + "_mesh_positional", + nargs="?", + default=None, + metavar="MESH", + help=argparse.SUPPRESS, + ) + + # --- Output --- + parser.add_argument( + "-o", "--output", + type=str, + default=None, + help=( + "Output file path. For snapshots defaults to a temp .png file; " + "for rotation videos defaults to a temp .mp4 file. " + "Use .gif for an animated GIF (no ffmpeg required)." + ), + ) + + # --- Optional overlay / annotation / roi / bg-map --- + parser.add_argument("--overlay", type=str, default=None, help="Per-vertex overlay file.") + parser.add_argument("--annot", type=str, default=None, help="FreeSurfer .annot file.") + parser.add_argument("--roi", type=str, default=None, + help="Path to a FreeSurfer label file defining the region of interest " + "(vertices to include in overlay coloring).") + parser.add_argument("--bg-map", type=str, default=None, dest="bg_map", + help="Path to a per-vertex scalar file used as background shading " + "(sign determines light/dark).") + parser.add_argument("--lut", type=str, default=None, + help="Path to a label look-up-table (LUT) file (csv/txt) with label IDs and RGB(A) colors. " + "Required if --annot is a csv/txt label map.") + + # --- View --- + parser.add_argument( + "--view", + type=str, + default="left", + choices=list(_VIEW_CHOICES), + help="Pre-defined view direction (default: left).", + ) + + # --- Appearance --- + parser.add_argument("--width", type=int, default=700) + parser.add_argument("--height", type=int, default=500) + parser.add_argument("--fthresh", type=float, default=None, help="Overlay threshold.") + parser.add_argument("--fmax", type=float, default=None, help="Overlay saturation value.") + parser.add_argument("--caption", type=str, default=None) + parser.add_argument("--invert", action="store_true", help="Invert color scale.") + parser.add_argument("--no-colorbar", dest="no_colorbar", action="store_true") + parser.add_argument( + "--color-mode", + type=str, + default="both", + choices=list(_COLOR_CHOICES), + help="Which overlay sign to display (default: both).", + ) + parser.add_argument( + "--orientation", + type=str, + default="horizontal", + choices=list(_ORIENT_CHOICES), + help="Colorbar orientation (default: horizontal).", + ) + parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Use diffuse-only shading (no specular).") + parser.add_argument("--brain-scale", type=float, default=1.5, + help="Geometry scale factor (default: 1.5).") + parser.add_argument("--ambient", type=float, default=0.0, + help="Ambient light strength (default: 0.0).") + parser.add_argument("--font", type=str, default=None, + help="Path to a TTF font for captions.") + + # --- Rotation video --- + rotate_group = parser.add_argument_group("rotation video (--rotate)") + rotate_group.add_argument( + "--rotate", + action="store_true", + help="Produce a 360° rotation video instead of a static snapshot.", + ) + rotate_group.add_argument( + "--rotate-frames", + type=int, + default=72, + metavar="N", + help="Number of frames for a full rotation (default: 72, i.e. 5° per frame).", + ) + rotate_group.add_argument( + "--rotate-fps", + type=int, + default=24, + metavar="FPS", + help="Frame rate of the output video (default: 24).", + ) + rotate_group.add_argument( + "--rotate-start-view", + type=str, + default="left", + choices=list(_VIEW_CHOICES), + metavar="VIEW", + help="Starting view for the rotation (default: left).", + ) + + args = parser.parse_args() + + # Resolve mesh: --mesh takes precedence over bare positional argument + mesh_path = args.mesh or args._mesh_positional + if mesh_path is None: + parser.error("A mesh file is required: use --mesh .") + + log = logging.getLogger(__name__) + + try: + if args.rotate: + outpath = args.output or os.path.join( + tempfile.gettempdir(), "whippersnappy_rotation.mp4" + ) + snap_rotate( + mesh=mesh_path, + outpath=outpath, + n_frames=args.rotate_frames, + fps=args.rotate_fps, + width=args.width, + height=args.height, + overlay=args.overlay, + bg_map=args.bg_map, + annot=args.annot, + roi=args.roi, + fthresh=args.fthresh, + fmax=args.fmax, + invert=args.invert, + specular=args.specular, + ambient=args.ambient, + brain_scale=args.brain_scale, + start_view=_VIEW_CHOICES[args.rotate_start_view], + color_mode=_COLOR_CHOICES[args.color_mode], + ) + log.info("Rotation video saved to %s", outpath) + else: + outpath = args.output or os.path.join( + tempfile.gettempdir(), "whippersnappy_snap1.png" + ) + if args.annot is not None and args.lut is not None: + # Load label map + labels = np.loadtxt(args.annot, delimiter=None, dtype=int) + # Load LUT + lut = np.loadtxt(args.lut, delimiter=None) + # Normalize color values if needed + if lut.shape[1] in (4,5): # label + RGB(A) + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + annot_tuple = (labels, lut) + else: + annot_tuple = args.annot + img = snap1( + mesh=mesh_path, + outpath=outpath, + overlay=args.overlay, + annot=annot_tuple, + roi=args.roi, + bg_map=args.bg_map, + view=_VIEW_CHOICES[args.view], + width=args.width, + height=args.height, + fthresh=args.fthresh, + fmax=args.fmax, + caption=args.caption, + invert=args.invert, + colorbar=not args.no_colorbar, + color_mode=_COLOR_CHOICES[args.color_mode], + orientation=_ORIENT_CHOICES[args.orientation], + font_file=args.font, + specular=args.specular, + brain_scale=args.brain_scale, + ambient=args.ambient, + ) + log.info("Snapshot saved to %s (%dx%d)", outpath, img.width, img.height) + except (RuntimeError, FileNotFoundError, ValueError, ImportError) as e: + parser.error(str(e)) + + +if __name__ == "__main__": + run() + diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py new file mode 100644 index 0000000..5457b34 --- /dev/null +++ b/whippersnappy/cli/whippersnap4.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +"""CLI entry point for four-view batch rendering via snap4. + +Renders left and right hemisphere surfaces in lateral and medial views, +stitches them into a single image and writes it to disk. + +Usage:: + + whippersnap4 -lh -rh -sd -o out.png + whippersnap4 --lh_annot --rh_annot -sd + +See ``whippersnap4 --help`` for the full list of options. +For the interactive GUI use ``whippersnap``. +""" + +import argparse +import logging +import os +import tempfile + +import numpy as np + +if __name__ == "__main__" and __package__ is None: + import sys + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap4"] + sys.argv[1:]) + +from .. import snap4 +from .._version import __version__ + +# Module logger +logger = logging.getLogger(__name__) + + +def run(): + """Command-line entry point for WhipperSnapPy four-view batch rendering. + + Parses command-line arguments, validates them, and calls + :func:`whippersnappy.snap4` to produce a four-view composed image. + + Raises + ------ + ValueError + For invalid or mutually exclusive argument combinations. + RuntimeError + If the OpenGL context cannot be initialised. + FileNotFoundError + If required surface files cannot be found. + + Notes + ----- + **Required:** + + * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and + ``label/`` subdirectories. + * One of the following (not both): + + * ``-lh`` / ``--lh_overlay`` **and** ``-rh`` / ``--rh_overlay`` — per-vertex + scalar overlay files for left and right hemispheres (e.g. ``lh.thickness``). + * ``--lh_annot`` **and** ``--rh_annot`` — FreeSurfer ``.annot`` parcellation + files for left and right hemispheres. + + **Output:** + + * ``-o`` / ``--output_path`` — output image path (default: temp ``.png``). + + **Overlay appearance:** + + * ``--fthresh`` / ``--fmax`` — threshold and saturation values + (auto-estimated if not set). + * ``--invert`` — invert the color scale. + * ``--no-colorbar`` — suppress the color bar. + * ``-c`` / ``--caption`` — text label placed on the figure. + + **Rendering:** + + * ``-s`` / ``--surf_name`` — surface basename (e.g. ``white``); + auto-detected if not provided. + * ``--diffuse`` — use diffuse-only shading (no specular highlights). + * ``--brain-scale`` — geometry scale factor (default: 1.85). + * ``--ambient`` — ambient light strength (default: 0.0). + * ``--font`` — path to a TTF font file for captions. + """ + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + parser = argparse.ArgumentParser( + prog="whippersnap4", + description=( + "Render a four-view (left/right hemisphere, lateral/medial) " + "batch snapshot without a GUI. " + "For the interactive GUI use whippersnap." + ), + ) + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + + # --- Overlay / annotation inputs --- + parser.add_argument("-lh", "--lh_overlay", type=str, default=None, + help="Path to the lh overlay file.") + parser.add_argument("-rh", "--rh_overlay", type=str, default=None, + help="Path to the rh overlay file.") + parser.add_argument("--lh_annot", type=str, default=None, + help="Path to the lh annotation (.annot) file.") + parser.add_argument("--rh_annot", type=str, default=None, + help="Path to the rh annotation (.annot) file.") + parser.add_argument("--lh_lut", type=str, default=None, + help="Path to the lh label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --lh_annot is a csv/txt label map.") + parser.add_argument("--rh_lut", type=str, default=None, + help="Path to the rh label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --rh_annot is a csv/txt label map.") + + # --- Subject directory / surface --- + parser.add_argument("-sd", "--sdir", type=str, required=True, + help="Subject directory containing surf/ and label/ subdirectories.") + parser.add_argument("-s", "--surf_name", type=str, default=None, + help="Surface basename to load (e.g. 'white'); " + "auto-detected if not provided.") + + # --- Output --- + parser.add_argument( + "-o", "--output_path", + type=str, + default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap4.png"), + help="Output image path (default: temp file).", + ) + + # --- Overlay appearance --- + parser.add_argument("--fmax", type=float, default=None, + help="Overlay saturation value (auto-estimated if not set).") + parser.add_argument("--fthresh", type=float, default=None, + help="Overlay threshold value (auto-estimated if not set).") + parser.add_argument("--invert", action="store_true", + help="Invert the color scale.") + parser.add_argument("--no-colorbar", dest="no_colorbar", action="store_true", + default=False, help="Suppress the colorbar.") + parser.add_argument("-c", "--caption", type=str, default="", + help="Caption text to place on the figure.") + + # --- Rendering --- + parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Diffuse-only shading (no specular).") + parser.add_argument("--ambient", type=float, default=0.0, + help="Ambient light strength (default: 0.0).") + parser.add_argument("--brain-scale", type=float, default=1.85, + help="Geometry scale factor (default: 1.85).") + parser.add_argument("--font", type=str, default=None, + help="Path to a TTF font for captions.") + + args = parser.parse_args() + + try: + if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): + raise ValueError( + "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot at the same time." + ) + if not any([args.lh_overlay, args.rh_overlay, args.lh_annot, args.rh_annot]): + raise ValueError( + "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + ) + if (args.lh_overlay is None) != (args.rh_overlay is None): + raise ValueError("Both -lh and -rh overlays must be provided together.") + if (args.lh_annot is None) != (args.rh_annot is None): + raise ValueError("Both --lh_annot and --rh_annot must be provided together.") + except ValueError as e: + parser.error(str(e)) + + logger.debug("Parsed args: %s", vars(args)) + + try: + if args.lh_annot is not None and args.lh_lut is not None: + labels = np.loadtxt(args.lh_annot, delimiter=None, dtype=int) + lut = np.loadtxt(args.lh_lut, delimiter=None) + if lut.shape[1] in (4,5): + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + lh_annot_tuple = (labels, lut) + else: + lh_annot_tuple = args.lh_annot + if args.rh_annot is not None and args.rh_lut is not None: + labels = np.loadtxt(args.rh_annot, delimiter=None, dtype=int) + lut = np.loadtxt(args.rh_lut, delimiter=None) + if lut.shape[1] in (4,5): + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + rh_annot_tuple = (labels, lut) + else: + rh_annot_tuple = args.rh_annot + + img = snap4( + lh_overlay=args.lh_overlay, + rh_overlay=args.rh_overlay, + lh_annot=lh_annot_tuple, + rh_annot=rh_annot_tuple, + sdir=args.sdir, + caption=args.caption, + surfname=args.surf_name, + fthresh=args.fthresh, + fmax=args.fmax, + invert=args.invert, + colorbar=not args.no_colorbar, + outpath=args.output_path, + font_file=args.font, + specular=args.specular, + ambient=args.ambient, + brain_scale=args.brain_scale, + ) + logger.info( + "Snapshot saved to %s (%dx%d)", args.output_path, img.width, img.height + ) + except (RuntimeError, FileNotFoundError, ValueError) as e: + parser.error(str(e)) + + +if __name__ == "__main__": + run() + diff --git a/whippersnappy/commands/sys_info.py b/whippersnappy/commands/sys_info.py index c1607cf..b000cfa 100644 --- a/whippersnappy/commands/sys_info.py +++ b/whippersnappy/commands/sys_info.py @@ -4,7 +4,21 @@ def run(): - """Run sys_info() command.""" + """Run the sys_info command-line helper. + + Parses CLI arguments and delegates to the package-level `sys_info` + function which prints system and dependency information. + + Parameters + ---------- + None + + Returns + ------- + None + This function prints information to stdout via the `sys_info` + helper and does not return a value. + """ parser = argparse.ArgumentParser( prog=f"{__package__.split('.')[0]}-sys_info", description="sys_info" ) diff --git a/whippersnappy/core.py b/whippersnappy/core.py deleted file mode 100644 index 5201251..0000000 --- a/whippersnappy/core.py +++ /dev/null @@ -1,1498 +0,0 @@ -"""Contains the core functionalities of WhipperSnapPy. - -Dependencies: - numpy, glfw, pyrr, PyOpenGL, pillow - -@Author : Martin Reuter -@Created : 27.02.2022 -@Revised : 02.10.2025 - -""" - -import math -import os -import sys - -import glfw -import numpy as np -import OpenGL.GL as gl -import OpenGL.GL.shaders as shaders -import pyrr -from PIL import Image, ImageDraw, ImageFont - -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data -from .types import ColorSelection, OrientationType, ViewType - - -def normalize_mesh(v, scale=1.0): - """ - Normalize mesh vertex coordinates. - - - Center their bounding box at the origin. - - Ensure that the longest side-length is equal to the scale variable (default 1). - - Parameters - ---------- - v : numpy.ndarray - Vertex array (Nvert X 3). - scale : float - Scaling constant. - - Returns - ------- - v: numpy.ndarray - Normalized vertex array (Nvert X 3). - """ - # center bounding box at origin - # scale longest side to scale (default 1) - bbmax = np.max(v, axis=0) - bbmin = np.min(v, axis=0) - v = v - 0.5 * (bbmax + bbmin) - v = scale * v / np.max(bbmax - bbmin) - return v - - -# adopted from lapy -def vertex_normals(v, t): - """ - Compute vertex normals. - - Triangle normals around each vertex are averaged, weighted by the angle - that they contribute. - Vertex ordering is important in t: counterclockwise when looking at the - triangle from above, so that normals point outwards. - - Parameters - ---------- - v : numpy.ndarray - Vertex array (Nvert X 3). - t : numpy.ndarray - Triangle array (Ntria X 3). - - Returns - ------- - normals: numpy.ndarray - Normals array: n - normals (Nvert X 3). - """ - # Compute vertex coordinates and a difference vector for each triangle: - v0 = v[t[:, 0], :] - v1 = v[t[:, 1], :] - v2 = v[t[:, 2], :] - v1mv0 = v1 - v0 - v2mv1 = v2 - v1 - v0mv2 = v0 - v2 - # Compute cross product at every vertex - # will point into the same direction with lengths depending on spanned area - cr0 = np.cross(v1mv0, -v0mv2) - cr1 = np.cross(v2mv1, -v1mv0) - cr2 = np.cross(v0mv2, -v2mv1) - # Add normals at each vertex (there can be duplicate indices in t at vertex i) - n = np.zeros(v.shape) - np.add.at(n, t[:, 0], cr0) - np.add.at(n, t[:, 1], cr1) - np.add.at(n, t[:, 2], cr2) - # Normalize normals - ln = np.sqrt(np.sum(n * n, axis=1)) - ln[ln < sys.float_info.epsilon] = 1 # avoid division by zero - n = n / ln.reshape(-1, 1) - return n - - -def heat_color(values, invert=False): - """ - Convert an array of float values into RBG heat color values. - - Only values between -1 and 1 will receive gradient and colors will - max-out at -1 and 1. Negative values will be blue and positive - red (unless invert is passed to flip the heatmap). Masked values - (nan) will map to masked colors (nan,nan,nan). - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface mesh (length Nvert). - invert : bool - Whether to invert the heat map (blue is positive and red negative). - - Returns - ------- - colors: numpy.ndarray - (Nvert x 3) array of RGB of heat map as 0.0 .. 1.0 floats. - """ - # values (1 dim array length n) will receive gradient between -1 and 1 - # nan will return (nan,nan,nan) - # returns colors (r,g,b) as n x 3 array - if invert: - values = -1.0 * values - vabs = np.abs(values) - colors = np.zeros((vabs.size, 3), dtype=np.float32) - crb = 0.5625 + 3 * 0.4375 * vabs - cg = 1.5 * (vabs - (1.0 / 3.0)) - n1 = values < -1.0 - nm = (values >= -1.0) & (values < -(1.0 / 3.0)) - n0 = (values >= -(1.0 / 3.0)) & (values < 0) - p0 = (values >= 0) & (values < (1.0 / 3.0)) - pm = (values >= (1.0 / 3.0)) & (values < 1.0) - p1 = values >= 1.0 - # fill in colors for the 5 blocks - colors[n1, 1:3] = 1.0 # bright blue - colors[nm, 1] = cg[nm] # cg increasing green channel - colors[nm, 2] = 1.0 # and keeping blue on full - colors[n0, 2] = crb[n0] # crb increasing blue channel - colors[p0, 0] = crb[p0] # crb increasing red channel - colors[pm, 1] = cg[pm] # cg increasing green channel - colors[pm, 0] = 1.0 # and keeping red on full - colors[p1, 0:2] = 1.0 # yellow - colors[np.isnan(values), :] = np.nan - return colors - -def mask_sign(values, color_mode): - """ - Mask values don't have the same sign as the color_mode. - - The masked values will be replaced by nan. - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface mesh (length Nvert). - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - - Returns - ------- - values: numpy.ndarray - Float array of input function on mesh (length Nvert). - """ - masked_values = np.copy(values) - if color_mode == ColorSelection.POSITIVE: - masked_values[masked_values < 0] = np.nan - elif color_mode == ColorSelection.NEGATIVE: - masked_values[masked_values > 0] = np.nan - return masked_values - -def rescale_overlay(values, minval=None, maxval=None): - """ - Rescale values for color map computation. - - minval and maxval are two positive floats (maxval>minval). - Values between -minval and minval will be masked (np.nan); - others will be shifted towards zero (from both sides) - and scaled so that -maxval and maxval are at -1 and +1. - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface (length Nvert). - minval : float - Minimum value. - maxval : float - Maximum value. - - Returns - ------- - values: numpy.ndarray - Float array of input function on mesh (length Nvert). - minval: float - Positive minimum value (crop values whose absolute value is below). - maxval: float - Positive maximum value (saturate color at maxval and -maxval). - pos: bool - Whether positive values are present at all after cropping. - neg: bool - Whether negative values are present at all after cropping. - """ - valsign = np.sign(values) - valabs = np.abs(values) - - if maxval < 0 or minval < 0: - print("resacle_overlay ERROR: min and maxval should both be positive!") - exit(1) - - # Mask values below minval - values[valabs < minval] = np.nan - - # Rescale map symmetrically to -1 .. 1 with the minval = 0 - # Any arithmetic operation containing NaN values results in NaN - range_val = maxval - minval - if range_val == 0: - values = np.zeros_like(values) - else: - values = values - valsign * minval - values = values / range_val - - # Check if there are any positive or negative values - pos = np.any(values[~np.isnan(values)] > 0) - neg = np.any(values[~np.isnan(values)] < 0) - - return values, minval, maxval, pos, neg - - -def binary_color(values, thres, color_low, color_high): - """ - Create a binary colormap based on a threshold value. - - This function assigns colors to input values based on whether they are - below or equal to the threshold (thres) or greater than the threshold. - - Values below thres are color_low, others are color_high. - color_low and color_high can be float (gray scale), or 1x3 array of RGB. - - Parameters - ---------- - values : numpy.ndarray - Input vertex function as float array (length Nvert). - thres : float - Threshold value. - color_low : float or numpy.ndarray - Lower color value(s). - color_high : float or numpy.ndarray - Higher color value(s). - - Returns - ------- - colors : numpy.ndarray - Binary colormap. - """ - if np.isscalar(color_low): - color_low = np.array((color_low, color_low, color_low), dtype=np.float32) - if np.isscalar(color_high): - color_high = np.array((color_high, color_high, color_high), dtype=np.float32) - colors = np.empty((values.size, 3), dtype=np.float32) - colors[values < thres, :] = color_low - colors[values >= thres, :] = color_high - return colors - - -def mask_label(values, labelpath=None): - """ - Apply a labelfile as a mask. - - Labelfile freesurfer format has indices of values that should be kept; - all other values will be set to np.nan. - - Parameters - ---------- - values : numpy.ndarray - Float values of function defined at vertices (a 1-dim array). - labelpath : str - Absolute path to label file. - - Returns - ------- - values: numpy.ndarray - Masked surface function values. - """ - if not labelpath: - return values - # this is the mask of vertices to keep, e.g. cortex labels - maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0]) - imask = np.ones(values.shape, dtype=bool) - imask[maskvids] = False - values[imask] = np.nan - return values - - -def prepare_geometry( - surfpath, - overlaypath=None, - annotpath=None, - curvpath=None, - labelpath=None, - minval=None, - maxval=None, - invert=False, - scale=1.85, - color_mode=ColorSelection.BOTH -): - """ - Prepare meshdata for upload to GPU. - - Vertex coordinates, vertex normals and color values are concatenated into - large vertexdata array. Also returns triangles, minimum and maximum overlay - values as well as whether negative values are present or not in triangles. - - Parameters - ---------- - surfpath : str - Path to surface file (usually lh or rh.pial_semi_inflated). - overlaypath : str - Path to overlay file. - annotpath : str - Path to annotation file. - curvpath : str - Path to curvature file (usually lh or rh.curv). - labelpath : str - Path to label file (mask; usually cortex.label). - minval : float - Minimum threshold to stop coloring (-minval used for neg values). - maxval : float - Maximum value to saturate (-maxval used for negative values). - invert : bool - Invert color map. - scale : float - Global scaling factor. Default: 1.85. - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - - Returns - ------- - vertexdata: numpy.ndarray - Concatenated array with vertex coords, vertex normals and colors - as a (Nvert X 9) float32 array. - triangles: numpy.ndarray - Triangle array as a (Ntria X 3) uint32 array. - fmin: float - Minimum value of overlay function after rescale. - fmax: float - Maximum value of overlay function after rescale. - pos: bool - Whether positive values are there after rescale/cropping. - neg: bool - Whether negative values are there after rescale/cropping. - """ - - # read vertices and triangles - surf = read_geometry(surfpath, read_metadata=False) - vertices = normalize_mesh(np.array(surf[0], dtype=np.float32), scale) - triangles = np.array(surf[1], dtype=np.uint32) - # compute vertex normals - vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32) - # read curvature - if curvpath: - curv = read_morph_data(curvpath) - sulcmap = binary_color(curv, 0.0, color_low=0.5, color_high=0.33) - else: - # if no curv pattern, color mesh in mid-gray - sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) - # read map (stats etc) or annotation - if overlaypath: - _, file_extension = os.path.splitext(overlaypath) - - if file_extension == ".mgh": - mapdata = read_mgh_data(overlaypath) - else: - mapdata = read_morph_data(overlaypath) - - valabs = np.abs(mapdata) - if maxval is None: - maxval = np.max(valabs) if np.any(valabs) else 0 - if minval is None: - minval = max(0.0, np.min(valabs) if np.any(valabs) else 0) - - # Mask map and get either positive and/or negative values - mapdata = mask_sign(mapdata, color_mode) - - # Rescale the map with minval and maxval - mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) - - # mask map with label - mapdata = mask_label(mapdata, labelpath) - - # compute color - colors = heat_color(mapdata, invert) - - missing = np.isnan(mapdata) - colors[missing, :] = sulcmap[missing, :] - elif annotpath: - annot, ctab, names = read_annot_data(annotpath) - # compute color - colors = ctab[annot, 0:3] / np.max(ctab[:, 0:3]) - # annot can contain -1 indices; these indicate non-annotated - # regions, but are valid indices in Python; need to recode - # them as missing - colors[annot==-1,:] = sulcmap[annot==-1,:] - colors = colors.astype(np.float32) - fmin = None - fmax = None - pos = None - neg = None - else: - colors = sulcmap - fmin = None - fmax = None - pos = None - neg = None - # concatenate matrices - vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) - return vertexdata, triangles, fmin, fmax, pos, neg - - -def init_window(width, height, title="PyOpenGL", visible=True): - """ - Create window with width, height, title. - - If visible False, hide window. - - Parameters - ---------- - width : int - Window width. - height : int - Window height. - title : str - Window title. - visible : bool - Window visibility. - - Returns - ------- - window: glfw.LP__GLFWwindow - GUI window. - """ - if not glfw.init(): - return False - - glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) - glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) - glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True) - glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) - if not visible: - glfw.window_hint(glfw.VISIBLE, glfw.FALSE) - window = glfw.create_window(width, height, title, None, None) - if not window: - glfw.terminate() - return False - # Enable key events - glfw.set_input_mode(window, glfw.STICKY_KEYS, gl.GL_TRUE) - # Enable key event callback - # glfw.set_key_callback(window,key_event) - glfw.make_context_current(window) - # vsync and glfw do not play nice. when vsync is enabled mouse movement is jittery. - glfw.swap_interval(0) - return window - - -def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): - """ - Create vertex and fragment shaders. - - Set up data and parameters (such as the initial view matrix) on the GPU. - - In meshdata: - - the first 3 columns are the vertex coordinates - - the next 3 columns are the vertex normals - - the final 3 columns are the color RGB values - - Parameters - ---------- - meshdata : numpy.ndarray - Mesh array (shape: n x 9, dtype: np.float32). - triangles : bool - Triangle indices array (shape: m x 3). - width : int - Window width (to set perspective projection). - height : int - Window height (to set perspective projection). - specular : Boolean - By default specular is set as True. - ambient : float - Ambient light strength, by default 0: use only diffuse light sources. - - Returns - ------- - shader: ShaderProgram - Compiled OpenGL shader program. - """ - - VERTEX_SHADER = """ - - #version 330 - - layout (location = 0) in vec3 aPos; - layout (location = 1) in vec3 aNormal; - layout (location = 2) in vec3 aColor; - - out vec3 FragPos; - out vec3 Normal; - out vec3 Color; - - uniform mat4 transform; - uniform mat4 model; - uniform mat4 view; - uniform mat4 projection; - - void main() - { - gl_Position = projection * view * model * transform * vec4(aPos, 1.0f); - FragPos = vec3(model * transform * vec4(aPos, 1.0)); - // normal matrix should be computed outside and passed! - Normal = mat3(transpose(inverse(view * model * transform))) * aNormal; - Color = aColor; - } - - """ - - FRAGMENT_SHADER = """ - #version 330 - - in vec3 Normal; - in vec3 FragPos; - in vec3 Color; - - out vec4 FragColor; - - uniform vec3 lightColor = vec3(1.0, 1.0, 1.0); - uniform bool doSpecular = true; - uniform float ambientStrength = 0.0; - - void main() - { - // ambient - vec3 ambient = ambientStrength * lightColor; - - // diffuse - vec3 norm = normalize(Normal); - // values for overhead, front, below, back lights - //vec4 diffweights = vec4(0.4, 0.6, 0.4, 0.4); //more light below - vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); //orig more shadow - - // key light (overhead) - vec3 lightPos1 = vec3(0.0,5.0,5.0); - vec3 lightDir = normalize(lightPos1 - FragPos); - float diff = max(dot(norm, lightDir), 0.0); - vec3 diffuse = diffweights[0] * diff * lightColor; - - // headlight (at camera) - vec3 lightPos2 = vec3(0.0,0.0,5.0); - lightDir = normalize(lightPos2 - FragPos); - vec3 ohlightDir = lightDir; // needed for specular - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[1] * diff * lightColor; - - // fill light (from below) - vec3 lightPos3 = vec3(0.0,-5.0,5.0); - lightDir = normalize(lightPos3 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[2] * diff * lightColor; - - // left right back lights (both are same brightness) - vec3 lightPos4 = vec3(5.0,0.0,-5.0); - lightDir = normalize(lightPos4 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[3] * diff * lightColor; - - vec3 lightPos5 = vec3(-5.0,0.0,-5.0); - lightDir = normalize(lightPos5 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[3] * diff * lightColor; - - // specular - vec3 result; - if (doSpecular) - { - float specularStrength = 0.5; - // the viewer is always at (0,0,0) in view-space, - // so viewDir is (0,0,0) - Position => -Position - vec3 viewDir = normalize(-FragPos); - vec3 reflectDir = reflect(ohlightDir, norm); - float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32); - vec3 specular = specularStrength * spec * lightColor; - // final color - result = (ambient + diffuse + specular) * Color; - } - else - { - // final color no specular - result = (ambient + diffuse) * Color; - } - FragColor = vec4(result, 1.0); - } - - """ - - # Create Vertex Buffer object in gpu - VBO = gl.glGenBuffers(1) - # Bind the buffer - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, VBO) - gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) - - # Create Vertex Array object - VAO = gl.glGenVertexArrays(1) - # Bind array - gl.glBindVertexArray(VAO) - gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) - - # Create Element Buffer Object - EBO = gl.glGenBuffers(1) - gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, EBO) - gl.glBufferData( - gl.GL_ELEMENT_ARRAY_BUFFER, triangles.nbytes, triangles, gl.GL_STATIC_DRAW - ) - - # Compile The Program and shaders - shader = gl.shaders.compileProgram( - shaders.compileShader(VERTEX_SHADER, gl.GL_VERTEX_SHADER), - shaders.compileShader(FRAGMENT_SHADER, gl.GL_FRAGMENT_SHADER), - ) - - # get the position from shader - position = gl.glGetAttribLocation(shader, "aPos") - gl.glVertexAttribPointer( - position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0) - ) - gl.glEnableVertexAttribArray(position) - - vnormalpos = gl.glGetAttribLocation(shader, "aNormal") - gl.glVertexAttribPointer( - vnormalpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(3 * 4) - ) - gl.glEnableVertexAttribArray(vnormalpos) - - colorpos = gl.glGetAttribLocation(shader, "aColor") - gl.glVertexAttribPointer( - colorpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(6 * 4) - ) - gl.glEnableVertexAttribArray(colorpos) - - gl.glUseProgram(shader) - - gl.glClearColor(0.0, 0.0, 0.0, 1.0) - gl.glEnable(gl.GL_DEPTH_TEST) - - # Creating Projection Matrix - view = pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, -5.0])) - projection = pyrr.matrix44.create_perspective_projection( - 20.0, width / height, 0.1, 100.0 - ) - model = pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0])) - - # Set matrices in vertex shader - view_loc = gl.glGetUniformLocation(shader, "view") - proj_loc = gl.glGetUniformLocation(shader, "projection") - model_loc = gl.glGetUniformLocation(shader, "model") - gl.glUniformMatrix4fv(view_loc, 1, gl.GL_FALSE, view) - gl.glUniformMatrix4fv(proj_loc, 1, gl.GL_FALSE, projection) - gl.glUniformMatrix4fv(model_loc, 1, gl.GL_FALSE, model) - - # setup doSpecular in fragment shader - specular_loc = gl.glGetUniformLocation(shader, "doSpecular") - gl.glUniform1i(specular_loc, specular) - - # setup light color in fragment shader - lightColor_loc = gl.glGetUniformLocation(shader, "lightColor") - gl.glUniform3f(lightColor_loc, 1.0, 1.0, 1.0) - - # setup ambient light strength (default=0) - ambientLight_loc = gl.glGetUniformLocation(shader, "ambientStrength") - gl.glUniform1f(ambientLight_loc, ambient) - - return shader - - -def capture_window(width, height): - """ - Capture the GL region (0,0) .. (width,height) into PIL Image. - - Parameters - ---------- - width : int - Window width. - height : int - Window height. - - Returns - ------- - image: PIL.Image.Image - Captured image. - """ - if sys.platform == "darwin": - # not sure why on mac the drawing area is 4 times as large (2x2): - width = 2 * width - height = 2 * height - gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) # may not be needed - img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) - image = Image.frombytes("RGB", (width, height), img_buf) - image = image.transpose(Image.FLIP_TOP_BOTTOM) - if sys.platform == "darwin": - image.thumbnail((0.5 * width, 0.5 * height), Image.Resampling.LANCZOS) - return image - -def text_size(caption, font): - """ - Get the size of the text. - - Parameters - ---------- - caption : str - Text that is to be rendered. - font : PIL.ImageFont.FreeTypeFont - Font of the labels. - - Returns - ------- - text_width: int - Width of the text in pixels. - text_height: int - Height of the text in pixels. - """ - dummy_img = Image.new("L", (1, 1)) - draw = ImageDraw.Draw(dummy_img) - bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt") - text_width = bbox[2] - bbox[0] - text_height = bbox[3] - bbox[1] - return text_width, text_height - -def get_colorbar_label_positions( - font, - labels, - colorbar_rect, - gapspace=0, - pos=True, - neg=True, - orientation=OrientationType.HORIZONTAL -): - """ - Get the positions of the labels for the colorbar. - - Parameters - ---------- - font : PIL.ImageFont.FreeTypeFont - Font of the labels. - labels : dict - Label texts that are to be rendered. - colorbar_rect : tuple - The coordinate values of the colorbar edges. - gapspace : int - Length of the gray space representing the threshold. Default : 0. - pos : bool - Show positive axis. Default: True. - neg : bool - Show negative axis. Default: True. - orientation : OrientationType - Orientation of the colorbar, can be OrientationType.HORIZONTAL or - OrientationType.VERTICAL. Default : OrientationType.HORIZONTAL. - - Returns - ------- - positions: dict - Positions of all labels. - """ - positions = {} - cb_x, cb_y, cb_width, cb_height = colorbar_rect - cb_labels_gap = 5 - - if orientation == OrientationType.HORIZONTAL: - label_y = cb_y + cb_height + cb_labels_gap - - # Upper - w, h = text_size(labels["upper"], font) - if pos: - positions["upper"] = (cb_x + cb_width - w, label_y) - else: - upper_x = cb_x + cb_width - w - int(gapspace) if gapspace > 0 else cb_x + cb_width - w - positions["upper"] = (upper_x, label_y) - - # Lower - w, h = text_size(labels["lower"], font) - if neg: - positions["lower"] = (cb_x, label_y) - else: - lower_x = cb_x + int(gapspace) if gapspace > 0 else cb_x - positions["lower"] = (lower_x, label_y) - - # Middle - if neg and pos: - if gapspace == 0: - # Single middle - w, h = text_size(labels["middle"], font) - positions["middle"] = (cb_x + cb_width // 2 - w // 2, label_y) - else: - # Middle Negative - w, h = text_size(labels["middle_neg"], font) - positions["middle_neg"] = (cb_x + cb_width // 2 - w - int(gapspace), label_y) - - # Middle Positive - w, h = text_size(labels["middle_pos"], font) - positions["middle_pos"] = (cb_x + cb_width // 2 + int(gapspace), label_y) - - else: # orientation == OrientationType.VERTICAL - label_x = cb_x + cb_width + cb_labels_gap - - # Upper - w, h = text_size(labels["upper"], font) - if pos: - positions["upper"] = (label_x, cb_y) - else: - upper_y = cb_y + int(gapspace) if gapspace > 0 else cb_y - positions["upper"] = (label_x, upper_y) - - # Lower - w, h = text_size(labels["lower"], font) - if neg: - positions["lower"] = (label_x, cb_y + cb_height - 1.5 * h) - else: - lower_y = cb_y + cb_height - int(gapspace) - 1.5 * h if gapspace > 0 else cb_y + cb_height - 1.5 * h - positions["lower"] = (label_x, lower_y) - - # Middle labels - if neg and pos: - if gapspace == 0: - # Single middle - w, h = text_size(labels["middle"], font) - positions["middle"] = (label_x, cb_y + cb_height // 2 - h // 2) - else: - # Middle Positive - w, h = text_size(labels["middle_pos"], font) - positions["middle_pos"] = (label_x, cb_y + cb_height // 2 - 1.5 * h - int(gapspace)) - - # Middle Negative - w, h = text_size(labels["middle_neg"], font) - positions["middle_neg"] = (label_x, cb_y + cb_height // 2 + int(gapspace)) - - return positions - -def create_colorbar( - fmin, - fmax, - invert, - orientation=OrientationType.HORIZONTAL, - colorbar_scale=1, - pos=True, - neg=True, - font_file=None -): - """ - Create colorbar image with text indicating min and max values. - - Parameters - ---------- - fmin : int - Absolute min value that receives color (threshold). - fmax : int - Absolute max value where color saturates. - invert : bool - Color invert. - orientation : OrientationType - Orientation of the colorbar, can be OrientationType.HORIZONTAL or - OrientationType.VERTICAL. Default : OrientationType.HORIZONTAL. - colorbar_scale : number - Colorbar scaling factor. Default: 1. - pos : bool - Show positive axis. - neg : bool - Show negative axis. - font_file : str - Path to the file describing the font to be used. - - Returns - ------- - image: PIL.Image.Image - Colorbar image. - """ - cwidth = int(200 * colorbar_scale) - cheight = int(30 * colorbar_scale) - gapspace = 0 - - # Add gray gap if needed - if fmin > 0.01: - # Leave gray gap - num = int(0.42 * cwidth) - gapspace = 0.08 * cwidth - else: - num = int(0.5 * cwidth) - if not neg or not pos: - num = num * 2 - gapspace = gapspace * 2 - - # Set the values for the colorbar - values = np.nan * np.ones(cwidth) - steps = np.linspace(0.01, 1, num) - if pos and not neg: - values[-steps.size :] = steps - elif not pos and neg: - values[: steps.size] = -1.0 * np.flip(steps) - else: - values[: steps.size] = -1.0 * np.flip(steps) - values[-steps.size :] = steps - - # Set the colors - colors = heat_color(values, invert) - colors[np.isnan(values), :] = 0.33 * np.ones((1, 3)) - img_bar = np.uint8(np.tile(colors, (cheight, 1, 1)) * 255) - - # Pad with black - pad_top, pad_left = 3, 10 - img_buf = np.zeros((cheight + 2 * pad_top, cwidth + 2 * pad_left, 3), dtype=np.uint8) - img_buf[pad_top : cheight + pad_top, pad_left : cwidth + pad_left, :] = img_bar - image = Image.fromarray(img_buf) - - # Get the font for the labels - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, int(12 * colorbar_scale)) - - # Labels for the colorbar - labels = {} - labels["upper"] = f">{fmax:.2f}" if pos else (f"{-fmin:.2f}" if gapspace != 0 else "0") - labels["lower"] = f"<{-fmax:.2f}" if neg else (f"{fmin:.2f}" if gapspace != 0 else "0") - if neg and pos and gapspace != 0: - labels["middle_neg"] = f"{-fmin:.2f}" - labels["middle_pos"] = f"{fmin:.2f}" - elif neg and pos and gapspace == 0: - labels["middle"] = "0" - - # Maximum caption sizes - caption_sizes = [text_size(caption, font) for caption in labels.values()] - max_caption_width = int(max([caption_size[0] for caption_size in caption_sizes])) - max_caption_height = int(max([caption_size[1] for caption_size in caption_sizes])) - - # Extend colorbar image by the maximum caption size to fit the labels and rotate image if needed - if orientation == OrientationType.VERTICAL: - image = image.rotate(90, expand=True) - - new_width = image.width + int(max_caption_width) - new_image = Image.new("RGB", (new_width, image.height), (0, 0, 0)) - new_image.paste(image, (0, 0)) - image = new_image - - colorbar_rect = (pad_top, pad_left, cheight, cwidth) - else: - new_height = image.height + int(max_caption_height * 2) - new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0)) - new_image.paste(image, (0, 0)) - image = new_image - - colorbar_rect = (pad_left, pad_top, cwidth, cheight) - - # Get positions of the labels - positions = get_colorbar_label_positions(font, labels, colorbar_rect, gapspace, pos, neg, orientation) - - # Draw the labels - draw = ImageDraw.Draw(image) - for label_key, position in positions.items(): - draw.text((int(position[0]), int(position[1])), labels[label_key], fill=(220, 220, 220), font=font) - - return image - -def snap1( - meshpath, - outpath, - overlaypath=None, - annotpath=None, - labelpath=None, - curvpath=None, - view=ViewType.LEFT, - viewmat=None, - width=None, - height=None, - fthresh=None, - fmax=None, - caption=None, - caption_x=None, - caption_y=None, - caption_scale=1, - invert=False, - colorbar=True, - colorbar_x=None, - colorbar_y=None, - colorbar_scale=1, - orientation=OrientationType.HORIZONTAL, - color_mode=ColorSelection.BOTH, - font_file=None, - specular=True, - brain_scale=1, - ambient=0.0, -): - """ - Snap one view (view and hemisphere is determined by the user). - - Colorbar, caption, and saving are optional. - - Parameters - ---------- - meshpath : str - Path to the surface file (FreeSurfer format). - outpath : str - Path to the output image file. - overlaypath : str - Path to the overlay file (FreeSurfer format). - annotpath : str - Path to the annotation file (FreeSurfer format). - labelpath : str - Path to the label file (FreeSurfer format). - curvpath : str - Path to the curvature file for texture in non-colored regions. - view : ViewType - Predefined views, can be ViewType.LEFT, ViewType.RIGHT, ViewType.BACK, - ViewType.FRONT, ViewType.TOP or ViewType.BOTTOM. Default: ViewType.LEFT. - viewmat : array-like - User-defined 4x4 viewing matrix. Overwrites view. - width : number - Width of the image. Default: automatically chosen. - height : number - Height of the image. Default: automatically chosen. - fthresh : float - Pos absolute value under which no color is shown. - fmax : float - Pos absolute value above which color is saturated. - caption : str - Caption text to be placed on the image. - caption_x : number - Normalized horizontal position of the caption. Default: automatically chosen. - caption_y : number - Normalized vertical position of the caption. Default: automatically chosen. - caption_scale : number - Caption scaling factor. Default: 1. - invert : bool - Invert color (blue positive, red negative). - colorbar : bool - Show colorbar on image. - colorbar_x : number - Normalized horizontal position of the colorbar. Default: automatically chosen. - colorbar_y : number - Normalized vertical position of the colorbar. Default: automatically chosen. - colorbar_scale : number - Colorbar scaling factor. Default: 1. - orientation : OrientationType - Orientation of the colorbar and caption, can be OrientationType.VERTICAL or - OrientationType.HORIZONTAL. Default: OrientationType.HORIZONTAL. - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - font_file : str - Path to the file describing the font to be used in captions. - specular : bool - Specular is by default set as True. - brain_scale : float - Brain scaling factor. Default: 1. - ambient : float - Ambient light, default 0, only use diffuse light sources. - - Returns - ------- - None - This function returns None. - """ - # Setup base image - REFWWIDTH = 700 - REFWHEIGHT = 500 - WWIDTH = REFWWIDTH if width is None else width - WHEIGHT = REFWHEIGHT if height is None else height - UI_SCALE = min(WWIDTH / REFWWIDTH, WHEIGHT / REFWHEIGHT) - - # Check screen resolution - if not glfw.init(): - print( - "[ERROR] Could not init glfw!" - ) - sys.exit(1) - primary_monitor = glfw.get_primary_monitor() - mode = glfw.get_video_mode(primary_monitor) - screen_width = mode.size.width - screen_height = mode.size.height - if WWIDTH > screen_width: - print( - f"[INFO] Requested width {WWIDTH} exceeds screen width {screen_width}, expect black bars" - ) - elif WHEIGHT > screen_height: - print( - f"[INFO] Requested height {WHEIGHT} exceeds screen height {screen_height}, expect black bars" - ) - - # Create the base image - image = Image.new("RGB", (WWIDTH, WHEIGHT)) - - # Setup brain image - # (keep aspect ratio, as the mesh scale and distances are set accordingly) - BWIDTH = int(540 * brain_scale * UI_SCALE) - BHEIGHT = int(450 * brain_scale * UI_SCALE) - brain_display_width = min(BWIDTH, WWIDTH) - brain_display_height = min(BHEIGHT, WHEIGHT) - - visible = True - window = init_window(brain_display_width, brain_display_height, "WhipperSnapPy 2.0", visible) - if not window: - return False # need raise error here in future - - viewLeft = np.array([[ 0, 0,-1, 0], [-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # left w top up // right - viewRight = np.array([[ 0, 0, 1, 0], [ 1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # right w top up // right - viewBack = np.array([[ 1, 0, 0, 0], [ 0, 0,-1, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # back w top up // back - viewFront = np.array([[-1, 0 ,0, 0], [ 0, 0, 1, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # front w top up // front - viewBottom = np.array([[-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0,-1, 0], [ 0, 0, 0, 1]]) # bottom ant up // bottom - viewTop = np.array([[ 1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 1, 0], [ 0, 0, 0, 1]]) # top w ant up // top - - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - - # Load and colorize data - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, - scale=brain_scale, color_mode=color_mode - ) - - # Check if there is data to display - if overlaypath is not None: - if color_mode == ColorSelection.POSITIVE: - if not pos and neg: - print( - "[Error] Overlay has no values to display with positive color_mode" - ) - sys.exit(1) - neg = False - elif color_mode == ColorSelection.NEGATIVE: - if pos and not neg: - print( - "[Error] Overlay has no values to display with negative color_mode" - ) - sys.exit(1) - pos = False - if not pos and not neg: - print( - "[Error] Overlay has no values to display" - ) - sys.exit(1) - - # Upload to GPU and compile shaders - shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, - specular=specular, ambient=ambient) - - # Draw - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - transformLoc = gl.glGetUniformLocation(shader, "transform") - if viewmat is None: - if view == ViewType.LEFT: - viewmat = transl * viewLeft - elif view == ViewType.RIGHT: - viewmat = transl * viewRight - elif view == ViewType.BACK: - viewmat = transl * viewBack - elif view == ViewType.FRONT: - viewmat = transl * viewFront - elif view == ViewType.BOTTOM: - viewmat = transl * viewBottom - elif view == ViewType.TOP: - viewmat = transl * viewTop - else: - viewmat = transl * viewmat - - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im1 = capture_window(brain_display_width, brain_display_height) - - # Center brain - brain_x = 0 if WWIDTH < BWIDTH else (WWIDTH - BWIDTH) // 2 - brain_y = 0 if WHEIGHT < BHEIGHT else (WHEIGHT - BHEIGHT) // 2 - - image.paste(im1, (brain_x, brain_y)) - - # Create colorbar - bar = None - bar_w = bar_h = 0 - if overlaypath is not None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * UI_SCALE, - pos, neg, font_file=font_file) - bar_w, bar_h = bar.size - - # Create caption - font = None - text_w = text_h = 0 - if caption: - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, 20 * caption_scale * UI_SCALE) - text_w, text_h = text_size(caption, font) - - text_w = int(text_w) - text_h = int(text_h) - - # Constants defining the position of the caption and colorbar - BOTTOM_PAD = int(20 * UI_SCALE) - RIGHT_PAD = int(20 * UI_SCALE) - GAP = int(4 * UI_SCALE) - - if orientation == OrientationType.HORIZONTAL: - # Place the colorbar - if bar is not None: - if colorbar_x is None: - bx = int(0.5 * (image.width - bar_w)) - else: - bx = int(colorbar_x * WWIDTH) - if colorbar_y is None: - gap_and_caption = (GAP + text_h) if caption and caption_y is None else 0 - by = image.height - BOTTOM_PAD - gap_and_caption - bar_h - else: - by = int(colorbar_y * WHEIGHT) - image.paste(bar, (bx, by)) - - # Place the caption - if caption: - if caption_x is None: - cx = int(0.5 * (image.width - text_w)) - else: - cx = int(caption_x * WWIDTH) - if caption_y is None: - cy = image.height - BOTTOM_PAD - text_h - else: - cy = int(caption_y * WHEIGHT) - ImageDraw.Draw(image).text( - (cx, cy), caption, (220, 220, 220), font=font, anchor="lt" - ) - else: # orientation == OrientationType.VERTICAL - # Place the colorbar - if bar is not None: - if colorbar_x is None: - gap_and_caption = (GAP + text_h) if caption and caption_x is None else 0 - bx = image.width - RIGHT_PAD - gap_and_caption - bar_w - else: - bx = int(colorbar_x * WWIDTH) - if colorbar_y is None: - by = int(0.5 * (image.height - bar_h)) - else: - by = int(colorbar_y * WHEIGHT) - image.paste(bar, (bx, by)) - - # Place the caption - if caption: - # Create a new transparent image and rotate it - temp_caption_img = Image.new("RGBA", (text_w, text_h), (0,0,0,0)) - ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") - rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0,0,0,0)) - rotated_w, rotated_h = rotated_caption.size - - if caption_x is None: - cx = image.width - RIGHT_PAD - rotated_w - else: - cx = int(caption_x * WWIDTH) - if caption_y is None: - cy = int(0.5 * (image.height - rotated_h)) - else: - cy = int(caption_y * WHEIGHT) - - image.paste(rotated_caption, (cx, cy), rotated_caption) - - # save image - print(f"[INFO] Saving snapshot to {outpath}") - image.save(outpath) - - glfw.terminate() - - return None - -def snap4( - lhoverlaypath=None, - rhoverlaypath=None, - lhannotpath=None, - rhannotpath=None, - fthresh=None, - fmax=None, - sdir=None, - caption=None, - invert=False, - labelname="cortex.label", - surfname=None, - curvname="curv", - colorbar=True, - outpath=None, - font_file=None, - specular=True, - ambient=0.0, -): - """ - Snap four views (front and back for left and right hemispheres). - - Save an image that includes the views and a color bar. - - Parameters - ---------- - lhoverlaypath : str - Path to the overlay files for left hemi (FreeSurfer format). - rhoverlaypath : str - Path to the overlay files for right hemi (FreeSurfer format). - lhannotpath : str - Path to the annotation files for left hemi (FreeSurfer format). - rhannotpath : str - Path to the annotation files for right hemi (FreeSurfer format). - fthresh : float - Pos absolute value under which no color is shown. - fmax : float - Pos absolute value above which color is saturated. - sdir : str - Subject dir containing surf files. - caption : str - Caption text to be placed on the image. - invert : bool - Invert color (blue positive, red negative). - labelname : str - Label for masking, usually cortex.label. - surfname : str - Surface to display values on, usually pial_semi_inflated from fsaverage. - curvname : str - Curvature file for texture in non-colored regions (default curv). - colorbar : bool - Show colorbar on image. Will be ignored for annotation files. - outpath : str - Path to the output image file. - font_file : str - Path to the file describing the font to be used in captions. - specular : bool - Specular is by default set as True. - ambient : float - Ambient light, default 0, only use diffuse light sources. - - Returns - ------- - None - This function returns None. - """ - # setup window - # (keep aspect ratio, as the mesh scale and distances are set accordingly) - wwidth = 540 - wheight = 450 - visible = True - window = init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible) - if not window: - return False # need raise error here in future - - # set up matrices to show object left and right side: - rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * math.pi) - rot_x = pyrr.Matrix44.from_x_rotation(0.5 * math.pi) - # rot_y = pyrr.Matrix44.from_y_rotation(math.pi/6) - viewLeft = rot_x * rot_z - rot_y = pyrr.Matrix44.from_y_rotation(math.pi) - viewRight = rot_y * viewLeft - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - - for hemi in ("lh", "rh"): - if surfname is None: - print( - "[INFO] No surf_name provided. Looking for options in surf directory..." - ) - - if sdir is None: - sdir = os.environ.get("SUBJECTS_DIR") - if not sdir: - print( - "[INFO] No surf_name or subjects directory (sdir) \ -provided, can not find surf file" - ) - sys.exit(1) - - found_surfname = get_surf_name(sdir, hemi) - - if found_surfname is None: - print( - f"[ERROR] Could not find valid surface in {sdir} for hemi: {hemi}!" - ) - sys.exit(1) - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) - else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) - - curvpath = None - if curvname: - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) - labelpath = None - if labelname: - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) - if hemi == "lh": - overlaypath = lhoverlaypath - annotpath = lhannotpath - else: - overlaypath = rhoverlaypath - annotpath = rhannotpath - - # load and colorize data - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert - ) - - # Check if there is something to display - if pos == 0 and neg == 0: - print( - "[Error] Overlay has no values to display" - ) - sys.exit(1) - - # upload to GPU and compile shaders - shader = setup_shader(meshdata, triangles, wwidth, wheight, - specular=specular, ambient=ambient) - - # draw - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - transformLoc = gl.glGetUniformLocation(shader, "transform") - viewmat = viewLeft - if hemi == "lh": - viewmat = transl * viewmat - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im1 = capture_window(wwidth, wheight) - - glfw.swap_buffers(window) - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - viewmat = viewRight - if hemi == "rh": - viewmat = transl * viewmat - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im2 = capture_window(wwidth, wheight) - - if hemi == "lh": - lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - lhimg.paste(im1, (0, 0)) - lhimg.paste(im2, (0, im1.height)) - else: - rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - rhimg.paste(im2, (0, 0)) - rhimg.paste(im1, (0, im2.height)) - - image = Image.new("RGB", (lhimg.width + rhimg.width, lhimg.height)) - image.paste(lhimg, (0, 0)) - image.paste(rhimg, (im1.width, 0)) - - if caption: - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, 20) - xpos = 0.5 * (image.width - font.getlength(caption)) - ImageDraw.Draw(image).text( - (xpos, image.height - 40), caption, (220, 220, 220), font=font - ) - - if lhannotpath is None and rhannotpath is None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) - xpos = int(0.5 * (image.width - bar.width)) - ypos = int(0.5 * (image.height - bar.height)) - image.paste(bar, (xpos, ypos)) - - if outpath: - print(f"[INFO] Saving snapshot to {outpath}") - image.save(outpath) - - glfw.terminate() - - return None - -def get_surf_name(sdir, hemi): - """ - Find a valid surface file in the specified subject directory. - - A valid file can be one of: ['pial_semi_inflated', 'white', 'inflated']. - - Parameters - ---------- - sdir : str - Subject directory. - hemi : str - Hemisphere; one of: ['lh', 'rh']. - - Returns - ------- - surfname: str - Valid and existing surf file's name; otherwise, None. - """ - for surf_name_option in ["pial_semi_inflated", "white", "inflated"]: - if os.path.exists(os.path.join(sdir, "surf", hemi + "." + surf_name_option)): - print("[INFO] Found {}".format(hemi + "." + surf_name_option)) - return surf_name_option - else: - print("[INFO] No {} file found".format(hemi + "." + surf_name_option)) - else: - return None diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py new file mode 100644 index 0000000..3307510 --- /dev/null +++ b/whippersnappy/geometry/__init__.py @@ -0,0 +1,83 @@ +"""Geometry subpackage — mesh IO, overlay IO, and rendering preparation. + +Architecture +------------ +The subpackage has three layers: + +**Layer 1 — low-level format readers** (one file per format family): + +* :mod:`~whippersnappy.geometry.freesurfer_io` — FreeSurfer binary formats: + surface geometry (``read_geometry``), morphometry scalars + (``read_morph_data``), MGH overlays (``read_mgh_data``), and annotation + files (``read_annot_data``). Contains derived nibabel code (MIT licence). +* :mod:`~whippersnappy.geometry.mesh_io` — open ASCII mesh formats: + OFF (``read_off``), legacy VTK PolyData (``read_vtk_ascii_polydata``), + ASCII PLY (``read_ply_ascii``), and GIfTI surface (``read_gifti_surface``). + Pure stdlib + numpy, except GIfTI which uses nibabel. +* :mod:`~whippersnappy.geometry.overlay_io` — open scalar/label formats: + plain ASCII (``read_txt``), NumPy binary (``read_npy``, ``read_npz``), + and GIfTI functional/label (``read_gifti``). Pure stdlib + numpy, except + GIfTI. + +Each family also exposes a dispatcher (``read_mesh`` / ``read_overlay``) +that routes by file extension. + +**Layer 2 — resolvers** (:mod:`~whippersnappy.geometry.inputs`): + +``resolve_mesh``, ``resolve_overlay``, ``resolve_bg_map``, ``resolve_roi``, +``resolve_annot`` — the **single public interface** for the rest of the +package. Each resolver accepts a file path *or* a numpy array *or* ``None``, +dispatches to the correct layer-1 reader, validates shapes and dtypes, and +returns a clean numpy array. All higher-level code (``prepare.py``, +``snap.py``, ``plot3d.py``, CLIs) should go through resolvers only. + +**Layer 3 — geometry preparation** (:mod:`~whippersnappy.geometry.prepare`): + +``prepare_geometry`` / ``prepare_geometry_from_arrays`` — load, normalise, +colour, and pack vertex data into the GPU-ready format consumed by the +OpenGL shaders. +""" +from .freesurfer_io import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi +from .mesh_io import read_gifti_surface, read_mesh, read_off, read_ply_ascii, read_vtk_ascii_polydata +from .overlay_io import read_gifti, read_npy, read_npz, read_overlay, read_txt +from .prepare import ( + estimate_overlay_thresholds, + prepare_and_validate_geometry, + prepare_geometry, + prepare_geometry_from_arrays, +) +from .surf_name import get_surf_name + +__all__ = [ + # Layer 3 — geometry preparation + 'prepare_geometry', + 'prepare_geometry_from_arrays', + 'prepare_and_validate_geometry', + 'estimate_overlay_thresholds', + # Layer 2 — resolvers (preferred public interface) + 'resolve_mesh', + 'resolve_overlay', + 'resolve_bg_map', + 'resolve_roi', + 'resolve_annot', + # Layer 1 — mesh readers + 'read_mesh', + 'read_off', + 'read_vtk_ascii_polydata', + 'read_ply_ascii', + 'read_gifti_surface', + # Layer 1 — overlay/scalar readers + 'read_overlay', + 'read_txt', + 'read_npy', + 'read_npz', + 'read_gifti', + # Layer 1 — FreeSurfer binary readers + 'read_geometry', + 'read_morph_data', + 'read_mgh_data', + 'read_annot_data', + # Utilities + 'get_surf_name', +] diff --git a/whippersnappy/read_geometry.py b/whippersnappy/geometry/freesurfer_io.py similarity index 82% rename from whippersnappy/read_geometry.py rename to whippersnappy/geometry/freesurfer_io.py index 2387ada..7d7fd39 100644 --- a/whippersnappy/read_geometry.py +++ b/whippersnappy/geometry/freesurfer_io.py @@ -1,8 +1,26 @@ -"""Read FreeSurfer geometry (fix for dev, ll 126-128); - -Code was taken from nibabel.freesurfer package -(https://github.com/nipy/nibabel/blob/master/nibabel/freesurfer/io.py). -This software is licensed under the following license: +"""FreeSurfer binary IO — surface geometry, morphometry, MGH, and annotation. + +This module contains readers for all FreeSurfer binary formats: + +* :func:`read_geometry` — triangle surface (e.g. ``lh.white``, ``rh.pial``) +* :func:`read_morph_data` — per-vertex morphometry scalar (e.g. ``lh.curv``, + ``lh.thickness``); the format FreeSurfer internally calls "curv files" +* :func:`read_mgh_data` — MGH/MGZ volumetric image used as per-vertex + overlay (shape ``N×1×1``) +* :func:`read_annot_data` — FreeSurfer parcellation annotation (``.annot``) + +These functions are **low-level readers**. In normal use you should go +through the resolver functions in :mod:`whippersnappy.geometry.inputs` +(:func:`~whippersnappy.geometry.inputs.resolve_mesh`, +:func:`~whippersnappy.geometry.inputs.resolve_overlay`, etc.) which handle +format dispatch, dtype conversion, and shape validation for you. + +License notice +-------------- +The binary parsing code in this module is derived from the ``nibabel`` +FreeSurfer IO module +(https://github.com/nipy/nibabel/blob/master/nibabel/freesurfer/io.py), +used under the MIT licence reproduced below. The MIT License @@ -14,12 +32,12 @@ Copyright (c) 2011-2019 Yaroslav Halchenko Copyright (c) 2015-2019 Chris Markiewicz -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. @@ -28,9 +46,9 @@ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. """ import warnings diff --git a/whippersnappy/geometry/inputs.py b/whippersnappy/geometry/inputs.py new file mode 100644 index 0000000..3c3afa5 --- /dev/null +++ b/whippersnappy/geometry/inputs.py @@ -0,0 +1,301 @@ +"""Input resolver functions for WhipperSnapPy geometry loading. + +This module is the **single entry point** for all input loading and +validation. No other module should call ``read_geometry``, +``read_morph_data``, ``read_mgh_data``, ``read_annot_data``, or +``mask_label`` directly — all calls go through the resolver functions +defined here. + +Each resolver accepts ``None``, a file path (``str``), or a numpy +array-like, validates shape and dtype, and returns a clean numpy array +(or ``None``). Format dispatch is handled internally by: + +* :mod:`~whippersnappy.geometry.freesurfer_io` — FreeSurfer binary formats +* :mod:`~whippersnappy.geometry.mesh_io` — OFF / VTK / PLY / GIfTI surfaces +* :mod:`~whippersnappy.geometry.overlay_io` — TXT / CSV / NPY / NPZ / GIfTI scalars +""" + +import os + +import numpy as np + +from ..utils.colormap import mask_label +from .freesurfer_io import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from .mesh_io import read_mesh as _read_mesh_by_ext +from .overlay_io import read_overlay as _read_overlay_by_ext + +# Extensions handled by the lightweight ASCII mesh readers in mesh_io.py +# (includes GIfTI surface via nibabel) +_MESH_IO_EXTS = frozenset({".off", ".vtk", ".ply", ".gii"}) + +# Extensions handled by overlay_io.py (everything except FreeSurfer morph / MGH) +_OVERLAY_IO_EXTS = frozenset({".txt", ".csv", ".npy", ".npz", ".gii"}) + + +def resolve_mesh(mesh): + """Resolve a mesh input to ``(vertices, faces)`` numpy arrays. + + Parameters + ---------- + mesh : str or tuple/list of two array-likes + * ``str`` — path to a mesh file. Files with extensions ``.off``, + ``.vtk``, or ``.ply`` are loaded by the lightweight ASCII readers + in :mod:`whippersnappy.geometry.mesh_io`. All other paths (e.g. + FreeSurfer surfaces such as ``lh.white``) are loaded via + :func:`whippersnappy.geometry.read_geometry`. + * Two-element tuple/list — ``(vertices, faces)`` array-likes + converted to ``float32`` and ``uint32`` numpy arrays respectively. + + Returns + ------- + vertices : numpy.ndarray + Vertex coordinate array of shape (N, 3), dtype float32. + faces : numpy.ndarray + Triangle face index array of shape (M, 3), dtype uint32. + + Raises + ------ + TypeError + If *mesh* is neither a ``str`` nor a two-element tuple/list. + ValueError + If the resulting arrays do not have the expected shapes or if face + indices are out of range. + """ + if isinstance(mesh, str): + lower = mesh.lower() + # Compound extension must be checked before os.path.splitext + if lower.endswith(".surf.gii") or os.path.splitext(lower)[1] in _MESH_IO_EXTS: + vertices, faces = _read_mesh_by_ext(mesh) + else: + v_raw, f_raw = read_geometry(mesh, read_metadata=False) + vertices = np.asarray(v_raw, dtype=np.float32) + faces = np.asarray(f_raw, dtype=np.uint32) + elif isinstance(mesh, (tuple, list)) and len(mesh) == 2: + vertices = np.asarray(mesh[0], dtype=np.float32) + faces = np.asarray(mesh[1], dtype=np.uint32) + else: + raise TypeError( + f"mesh must be a file path (str) or a (vertices, faces) tuple/list, " + f"got {type(mesh).__name__!r}." + ) + + if vertices.ndim != 2 or vertices.shape[1] != 3: + raise ValueError( + f"vertices must be an array of shape (N, 3), got shape {vertices.shape}." + ) + if faces.ndim != 2 or faces.shape[1] != 3: + raise ValueError( + f"faces must be an array of shape (M, 3), got shape {faces.shape}." + ) + # Bounds check for array inputs (file readers do their own check) + if not isinstance(mesh, str) and faces.size > 0: + n_verts = vertices.shape[0] + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"Face indices out of range [0, {n_verts}): " + f"min={int(faces.min())}, max={int(faces.max())}." + ) + return vertices, faces + + +def _load_overlay_from_file(path): + """Load a 1-D per-vertex overlay array from a file path. + + Routing logic: + + * ``.mgh`` / ``.mgz`` → :func:`read_mgh_data` + * ``.txt``, ``.csv``, ``.npy``, ``.npz``, ``.gii``, + ``.func.gii``, ``.label.gii`` → :func:`overlay_io.read_overlay` + * anything else → :func:`read_morph_data` (FreeSurfer binary morph) + """ + lower = path.lower() + # Compound GIfTI extensions must be checked before splitext + if lower.endswith(".func.gii") or lower.endswith(".label.gii"): + return _read_overlay_by_ext(path) + _, ext = os.path.splitext(lower) + if ext in (".mgh", ".mgz"): + return read_mgh_data(path) + if ext in _OVERLAY_IO_EXTS: + return _read_overlay_by_ext(path) + # Default: FreeSurfer binary morph (curv, thickness, etc. — often no ext) + return read_morph_data(path) + + +def resolve_overlay(overlay, *, n_vertices): + """Resolve an overlay input to a 1-D float32 numpy array, or ``None``. + + Parameters + ---------- + overlay : None, str, or array-like + * ``None`` — no overlay; returns ``None``. + * ``str`` — path to an overlay file (.mgh or FreeSurfer morph format). + * array-like — converted to ``np.float32``; must have shape + ``(n_vertices,)`` when *n_vertices* is not ``None``. + n_vertices : int or None + Expected number of vertices. Shape validation is skipped when + ``None`` (useful for ``estimate_overlay_thresholds``). + + Returns + ------- + numpy.ndarray of shape (n_vertices,) or None + + Raises + ------ + ValueError + If the loaded/converted array does not match *n_vertices*. + """ + if overlay is None: + return None + if isinstance(overlay, str): + arr = _load_overlay_from_file(overlay).astype(np.float32) + else: + arr = np.asarray(overlay, dtype=np.float32) + if n_vertices is not None and arr.shape != (n_vertices,): + raise ValueError( + f"overlay has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_bg_map(bg_map, *, n_vertices): + """Resolve a background-map input to a 1-D float32 numpy array, or ``None``. + + Identical logic to :func:`resolve_overlay`. + + Parameters + ---------- + bg_map : None, str, or array-like + Background shading data (typically curvature). + n_vertices : int or None + Expected number of vertices for shape validation. + + Returns + ------- + numpy.ndarray of shape (n_vertices,) or None + """ + if bg_map is None: + return None + if isinstance(bg_map, str): + arr = _load_overlay_from_file(bg_map).astype(np.float32) + else: + arr = np.asarray(bg_map, dtype=np.float32) + if n_vertices is not None and arr.shape != (n_vertices,): + raise ValueError( + f"bg_map has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_roi(roi, *, n_vertices): + """Resolve a region-of-interest input to a boolean numpy array, or ``None``. + + The returned boolean array has ``True`` for vertices that are *included* + in the overlay coloring and ``False`` for vertices that fall back to + background (``bg_map``) shading. + + Parameters + ---------- + roi : None, str, or array-like + * ``None`` — no masking; returns ``None``. + * ``str`` — file path. Routing by extension: + + - ``.txt``, ``.csv``, ``.npy``, ``.npz``, ``.gii``, + ``.func.gii``, ``.label.gii`` — loaded via + :func:`_load_overlay_from_file` and cast to ``bool``. + Non-zero / ``True`` values → vertex included. + - Any other path (including FreeSurfer label files with no + standard extension, e.g. ``lh.cortex.label``) — loaded via + :func:`~whippersnappy.utils.colormap.mask_label`. + + * array-like — converted to ``np.bool_``; must have shape + ``(n_vertices,)``. + n_vertices : int + Expected number of vertices. + + Returns + ------- + numpy.ndarray of shape (n_vertices,) bool, or None + + Raises + ------ + ValueError + If the resolved array does not match *n_vertices*. + """ + if roi is None: + return None + if isinstance(roi, str): + lower = roi.lower() + use_overlay_io = ( + lower.endswith(".func.gii") + or lower.endswith(".label.gii") + or os.path.splitext(lower)[1] in _OVERLAY_IO_EXTS + ) + if use_overlay_io: + raw = _load_overlay_from_file(roi) + arr = raw.astype(bool) + else: + # FreeSurfer label file: vertices listed → True, rest → False + sentinel = np.ones(n_vertices, dtype=np.float32) + masked = mask_label(sentinel, roi) + arr = ~np.isnan(masked) + else: + arr = np.asarray(roi, dtype=bool) + if arr.shape != (n_vertices,): + raise ValueError( + f"roi has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_annot(annot, *, n_vertices): + """Resolve an annotation input to ``(labels, ctab, names)`` or ``None``. + + Parameters + ---------- + annot : None, str, or tuple of length 2 or 3 + * ``None`` — no annotation; returns ``None``. + * ``str`` — path to a FreeSurfer .annot file; loaded via + :func:`read_annot_data`. + * 2-tuple ``(labels, ctab)`` — validated and returned as + ``(labels, ctab, None)``. + * 3-tuple ``(labels, ctab, names)`` — validated and returned as-is. + n_vertices : int + Expected number of vertices for shape validation of *labels*. + + Returns + ------- + tuple of (labels, ctab, names) or None + * *labels* — integer array of shape (n_vertices,). + * *ctab* — color table array of shape (n_labels, ≥3). + * *names* — list of label names, or ``None``. + + Raises + ------ + TypeError + If *annot* is not one of the accepted types. + ValueError + If *labels* does not match *n_vertices*. + """ + if annot is None: + return None + if isinstance(annot, str): + labels, ctab, names = read_annot_data(annot) + elif isinstance(annot, (tuple, list)) and len(annot) == 2: + labels = np.asarray(annot[0]) + ctab = np.asarray(annot[1]) + names = None + elif isinstance(annot, (tuple, list)) and len(annot) == 3: + labels = np.asarray(annot[0]) + ctab = np.asarray(annot[1]) + names = annot[2] + else: + raise TypeError( + f"annot must be a file path (str), a (labels, ctab) tuple, " + f"or a (labels, ctab, names) tuple; got {type(annot).__name__!r}." + ) + if labels.shape != (n_vertices,): + raise ValueError( + f"annot labels have shape {labels.shape} but mesh has {n_vertices} vertices." + ) + return labels, ctab, names + diff --git a/whippersnappy/geometry/mesh_io.py b/whippersnappy/geometry/mesh_io.py new file mode 100644 index 0000000..105b7af --- /dev/null +++ b/whippersnappy/geometry/mesh_io.py @@ -0,0 +1,621 @@ +"""Lightweight ASCII mesh readers for common open formats. + +This module implements pure-Python (stdlib + numpy only) readers for: + +* **OFF** — Object File Format, ASCII triangles +* **VTK legacy ASCII PolyData** — ``DATASET POLYDATA`` with POINTS/POLYGONS +* **PLY ASCII** — Stanford PLY, ASCII encoding, triangles only + +And a nibabel-backed reader for: + +* **GIfTI surface** (``.surf.gii`` / ``.gii``) — loaded via nibabel; + requires the ``NIFTI_INTENT_POINTSET`` (1008) + ``NIFTI_INTENT_TRIANGLE`` + (1009) data arrays that every standard surface GIfTI file contains. + +All readers return ``(vertices, faces)`` where + +* ``vertices`` — ``float32`` array of shape ``(N, 3)`` +* ``faces`` — ``uint32`` array of shape ``(M, 3)`` + +The public dispatcher :func:`read_mesh` routes by file extension +(``.off``, ``.vtk``, ``.ply``, ``.surf.gii``, ``.gii``). For FreeSurfer +surfaces (no standard extension) use the existing +:func:`whippersnappy.geometry.read_geometry` directly, or go through +:func:`whippersnappy.geometry.inputs.resolve_mesh` which handles the routing +automatically. +""" + +import numpy as np + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _non_empty_lines(path): + """Yield stripped, non-empty, non-comment lines from a text file.""" + with open(path, encoding="utf-8", errors="replace") as fh: + for raw in fh: + line = raw.strip() + if line and not line.startswith("#"): + yield line + + +# --------------------------------------------------------------------------- +# OFF reader +# --------------------------------------------------------------------------- + +def read_off(path): + """Read an ASCII OFF (Object File Format) triangle mesh. + + Parameters + ---------- + path : str + Path to the ``.off`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file does not start with ``OFF``, if any face is not a + triangle, or if the declared counts don't match the data. + IOError + If the file cannot be opened. + + Notes + ----- + Comments (lines starting with ``#``) and blank lines are ignored + anywhere in the file. The optional ``COFF`` / ``NOFF`` / ``CNOFF`` + variants are *not* supported; only plain ``OFF`` is accepted. + """ + lines = list(_non_empty_lines(path)) + if not lines: + raise ValueError(f"OFF file is empty: {path!r}") + + # First line must be exactly "OFF" + header = lines[0].upper() + if header != "OFF": + raise ValueError( + f"Expected 'OFF' header on first non-comment line, got {lines[0]!r} " + f"in {path!r}. Only plain ASCII OFF is supported." + ) + + if len(lines) < 2: + raise ValueError(f"OFF file has no count line after header: {path!r}") + + # Second line: n_vertices n_faces n_edges + parts = lines[1].split() + if len(parts) < 2: + raise ValueError( + f"OFF count line must have at least 2 integers (n_vertices n_faces), " + f"got {lines[1]!r} in {path!r}." + ) + try: + n_verts = int(parts[0]) + n_faces = int(parts[1]) + except ValueError as exc: + raise ValueError( + f"Could not parse OFF count line {lines[1]!r} in {path!r}." + ) from exc + + # Validate we have enough data lines + data_lines = lines[2:] + if len(data_lines) < n_verts + n_faces: + raise ValueError( + f"OFF file declares {n_verts} vertices and {n_faces} faces " + f"but only {len(data_lines)} data lines follow in {path!r}." + ) + + # Parse vertices + vertices = np.empty((n_verts, 3), dtype=np.float32) + for i in range(n_verts): + try: + coords = data_lines[i].split()[:3] + vertices[i] = [float(c) for c in coords] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse vertex {i} in OFF file {path!r}: " + f"{data_lines[i]!r}" + ) from exc + + # Parse faces + faces = np.empty((n_faces, 3), dtype=np.uint32) + for j in range(n_faces): + try: + tokens = data_lines[n_verts + j].split() + n_poly = int(tokens[0]) + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse face {j} in OFF file {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + if n_poly != 3: + raise ValueError( + f"OFF face {j} has {n_poly} vertices; only triangles (3) are " + f"supported in {path!r}. Convert to a triangle mesh first." + ) + try: + faces[j] = [int(tokens[k]) for k in range(1, 4)] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse face indices at face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + + # Bounds check + if n_faces > 0: + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"OFF face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# Legacy VTK ASCII PolyData reader +# --------------------------------------------------------------------------- + +def read_vtk_ascii_polydata(path): + """Read a legacy ASCII VTK PolyData triangle mesh. + + Only the *ASCII legacy format* with ``DATASET POLYDATA`` is supported. + Binary VTK files are explicitly rejected. + + Parameters + ---------- + path : str + Path to the ``.vtk`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file is binary, is not POLYDATA, contains non-triangle + polygons, or if required sections are missing. + IOError + If the file cannot be opened. + """ + with open(path, encoding="utf-8", errors="replace") as fh: + raw_lines = fh.readlines() + + # Legacy VTK format: + # line 0: "# vtk DataFile Version x.x" + # line 1: title (arbitrary free text) + # line 2: "ASCII" or "BINARY" + # line 3: "DATASET " + if len(raw_lines) < 3: + raise ValueError(f"VTK file too short: {path!r}") + + fmt_line = raw_lines[2].strip().upper() + if "BINARY" in fmt_line: + raise ValueError( + f"Only ASCII legacy VTK POLYDATA is supported; " + f"file appears to be BINARY: {path!r}. " + f"Convert with: vtk-convert or meshio-convert." + ) + if "ASCII" not in fmt_line: + raise ValueError( + f"Could not determine VTK format from line 3 (expected 'ASCII' or " + f"'BINARY'): {raw_lines[2]!r} in {path!r}." + ) + + # Scan for DATASET POLYDATA + dataset_found = False + for line in raw_lines: + if line.strip().upper().startswith("DATASET"): + if "POLYDATA" not in line.upper(): + raise ValueError( + f"Only POLYDATA VTK datasets are supported, " + f"got: {line.strip()!r} in {path!r}." + ) + dataset_found = True + break + if not dataset_found: + raise ValueError(f"No DATASET line found in VTK file {path!r}.") + + # Tokenise everything into a flat list for easy sectioned parsing + lines = [raw_ln.strip() for raw_ln in raw_lines if raw_ln.strip() and not raw_ln.strip().startswith("#")] + + vertices = None + faces = None + i = 0 + while i < len(lines): + upper = lines[i].upper() + + if upper.startswith("POINTS"): + parts = lines[i].split() + n_pts = int(parts[1]) + # Collect 3*n_pts floats; they may span multiple lines + floats = [] + i += 1 + while len(floats) < 3 * n_pts and i < len(lines): + # Stop at next keyword section + if lines[i].upper().split()[0] in ( + "POLYGONS", "LINES", "STRIPS", "VERTICES", + "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS", + "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE", + ): + break + floats.extend(float(x) for x in lines[i].split()) + i += 1 + if len(floats) < 3 * n_pts: + raise ValueError( + f"Expected {3 * n_pts} floats for POINTS but got " + f"{len(floats)} in {path!r}." + ) + vertices = np.array(floats[: 3 * n_pts], dtype=np.float32).reshape(n_pts, 3) + continue # i already advanced + + elif upper.startswith("POLYGONS"): + parts = lines[i].split() + n_polys = int(parts[1]) + face_list = [] + i += 1 + while len(face_list) < n_polys and i < len(lines): + if lines[i].upper().split()[0] in ( + "POINTS", "LINES", "STRIPS", "VERTICES", + "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS", + "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE", + ): + break + tokens = lines[i].split() + n_poly = int(tokens[0]) + if n_poly != 3: + raise ValueError( + f"VTK polygon {len(face_list)} has {n_poly} vertices; " + f"only triangles (3) are supported in {path!r}. " + f"Triangulate the mesh before loading." + ) + face_list.append([int(tokens[1]), int(tokens[2]), int(tokens[3])]) + i += 1 + if len(face_list) < n_polys: + raise ValueError( + f"Expected {n_polys} polygons but only parsed " + f"{len(face_list)} in {path!r}." + ) + faces = np.array(face_list, dtype=np.uint32) + continue # i already advanced + + else: + i += 1 + + if vertices is None: + raise ValueError(f"No POINTS section found in VTK file {path!r}.") + if faces is None: + raise ValueError(f"No POLYGONS section found in VTK file {path!r}.") + + # Bounds check + n_verts = vertices.shape[0] + if faces.size > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0): + raise ValueError( + f"VTK face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# PLY ASCII reader +# --------------------------------------------------------------------------- + +def read_ply_ascii(path): + """Read an ASCII PLY triangle mesh. + + Only ASCII PLY files are supported. Binary PLY files are explicitly + rejected with a helpful error message. + + Parameters + ---------- + path : str + Path to the ``.ply`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file is binary PLY, if faces are not triangles, or if the + header is malformed. + IOError + If the file cannot be opened. + """ + with open(path, encoding="utf-8", errors="replace") as fh: + raw_lines = fh.readlines() + + if not raw_lines or raw_lines[0].strip() != "ply": + raise ValueError( + f"File does not start with 'ply' magic; not a PLY file: {path!r}." + ) + + # Check encoding + for line in raw_lines[1:4]: + stripped = line.strip().lower() + if stripped.startswith("format"): + if "ascii" not in stripped: + raise ValueError( + f"PLY binary format not supported; only ASCII PLY is " + f"accepted: {path!r}. " + f"Convert with: plyconvert or meshio-convert." + ) + break + + # Parse header + n_verts = None + n_faces = None + vertex_props = [] # ordered list of property names for the vertex element + in_vertex = False + header_end = 0 + + for idx, line in enumerate(raw_lines): + stripped = line.strip() + lower = stripped.lower() + + if lower == "end_header": + header_end = idx + 1 + break + if lower.startswith("element vertex"): + n_verts = int(stripped.split()[2]) + in_vertex = True + elif lower.startswith("element face"): + n_faces = int(stripped.split()[2]) + in_vertex = False + elif lower.startswith("property") and in_vertex: + # e.g. "property float x" + parts = stripped.split() + if len(parts) >= 3: + vertex_props.append(parts[-1]) + elif (lower.startswith("element") + and not lower.startswith("element vertex") + and not lower.startswith("element face")): + in_vertex = False + + if n_verts is None: + raise ValueError(f"No 'element vertex' found in PLY header: {path!r}.") + if n_faces is None: + raise ValueError(f"No 'element face' found in PLY header: {path!r}.") + + # Determine column indices for x, y, z + try: + xi = vertex_props.index("x") + yi = vertex_props.index("y") + zi = vertex_props.index("z") + except ValueError as exc: + raise ValueError( + f"PLY vertex element missing x/y/z properties in {path!r}; " + f"found: {vertex_props!r}." + ) from exc + + data_lines = [raw_line.strip() for raw_line in raw_lines[header_end:] if raw_line.strip()] + + if len(data_lines) < n_verts + n_faces: + raise ValueError( + f"PLY file has {len(data_lines)} data lines but expects " + f"{n_verts} vertices + {n_faces} faces in {path!r}." + ) + + # Parse vertices + vertices = np.empty((n_verts, 3), dtype=np.float32) + for i in range(n_verts): + try: + tokens = data_lines[i].split() + vertices[i] = [float(tokens[xi]), float(tokens[yi]), float(tokens[zi])] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY vertex {i} in {path!r}: " + f"{data_lines[i]!r}" + ) from exc + + # Parse faces + faces = np.empty((n_faces, 3), dtype=np.uint32) + for j in range(n_faces): + try: + tokens = data_lines[n_verts + j].split() + count = int(tokens[0]) + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + if count != 3: + raise ValueError( + f"PLY face {j} has {count} vertices; only triangles (3) are " + f"supported in {path!r}. Triangulate the mesh first." + ) + try: + faces[j] = [int(tokens[1]), int(tokens[2]), int(tokens[3])] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY face indices at face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + + # Bounds check + if n_faces > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0): + raise ValueError( + f"PLY face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# GIfTI surface reader +# --------------------------------------------------------------------------- + +# NIFTI intent codes relevant to GIfTI surface files +_GIFTI_INTENT_POINTSET = 1008 # NIFTI_INTENT_POINTSET — vertex coordinates +_GIFTI_INTENT_TRIANGLE = 1009 # NIFTI_INTENT_TRIANGLE — face indices + + +def read_gifti_surface(path): + """Read a GIfTI surface file (``.surf.gii`` or ``.gii``). + + A standard GIfTI surface file contains exactly two data arrays: + + * One with intent ``NIFTI_INTENT_POINTSET`` (1008) — vertex coordinates, + shape ``(N, 3)``, dtype ``float32``. + * One with intent ``NIFTI_INTENT_TRIANGLE`` (1009) — face indices, + shape ``(M, 3)``, dtype ``int32``. + + This is the format produced by FreeSurfer's ``mris_convert``, + Connectome Workbench, fMRIPrep, and most HCP pipelines. + + Parameters + ---------- + path : str + Path to a GIfTI surface file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ImportError + If ``nibabel`` is not installed. + ValueError + If the file does not contain the expected POINTSET and TRIANGLE + data arrays, or if the arrays have unexpected shapes. + Also raised if the file appears to be a functional/label GIfTI + rather than a surface GIfTI (i.e. no POINTSET array found) — in + that case, use ``resolve_overlay()`` instead. + IOError + If the file cannot be opened or is not a valid GIfTI file. + + Examples + -------- + >>> v, f = read_gifti_surface('lh.white.surf.gii') + >>> v.shape # (N, 3) + >>> f.shape # (M, 3) + """ + try: + import nibabel as nib # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Reading GIfTI surface files requires nibabel. " + "Install with: pip install nibabel" + ) from exc + + img = nib.load(path) + if not hasattr(img, "darrays") or not img.darrays: + raise ValueError( + f"GIfTI file {path!r} contains no data arrays. " + f"Expected a surface file with POINTSET and TRIANGLE arrays." + ) + + coords_arr = None + faces_arr = None + for da in img.darrays: + if da.intent == _GIFTI_INTENT_POINTSET and coords_arr is None: + coords_arr = da.data + elif da.intent == _GIFTI_INTENT_TRIANGLE and faces_arr is None: + faces_arr = da.data + + if coords_arr is None: + raise ValueError( + f"GIfTI file {path!r} has no POINTSET (intent 1008) array. " + f"If this is a functional or label file, use resolve_overlay() instead." + ) + if faces_arr is None: + raise ValueError( + f"GIfTI file {path!r} has no TRIANGLE (intent 1009) array. " + f"A valid surface GIfTI must contain both POINTSET and TRIANGLE arrays." + ) + + vertices = np.asarray(coords_arr, dtype=np.float32) + faces = np.asarray(faces_arr, dtype=np.uint32) + + if vertices.ndim != 2 or vertices.shape[1] != 3: + raise ValueError( + f"GIfTI POINTSET array in {path!r} has unexpected shape " + f"{vertices.shape}; expected (N, 3)." + ) + if faces.ndim != 2 or faces.shape[1] != 3: + raise ValueError( + f"GIfTI TRIANGLE array in {path!r} has unexpected shape " + f"{faces.shape}; expected (M, 3)." + ) + if faces.size > 0: + n_verts = vertices.shape[0] + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"GIfTI face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +_READERS = { + ".off": read_off, + ".vtk": read_vtk_ascii_polydata, + ".ply": read_ply_ascii, + ".gii": read_gifti_surface, + ".surf.gii": read_gifti_surface, # compound extension — checked first +} + +_SUPPORTED = ", ".join(sorted(_READERS)) + + +def read_mesh(path): + """Read a triangle mesh from an OFF, VTK, PLY, or GIfTI surface file. + + Dispatches to the appropriate reader based on the file extension. + For FreeSurfer binary surfaces (which typically have no standard + extension, e.g. ``lh.white``) use + :func:`whippersnappy.geometry.read_geometry` directly, or pass the + path through :func:`whippersnappy.geometry.inputs.resolve_mesh` which + handles the routing automatically. + + Parameters + ---------- + path : str + Path to a mesh file. Extension must be one of: + ``.off``, ``.vtk``, ``.ply``, ``.surf.gii``, ``.gii`` + (case-insensitive). + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the extension is not recognised. + """ + import os as _os + lower = path.lower() + # Check compound extension first (.surf.gii before .gii) + if lower.endswith(".surf.gii"): + return read_gifti_surface(path) + ext = _os.path.splitext(lower)[1] + reader = _READERS.get(ext) + if reader is None: + raise ValueError( + f"Unsupported mesh file extension {ext!r} for {path!r}. " + f"Supported formats: {_SUPPORTED}. " + f"For FreeSurfer surfaces (no extension) use resolve_mesh() " + f"or read_geometry() directly." + ) + return reader(path) + + + + + diff --git a/whippersnappy/geometry/overlay_io.py b/whippersnappy/geometry/overlay_io.py new file mode 100644 index 0000000..04ea510 --- /dev/null +++ b/whippersnappy/geometry/overlay_io.py @@ -0,0 +1,345 @@ +"""Lightweight per-vertex scalar and label readers for common open formats. + +This module implements pure-Python (stdlib + numpy only) readers for simple +per-vertex data files, plus a GIfTI reader that reuses the nibabel dependency +already present in the project. + +Supported formats +----------------- +* **ASCII text** (``.txt``, ``.csv``) — one numeric value per line; optional + single non-numeric header line (skipped automatically); whitespace or + comma-separated. Integer values are loaded as ``int32``; all others as + ``float32``. + +* **NumPy array** (``.npy``) — single 1-D array saved with + ``numpy.save``. Any numeric dtype is accepted and kept as-is; callers + cast to the required dtype. + +* **NumPy archive** (``.npz``) — multi-array archive saved with + ``numpy.savez``. The array named ``"data"`` is used if present, + otherwise the first array in the archive is used. + +* **GIfTI functional / label** (``.func.gii``, ``.label.gii``, ``.gii``) — + loaded via ``nibabel``; the first data array is returned. Covers HCP, + fMRIPrep, and Connectome Workbench outputs. + +The public dispatcher :func:`read_overlay` routes by file extension. +FreeSurfer binary morph files and MGH/MGZ files are *not* handled here — +they are loaded by :mod:`whippersnappy.geometry.freesurfer_io` and +dispatched from :func:`whippersnappy.geometry.inputs._load_overlay_from_file`. + +All readers return a flat ``numpy.ndarray`` of shape ``(N,)``. The caller +is responsible for casting to the desired dtype (``float32`` for overlays and +background maps, ``bool`` for ROI masks, ``int32`` for label/parcellation +maps). +""" + +import os + +import numpy as np + +# --------------------------------------------------------------------------- +# ASCII text / CSV reader +# --------------------------------------------------------------------------- + +def read_txt(path): + """Read a per-vertex scalar file in plain ASCII format. + + The file must contain exactly one numeric value per line. An optional + single-line text header (non-numeric first line) is silently skipped. + Whitespace and comma separators are both accepted; only the *first* + value on each line is used (allowing simple CSV files with a single + data column). + + Integer-valued files (every value equal to its ``int`` cast) are + returned as ``int32``; all others as ``float32``. + + Parameters + ---------- + path : str + Path to the ``.txt`` or ``.csv`` file. + + Returns + ------- + numpy.ndarray, shape (N,), dtype float32 or int32 + + Raises + ------ + ValueError + If no numeric values can be parsed from the file. + IOError + If the file cannot be opened. + + Examples + -------- + A valid ``overlay.txt``:: + + # optional comment line (skipped) + 0.123 + -1.456 + 2.0 + + A valid ``labels.csv`` (first column used, header skipped):: + + label + 3 + 0 + 1 + 3 + """ + values = [] + with open(path, encoding="utf-8", errors="replace") as fh: + for lineno, raw in enumerate(fh, start=1): + line = raw.strip() + if not line or line.startswith("#"): + continue + # Take only the first token (handles CSV with a single data column) + token = line.split(",")[0].split()[0] + try: + values.append(float(token)) + except ValueError as exc: + if lineno == 1: + # Treat the very first non-numeric line as a header and skip it + continue + raise ValueError( + f"Could not parse numeric value on line {lineno} of {path!r}: " + f"{raw.strip()!r}" + ) from exc + + if not values: + raise ValueError(f"No numeric values found in {path!r}.") + + arr = np.array(values, dtype=np.float32) + + # Promote to int32 if all values are integers (label / parcellation file) + if np.all(arr == arr.astype(np.int32)): + return arr.astype(np.int32) + return arr + + +# --------------------------------------------------------------------------- +# NumPy readers +# --------------------------------------------------------------------------- + +def read_npy(path): + """Read a per-vertex scalar array from a NumPy ``.npy`` file. + + Parameters + ---------- + path : str + Path to the ``.npy`` file. + + Returns + ------- + numpy.ndarray, shape (N,) + The stored array, squeezed to 1-D. + + Raises + ------ + ValueError + If the stored array is not 1-D after squeezing, or is empty. + IOError + If the file cannot be opened. + """ + arr = np.load(path) + arr = np.squeeze(arr) + if arr.ndim != 1: + raise ValueError( + f"NumPy file {path!r} contains an array of shape {arr.shape}; " + f"expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"NumPy file {path!r} contains an empty array.") + return arr + + +def read_npz(path): + """Read a per-vertex scalar array from a NumPy ``.npz`` archive. + + The array named ``"data"`` is returned if it exists; otherwise the + first array in the archive is used. + + Parameters + ---------- + path : str + Path to the ``.npz`` file. + + Returns + ------- + numpy.ndarray, shape (N,) + The selected array, squeezed to 1-D. + + Raises + ------ + ValueError + If no arrays are found, or the selected array is not 1-D after + squeezing. + IOError + If the file cannot be opened. + """ + archive = np.load(path) + keys = list(archive.keys()) + if not keys: + raise ValueError(f"NumPy archive {path!r} contains no arrays.") + + key = "data" if "data" in keys else keys[0] + arr = np.squeeze(archive[key]) + if arr.ndim != 1: + raise ValueError( + f"NumPy archive {path!r}, array {key!r} has shape {arr.shape}; " + f"expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"NumPy archive {path!r}, array {key!r} is empty.") + return arr + + +# --------------------------------------------------------------------------- +# GIfTI reader +# --------------------------------------------------------------------------- + +def read_gifti(path): + """Read a per-vertex scalar array from a GIfTI functional or label file. + + Supports ``.func.gii`` (continuous scalars, e.g. HCP thickness) and + ``.label.gii`` (integer parcellation labels, e.g. HCP parcellation). + Plain ``.gii`` files are also accepted provided they contain a scalar + data array — **not** a surface geometry file. For surface GIfTI files + (``.surf.gii`` or ``.gii`` files with POINTSET+TRIANGLE arrays) use + :func:`whippersnappy.geometry.mesh_io.read_gifti_surface` or pass the + path to :func:`whippersnappy.geometry.inputs.resolve_mesh`. + + The first non-POINTSET, non-TRIANGLE data array in the file is returned. + + Parameters + ---------- + path : str + Path to a GIfTI file. + + Returns + ------- + numpy.ndarray, shape (N,) + The first scalar data array, squeezed to 1-D. + + Raises + ------ + ImportError + If ``nibabel`` is not installed. + ValueError + If the file is a surface GIfTI (POINTSET+TRIANGLE only), contains + no usable scalar arrays, or the first scalar array is not 1-D. + IOError + If the file cannot be opened or is not a valid GIfTI file. + """ + try: + import nibabel as nib # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Reading GIfTI files requires nibabel. " + "Install with: pip install nibabel" + ) from exc + + img = nib.load(path) + if not hasattr(img, "darrays") or not img.darrays: + raise ValueError( + f"GIfTI file {path!r} contains no data arrays." + ) + + # Intent codes for surface geometry — skip these + _SURFACE_INTENTS = {1008, 1009} # POINTSET, TRIANGLE + + scalar_da = None + has_surface_arrays = False + for da in img.darrays: + if da.intent in _SURFACE_INTENTS: + has_surface_arrays = True + elif scalar_da is None: + scalar_da = da + + if scalar_da is None: + if has_surface_arrays: + raise ValueError( + f"GIfTI file {path!r} appears to be a surface geometry file " + f"(contains only POINTSET/TRIANGLE arrays). " + f"Use resolve_mesh() or read_gifti_surface() to load it as a mesh." + ) + raise ValueError( + f"GIfTI file {path!r} contains no scalar data arrays." + ) + + arr = np.squeeze(scalar_da.data) + if arr.ndim != 1: + raise ValueError( + f"GIfTI file {path!r}: first scalar data array has shape " + f"{scalar_da.data.shape}; expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"GIfTI file {path!r}: first scalar data array is empty.") + return arr + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +# Map from lower-case file extension to reader function. +# Note: ".func.gii" and ".label.gii" have a compound extension; we handle +# them by matching the last *two* dot-separated components as well. +_READERS = { + ".txt": read_txt, + ".csv": read_txt, + ".npy": read_npy, + ".npz": read_npz, + ".gii": read_gifti, + ".func.gii": read_gifti, + ".label.gii": read_gifti, +} + +_SUPPORTED = ", ".join(sorted(_READERS)) + + +def read_overlay(path): + """Read a per-vertex scalar or label array from a file. + + Dispatches to the appropriate reader based on the file extension. + FreeSurfer binary morph files (e.g. ``lh.curv``, ``lh.thickness``) and + MGH/MGZ files are **not** handled here — pass them through + :func:`whippersnappy.geometry.inputs._load_overlay_from_file` which + already routes those formats via :mod:`~whippersnappy.geometry.freesurfer_io`. + + Parameters + ---------- + path : str + Path to an overlay/label file. Recognised extensions: + + * ``.txt``, ``.csv`` — plain ASCII, one value per line + * ``.npy`` — NumPy binary array + * ``.npz`` — NumPy archive (key ``"data"`` or first array) + * ``.gii``, ``.func.gii``, ``.label.gii`` — GIfTI + + Returns + ------- + numpy.ndarray, shape (N,) + + Raises + ------ + ValueError + If the file extension is not recognised. + """ + # Check compound extensions first (.func.gii, .label.gii) + lower = path.lower() + for compound in (".func.gii", ".label.gii"): + if lower.endswith(compound): + return _READERS[compound](path) + + ext = os.path.splitext(path)[1].lower() + reader = _READERS.get(ext) + if reader is None: + raise ValueError( + f"Unsupported overlay file extension {ext!r} for {path!r}. " + f"Supported formats: {_SUPPORTED}. " + f"For FreeSurfer morph files (no extension) or .mgh/.mgz files " + f"the routing is handled automatically by resolve_overlay()." + ) + return reader(path) + diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py new file mode 100644 index 0000000..d5df718 --- /dev/null +++ b/whippersnappy/geometry/prepare.py @@ -0,0 +1,454 @@ +"""Geometry helpers for mesh processing and GPU preparation (prepare.py). + +This module contains the primary geometry-preparation pipeline. The +low-level workhorse is :func:`prepare_geometry_from_arrays` which operates +entirely on numpy arrays. :func:`prepare_geometry` is a thin file-loading +wrapper that delegates to the resolver functions in +:mod:`whippersnappy.geometry.inputs` before calling +:func:`prepare_geometry_from_arrays`. +""" + +import warnings + +import numpy as np + +from ..utils.colormap import binary_color, heat_color, mask_sign, rescale_overlay +from ..utils.types import ColorSelection +from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi + + +def normalize_mesh(v, scale=1.0): + """Center and scale mesh vertex coordinates to a unit cube. + + The function recenters the vertices around the origin and scales them so + that the maximum extent fits into a unit cube, optionally applying an + additional scale factor. + + Parameters + ---------- + v : numpy.ndarray + Vertex coordinate array of shape (n_vertices, 3). + scale : float, optional + Additional multiplicative scale applied after normalization. + + Returns + ------- + numpy.ndarray + Normalized vertex coordinates with same shape as ``v``. + """ + bbmax = np.max(v, axis=0) + bbmin = np.min(v, axis=0) + v = v - 0.5 * (bbmax + bbmin) + v = scale * v / np.max(bbmax - bbmin) + return v + + +def vertex_normals(v, t): + """Compute per-vertex normals from triangle connectivity. + + Parameters + ---------- + v : numpy.ndarray + Vertex coordinates (n_vertices, 3). + t : numpy.ndarray + Triangle indices (n_faces, 3). + + Returns + ------- + numpy.ndarray + Per-vertex unit normals (n_vertices, 3). + """ + v0 = v[t[:, 0], :] + v1 = v[t[:, 1], :] + v2 = v[t[:, 2], :] + v1mv0 = v1 - v0 + v2mv1 = v2 - v1 + v0mv2 = v0 - v2 + cr0 = np.cross(v1mv0, -v0mv2) + cr1 = np.cross(v2mv1, -v1mv0) + cr2 = np.cross(v0mv2, -v2mv1) + # Vectorized accumulation using bincount + idx = np.concatenate([t[:, 0], t[:, 1], t[:, 2]]) + contribs = np.vstack([cr0, cr1, cr2]) + n = np.empty((v.shape[0], 3), dtype=np.float64) + for j in range(3): + n[:, j] = np.bincount(idx, weights=contribs[:, j], minlength=v.shape[0]) + ln = np.sqrt(np.sum(n * n, axis=1)) + ln[ln < np.finfo(float).eps] = 1 + n = n / ln.reshape(-1, 1) + return n + + +def _estimate_thresholds_from_array(mapdata, minval=None, maxval=None): + """Estimate threshold and saturation values from an already-loaded array. + + Parameters + ---------- + mapdata : numpy.ndarray + Per-vertex overlay values. + minval : float or None, optional + If provided, used as-is; otherwise estimated as the minimum absolute + value in the data. + maxval : float or None, optional + If provided, used as-is; otherwise estimated as the maximum absolute + value in the data. + + Returns + ------- + minval : float + Threshold value (lower bound of the color scale). + maxval : float + Saturation value (upper bound of the color scale). + """ + valabs = np.abs(mapdata) + if maxval is None: + maxval = float(np.max(valabs)) if np.any(valabs) else 0.0 + if minval is None: + minval = float(max(0.0, np.min(valabs) if np.any(valabs) else 0.0)) + return minval, maxval + + +def estimate_overlay_thresholds(overlay, minval=None, maxval=None): + """Estimate threshold and saturation values from an overlay file or array. + + Reads the overlay data and derives ``fmin`` / ``fmax`` from the absolute + values without performing any geometry or color work. Both values are + returned unchanged when they are already provided by the caller, making + the function safe to call unconditionally. + + Parameters + ---------- + overlay : str or array-like + Path to the overlay file (.mgh or FreeSurfer morph format), or a + numpy array / array-like of per-vertex scalar values. + minval : float or None, optional + If provided, used as-is for the threshold; otherwise estimated as + the minimum absolute value in the overlay. + maxval : float or None, optional + If provided, used as-is for the saturation; otherwise estimated as + the maximum absolute value in the overlay. + + Returns + ------- + minval : float + Threshold value (lower bound of the color scale). + maxval : float + Saturation value (upper bound of the color scale). + """ + if isinstance(overlay, str): + # Use resolve_overlay with n_vertices=None to skip shape validation + overlay_arr = resolve_overlay(overlay, n_vertices=None) + else: + overlay_arr = np.asarray(overlay) + return _estimate_thresholds_from_array(overlay_arr, minval, maxval) + + +def prepare_geometry_from_arrays( + vertices, + faces, + overlay=None, + annot=None, + ctab=None, + bg_map=None, + roi=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Prepare vertex and color arrays for GPU upload from numpy arrays. + + This is the core geometry preparation function. All inputs must already + be resolved numpy arrays; for file-path support use the thin wrapper + :func:`prepare_geometry`. + + Parameters + ---------- + vertices : numpy.ndarray + Vertex coordinate array of shape (N, 3), dtype float32. + faces : numpy.ndarray + Triangle index array of shape (M, 3), dtype uint32. + overlay : numpy.ndarray or None, optional + Per-vertex scalar values of shape (N,) float32 used for coloring. + annot : numpy.ndarray or None, optional + Per-vertex integer label indices of shape (N,) int32. + ctab : numpy.ndarray or None, optional + Color table array (n_labels, ≥3) associated with *annot*. + bg_map : numpy.ndarray or None, optional + Per-vertex scalar values of shape (N,) float32 whose sign determines + background shading (binary light/dark). When ``None`` a flat gray + background is used. + roi : numpy.ndarray of bool or None, optional + Boolean mask of shape (N,). ``True`` = vertex is inside the region + of interest and receives overlay coloring; ``False`` = vertex falls + back to background shading. When ``None`` all vertices are in-ROI. + minval, maxval : float or None, optional + Threshold and saturation values for overlay scaling. + invert : bool, optional, default False + Invert color mapping. + scale : float, optional, default 1.85 + Geometry scaling factor applied by :func:`normalize_mesh`. + color_mode : ColorSelection, optional, default ColorSelection.BOTH + Which sign(s) of overlay values to use for coloring. + + Returns + ------- + vertexdata : numpy.ndarray + Nx9 array (position x3, normal x3, color x3) ready for GPU upload. + triangles : numpy.ndarray + Mx3 uint32 triangle index array. + fmin, fmax : float or None + Final threshold and saturation values used for color mapping. + pos, neg : bool or None + Flags indicating whether positive/negative overlay values are present. + + Raises + ------ + ValueError + If overlay or annotation arrays do not match the surface vertex count. + """ + vertices = normalize_mesh(np.array(vertices, dtype=np.float32), scale) + triangles = np.array(faces, dtype=np.uint32) + vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32) + num_vertices = vertices.shape[0] + + # Build background (sulcal) colormap + if bg_map is not None: + if bg_map.shape[0] != num_vertices: + warnings.warn( + f"bg_map has {bg_map.shape[0]} values but mesh has {num_vertices}.", + stacklevel=2, + ) + sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) + else: + sulcmap = binary_color(bg_map, 0.0, color_low=0.5, color_high=0.33) + else: + sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) + + # Initialize defaults for overlay outputs + fmin = None + fmax = None + pos = None + neg = None + colors = sulcmap # use as default + + # Apply overlay coloring + if overlay is not None: + if overlay.shape[0] != num_vertices: + raise ValueError( + f"overlay has {overlay.shape[0]} values but mesh has {num_vertices}.\n" + "This usually means the overlay does not match the provided surface " + "(e.g. RH overlay used with LH surface). Provide the correct overlay." + ) + mapdata = overlay.copy().astype(np.float64) + minval, maxval = _estimate_thresholds_from_array(mapdata, minval, maxval) + mapdata = mask_sign(mapdata, color_mode) + mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) + colors = heat_color(mapdata, invert) + # Some mapdata values could be nan (below min threshold) — fall back to bg + missing = np.isnan(mapdata) + if np.any(missing): + colors[missing, :] = sulcmap[missing, :] + + elif annot is not None and ctab is not None: + # Per-vertex annotation coloring + if annot.shape[0] != num_vertices: + raise ValueError( + f"annot has {annot.shape[0]} values but mesh has {num_vertices}.\n" + "This usually means the .annot does not match the provided surface " + "(e.g. RH annot used with LH surface). Provide the correct annot file." + ) + annot = annot.astype(np.int32) + colors = np.array(sulcmap, dtype=np.float32) + ctab_rgb = np.asarray(ctab[:, 0:3], dtype=np.float32) + denom = 255.0 if np.max(ctab_rgb) > 1 else 1.0 + valid = (annot >= 0) & (annot < ctab.shape[0]) + if np.any(valid): + colors[valid, :] = ctab_rgb[annot[valid], :] / denom + + # Ensure colors dtype matches vertices/normals + colors = np.asarray(colors, dtype=np.float32) + + # Apply ROI mask: vertices where roi == False fall back to sulcmap + if roi is not None: + outside = ~roi + if np.any(outside): + colors[outside, :] = sulcmap[outside, :] + + vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) + return vertexdata, triangles, fmin, fmax, pos, neg + + +def prepare_geometry( + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Prepare vertex and color arrays for GPU upload. + + This is a thin file-loading wrapper around + :func:`prepare_geometry_from_arrays`. Inputs are resolved via the + functions in :mod:`whippersnappy.geometry.inputs` so that every + parameter can be either a file path or a numpy array. + + Parameters + ---------- + mesh : str or tuple of (array-like, array-like) + Surface file path (FreeSurfer format) **or** a ``(vertices, faces)`` + tuple/list where *vertices* is (N, 3) float and *faces* is (M, 3) int. + overlay : str, array-like, or None, optional + Path to an overlay (.mgh / FreeSurfer morph) file, or a (N,) array + of per-vertex scalar values. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple. + bg_map : str, array-like, or None, optional + Path to a curvature/morph file used for background shading, or a + (N,) array whose sign determines light/dark shading. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file or a (N,) boolean array. Vertices + with ``True`` receive overlay coloring; others fall back to *bg_map*. + minval, maxval : float or None, optional + Threshold and saturation values for overlay scaling. + invert : bool, optional, default False + Invert color mapping. + scale : float, optional, default 1.85 + Geometry scaling factor applied by :func:`normalize_mesh`. + color_mode : ColorSelection, optional, default ColorSelection.BOTH + Which sign(s) of overlay values to use for coloring. + + Returns + ------- + vertexdata : numpy.ndarray + Nx9 array (position x3, normal x3, color x3) ready for GPU upload. + triangles : numpy.ndarray + Mx3 uint32 triangle index array. + fmin, fmax : float or None + Final threshold and saturation values used for color mapping. + pos, neg : bool or None + Flags indicating whether positive/negative overlay values are present. + + Raises + ------ + TypeError + If *mesh* is not a valid type. + ValueError + If overlay or annotation arrays do not match the surface vertex count. + + Examples + -------- + File-path usage:: + + vdata, tris, fmin, fmax, pos, neg = prepare_geometry( + 'fsaverage/surf/lh.white', + overlay='fsaverage/surf/lh.thickness', + bg_map='fsaverage/surf/lh.curv', + roi='fsaverage/label/lh.cortex.label', + ) + + Array inputs:: + + import numpy as np + v = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32) + f = np.array([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], dtype=np.uint32) + vdata, tris, *_ = prepare_geometry((v, f)) + """ + vertices, faces = resolve_mesh(mesh) + n = vertices.shape[0] + overlay_arr = resolve_overlay(overlay, n_vertices=n) + bg_map_arr = resolve_bg_map(bg_map, n_vertices=n) + roi_arr = resolve_roi(roi, n_vertices=n) + annot_result = resolve_annot(annot, n_vertices=n) + annot_arr = annot_result[0] if annot_result is not None else None + ctab_arr = annot_result[1] if annot_result is not None else None + return prepare_geometry_from_arrays( + vertices, faces, overlay_arr, annot_arr, ctab_arr, + bg_map_arr, roi_arr, minval, maxval, invert, scale, color_mode, + ) + + +def prepare_and_validate_geometry( + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, + fthresh=None, + fmax=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Load and validate mesh geometry and overlay/annotation inputs. + + This is a small wrapper around :func:`prepare_geometry` that performs + the same overlay-presence validation used throughout the static snapshot + helpers. + + Parameters + ---------- + mesh : str or tuple + Passed through to :func:`prepare_geometry`. + overlay, annot, bg_map, roi : str, array-like, or None + Passed through to :func:`prepare_geometry`. + fthresh, fmax : float or None + Threshold and saturation values passed to the geometry preparer. + invert : bool + Passed to the geometry preparer. + scale : float + Scaling factor passed to the geometry preparer. + color_mode : ColorSelection + Which sign of overlay to display (POSITIVE/NEGATIVE/BOTH). + + Returns + ------- + tuple + ``(meshdata, triangles, fthresh, fmax, pos, neg)`` as returned by + :func:`prepare_geometry`. + + Raises + ------ + ValueError + If the overlay contains no values appropriate for ``color_mode``. + """ + import logging + logger = logging.getLogger(__name__) + meshdata, triangles, out_fthresh, out_fmax, pos, neg = prepare_geometry( + mesh, + overlay, + annot, + bg_map, + roi, + fthresh, + fmax, + invert, + scale=scale, + color_mode=color_mode, + ) + + # Validate overlay presence similar to previous inline checks + if overlay is not None: + if color_mode == ColorSelection.POSITIVE: + if not pos and neg: + logger.error("Overlay has no values to display with positive color_mode") + raise ValueError("Overlay has no values to display with positive color_mode") + neg = False + elif color_mode == ColorSelection.NEGATIVE: + if pos and not neg: + logger.error("Overlay has no values to display with negative color_mode") + raise ValueError("Overlay has no values to display with negative color_mode") + pos = False + if not pos and not neg: + logger.error("Overlay has no values to display") + raise ValueError("Overlay has no values to display") + + return meshdata, triangles, out_fthresh, out_fmax, pos, neg + diff --git a/whippersnappy/geometry/surf_name.py b/whippersnappy/geometry/surf_name.py new file mode 100644 index 0000000..1d1f07d --- /dev/null +++ b/whippersnappy/geometry/surf_name.py @@ -0,0 +1,32 @@ +"""Helper for finding a surface file name inside a subject directory. + +This replaces the previous `io.py` name which was generic; `surf_name.py` is +more descriptive (it provides `get_surf_name`). +""" +import os + + +def get_surf_name(sdir, hemi): + """Find a suitable surface basename in a subject directory. + + The function searches the standard FreeSurfer `surf/` directory for a + common surface name in order of preference and returns the basename + (search for 'pial_semi_inflated', 'white', and then 'inflated'). + + Parameters + ---------- + sdir : str + Path to the subject directory containing a `surf/` subdirectory. + hemi : {'lh','rh'} + Hemisphere prefix to use when searching for surface files. + + Returns + ------- + surf_name : str or None + The surface basename if found, otherwise ``None``. + """ + for surf_name_option in ["pial_semi_inflated", "white", "inflated"]: + path = os.path.join(sdir, "surf", f"{hemi}.{surf_name_option}") + if os.path.exists(path): + return surf_name_option + return None diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py new file mode 100644 index 0000000..4b458bf --- /dev/null +++ b/whippersnappy/gl/__init__.py @@ -0,0 +1,46 @@ +"""OpenGL helper utilities (gl package). + +This package replaces the previous `gl_utils.py` module. +Functions are re-exported at package level for convenience, e.g.: + + from whippersnappy.gl import init_window, setup_shader + +""" + +from . import _platform # noqa: F401 — MUST be first; sets PYOPENGL_PLATFORM +from .camera import make_model, make_projection, make_view +from .shaders import get_default_shaders, get_webgl_shaders +from .utils import ( + capture_window, + compile_shader_program, + create_vao, + create_window_with_fallback, + init_window, + render_scene, + set_camera_uniforms, + set_default_gl_state, + set_lighting_uniforms, + setup_buffers, + setup_shader, + setup_vertex_attributes, + terminate_context, +) +from .views import ( + ViewState, + arcball_rotation_matrix, + arcball_vector, + compute_view_matrix, + get_view_matrices, + get_view_matrix, +) + +__all__ = [ + 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes', + 'set_default_gl_state', 'set_camera_uniforms', 'set_lighting_uniforms', + 'init_window', 'render_scene', 'setup_shader', 'capture_window', + 'make_model', 'make_projection', 'make_view', + 'get_default_shaders', 'get_view_matrices', 'get_view_matrix', + 'get_webgl_shaders', 'terminate_context', + 'ViewState', 'compute_view_matrix', + 'arcball_vector', 'arcball_rotation_matrix', +] diff --git a/whippersnappy/gl/_platform.py b/whippersnappy/gl/_platform.py new file mode 100644 index 0000000..8a94ad4 --- /dev/null +++ b/whippersnappy/gl/_platform.py @@ -0,0 +1,17 @@ +"""Bootstrap PyOpenGL platform selection — must be imported first. + +Imported unconditionally at the top of gl/__init__.py before any other +OpenGL symbol. Sets PYOPENGL_PLATFORM=egl when running headless on Linux +so that PyOpenGL uses the EGL backend instead of GLX. + +On macOS PyOpenGL uses CGL and on Windows it uses WGL — both are handled +natively without EGL. If the user has already set PYOPENGL_PLATFORM that +value is always respected. +""" +import os +import sys + +if "PYOPENGL_PLATFORM" not in os.environ and sys.platform == "linux": + # No X11/Wayland display on Linux → force EGL headless backend. + if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + os.environ["PYOPENGL_PLATFORM"] = "egl" diff --git a/whippersnappy/gl/camera.py b/whippersnappy/gl/camera.py new file mode 100644 index 0000000..bcfa376 --- /dev/null +++ b/whippersnappy/gl/camera.py @@ -0,0 +1,76 @@ +"""Camera and transform helpers (moved under gl package).""" + +import pyrr + + +def make_projection(width, height, fov=20.0, near=0.1, far=100.0): + """Create a 4x4 perspective projection matrix. + + Parameters + ---------- + width, height : int + Viewport dimensions in pixels (used to compute aspect ratio). + fov : float, optional, default 20.0 + Vertical field of view in degrees. + near, far : float, optional, default 0.1, 100.0 + Near and far clipping planes. Default are 0.1 and 100.0, respectively. + + Returns + ------- + numpy.ndarray + 4x4 projection matrix. + """ + return pyrr.matrix44.create_perspective_projection(fov, width / height, near, far) + + +def make_view(camera_pos=(0.0, 0.0, -5.0)): + """Create a view matrix for a camera located at ``camera_pos``. + + Parameters + ---------- + camera_pos : sequence of float, optional, default (0.0, 0.0, -5.0) + 3-element position of the camera in world space. + Default is (0.0, 0.0, -5.0). + + Returns + ------- + numpy.ndarray + 4x4 view matrix. + """ + return pyrr.matrix44.create_from_translation(pyrr.Vector3(camera_pos)) + + +def make_model(): + """Create a default model matrix (identity translation). + + Returns + ------- + numpy.ndarray + 4x4 model matrix. + """ + return pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0])) + + +def make_transform(translation, rotation, scale): + """Build a model transform matrix from translation, rotation and uniform scale. + + Parameters + ---------- + translation : sequence of float + 3-element translation vector. + rotation : numpy.ndarray + 4x4 rotation matrix. + scale : float + Uniform scaling factor. + + Returns + ------- + numpy.ndarray + 4x4 transformation matrix (translation * rotation * scale). + """ + scale_matrix = pyrr.matrix44.create_from_scale([scale, scale, scale]) + return ( + pyrr.matrix44.create_from_translation(translation) + * rotation + * scale_matrix + ) diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py new file mode 100644 index 0000000..950dffa --- /dev/null +++ b/whippersnappy/gl/egl_context.py @@ -0,0 +1,374 @@ +"""EGL off-screen (headless) OpenGL context via pbuffer + FBO. + +This module provides a drop-in alternative to GLFW window creation for +headless environments (CI, Docker, HPC clusters) where no X11/Wayland +display is available. It requires: + + - A system EGL library (``libegl1`` on Debian/Ubuntu, already present + in the WhipperSnapPy Dockerfile). + - PyOpenGL >= 3.1 (already a project dependency), which ships + ``OpenGL.EGL`` bindings. + - Either an NVIDIA GPU with the EGL driver, or Mesa ``libEGL-mesa0`` + (llvmpipe software renderer) for CPU-only systems. + +Typical usage (internal, called from ``create_window_with_fallback``):: + + from whippersnappy.gl.egl_context import EGLContext + + ctx = EGLContext(width, height) + ctx.make_current() + # ... OpenGL calls ... + img = ctx.read_pixels() + ctx.destroy() +""" + +import ctypes +import logging +import os +import sys + +if sys.platform == "darwin": + raise ImportError("EGL is not available on macOS; use GLFW/CGL instead.") + +# Must be set before OpenGL.GL is imported anywhere in the process. +# If already set (e.g. user set it, or GLFW succeeded), respect it. +# We set it here because this module is only imported when EGL is needed. +if os.environ.get("PYOPENGL_PLATFORM") != "egl": + os.environ["PYOPENGL_PLATFORM"] = "egl" + +import OpenGL.GL as gl +from PIL import Image + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# EGL constants not exposed by all PyOpenGL versions +# --------------------------------------------------------------------------- +_EGL_SURFACE_TYPE = 0x3033 +_EGL_PBUFFER_BIT = 0x0001 +_EGL_RENDERABLE_TYPE = 0x3040 +_EGL_OPENGL_BIT = 0x0008 +_EGL_NONE = 0x3038 +_EGL_WIDTH = 0x3057 +_EGL_HEIGHT = 0x3056 +_EGL_OPENGL_API = 0x30A2 +_EGL_CONTEXT_MAJOR_VERSION = 0x3098 +_EGL_CONTEXT_MINOR_VERSION = 0x30FB +_EGL_PLATFORM_DEVICE_EXT = 0x313F + + +class EGLContext: + """A headless OpenGL 3.3 Core context backed by an EGL pbuffer + FBO. + + The pbuffer surface is created solely to satisfy EGL's requirement for + a surface when calling ``eglMakeCurrent``. All rendering is directed + into an off-screen Framebuffer Object (FBO) so that ``glReadPixels`` + captures exactly what was rendered regardless of platform quirks with + pbuffer readback. + + Parameters + ---------- + width, height : int + Dimensions of the off-screen render target in pixels. + + Attributes + ---------- + width, height : int + Render target dimensions. + fbo : int + OpenGL FBO handle (valid after ``make_current`` is called). + + Raises + ------ + ImportError + If ``OpenGL.EGL`` bindings are not available. + RuntimeError + If any EGL initialisation step fails. + """ + + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self._libegl = None + self._display = None + self._surface = None + self._context = None + self._config = None + self.fbo = None + self._rbo_color = None + self._rbo_depth = None + self._init_egl() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _get_ext_fn(self, name, restype, argtypes): + """Load an EGL extension function via eglGetProcAddress.""" + addr = self._libegl.eglGetProcAddress(name.encode()) + if not addr: + raise RuntimeError( + f"eglGetProcAddress('{name}') returned NULL — " + f"extension not available on this driver." + ) + FuncType = ctypes.CFUNCTYPE(restype, *argtypes) + return FuncType(addr) + + def _init_egl(self): + import ctypes.util + + egl_name = ctypes.util.find_library("EGL") or "libEGL.so.1" + try: + libegl = ctypes.CDLL(egl_name) + except OSError as e: + raise RuntimeError( + f"Could not load {egl_name}. " + "Install libegl1-mesa and retry." + ) from e + self._libegl = libegl # keep reference alive + + # Set signatures for direct (non-extension) EGL symbols + libegl.eglGetProcAddress.restype = ctypes.c_void_p + libegl.eglGetProcAddress.argtypes = [ctypes.c_char_p] + libegl.eglQueryString.restype = ctypes.c_char_p + libegl.eglQueryString.argtypes = [ctypes.c_void_p, ctypes.c_int] + libegl.eglInitialize.restype = ctypes.c_bool + libegl.eglInitialize.argtypes = [ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int)] + libegl.eglBindAPI.restype = ctypes.c_bool + libegl.eglBindAPI.argtypes = [ctypes.c_uint] + libegl.eglChooseConfig.restype = ctypes.c_bool + libegl.eglChooseConfig.argtypes = [ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int), + ctypes.c_void_p, ctypes.c_int, + ctypes.POINTER(ctypes.c_int)] + libegl.eglCreatePbufferSurface.restype = ctypes.c_void_p + libegl.eglCreatePbufferSurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int)] + libegl.eglCreateContext.restype = ctypes.c_void_p + libegl.eglCreateContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int)] + libegl.eglMakeCurrent.restype = ctypes.c_bool + libegl.eglMakeCurrent.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.c_void_p, ctypes.c_void_p] + libegl.eglDestroyContext.restype = ctypes.c_bool + libegl.eglDestroyContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + libegl.eglDestroySurface.restype = ctypes.c_bool + libegl.eglDestroySurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + libegl.eglTerminate.restype = ctypes.c_bool + libegl.eglTerminate.argtypes = [ctypes.c_void_p] + + # Check extensions and load ext functions via eglGetProcAddress + _EGL_EXTENSIONS = 0x3055 + client_exts = libegl.eglQueryString(None, _EGL_EXTENSIONS) or b"" + logger.debug("EGL client extensions: %s", client_exts.decode()) + + has_device_enum = b"EGL_EXT_device_enumeration" in client_exts + has_platform_base = b"EGL_EXT_platform_base" in client_exts + + display = None + if has_device_enum and has_platform_base: + eglQueryDevicesEXT = self._get_ext_fn( + "eglQueryDevicesEXT", + ctypes.c_bool, + [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)], + ) + eglGetPlatformDisplayEXT = self._get_ext_fn( + "eglGetPlatformDisplayEXT", + ctypes.c_void_p, + [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)], + ) + display = self._open_device_display( + eglQueryDevicesEXT, eglGetPlatformDisplayEXT + ) + + if display is None: + logger.debug("Falling back to eglGetDisplay(EGL_DEFAULT_DISPLAY)") + libegl.eglGetDisplay.restype = ctypes.c_void_p + libegl.eglGetDisplay.argtypes = [ctypes.c_void_p] + display = libegl.eglGetDisplay(ctypes.c_void_p(0)) + + if not display: + raise RuntimeError( + "Could not obtain any EGL display. " + "Install libegl1-mesa for CPU rendering." + ) + self._display = display + + major, minor = ctypes.c_int(0), ctypes.c_int(0) + if not libegl.eglInitialize( + self._display, ctypes.byref(major), ctypes.byref(minor) + ): + raise RuntimeError("eglInitialize failed.") + logger.debug("EGL %d.%d", major.value, minor.value) + + if not libegl.eglBindAPI(_EGL_OPENGL_API): + raise RuntimeError("eglBindAPI(OpenGL) failed.") + + cfg_attribs = (ctypes.c_int * 7)( + _EGL_SURFACE_TYPE, _EGL_PBUFFER_BIT, + _EGL_RENDERABLE_TYPE, _EGL_OPENGL_BIT, + _EGL_NONE, + ) + configs = (ctypes.c_void_p * 1)() + num_cfgs = ctypes.c_int(0) + if not libegl.eglChooseConfig( + self._display, cfg_attribs, configs, 1, ctypes.byref(num_cfgs) + ) or num_cfgs.value == 0: + raise RuntimeError("eglChooseConfig: no suitable config.") + self._config = configs[0] + + pbuf_attribs = (ctypes.c_int * 5)( + _EGL_WIDTH, 1, _EGL_HEIGHT, 1, _EGL_NONE + ) + self._surface = libegl.eglCreatePbufferSurface( + self._display, self._config, pbuf_attribs + ) + if not self._surface: + raise RuntimeError("eglCreatePbufferSurface failed.") + + ctx_attribs = (ctypes.c_int * 5)( + _EGL_CONTEXT_MAJOR_VERSION, 3, + _EGL_CONTEXT_MINOR_VERSION, 3, + _EGL_NONE, + ) + self._context = libegl.eglCreateContext( + self._display, self._config, None, ctx_attribs + ) + if not self._context: + raise RuntimeError( + "eglCreateContext for OpenGL 3.3 Core failed. " + "Try: MESA_GL_VERSION_OVERRIDE=3.3 MESA_GLSL_VERSION_OVERRIDE=330" + ) + logger.info("EGL context created (%dx%d)", self.width, self.height) + + + def _open_device_display(self, eglQueryDevicesEXT, eglGetPlatformDisplayEXT): + """Enumerate EGL devices and return first usable display pointer.""" + n = ctypes.c_int(0) + if not eglQueryDevicesEXT(0, None, ctypes.byref(n)) or n.value == 0: + logger.warning("eglQueryDevicesEXT: no devices.") + return None + logger.debug("EGL: %d device(s) found", n.value) + devices = (ctypes.c_void_p * n.value)() + eglQueryDevicesEXT(n.value, devices, ctypes.byref(n)) + no_attribs = (ctypes.c_int * 1)(_EGL_NONE) + for i, dev in enumerate(devices): + dpy = eglGetPlatformDisplayEXT( + _EGL_PLATFORM_DEVICE_EXT, ctypes.c_void_p(dev), no_attribs + ) + if dpy: + logger.debug("EGL: using device %d", i) + return dpy + return None + + + def make_current(self): + """Make this EGL context current and set up the FBO render target. + + Must be called before any OpenGL commands. Creates and binds an + FBO backed by two renderbuffers (RGBA color + depth/stencil). + """ + if not self._libegl.eglMakeCurrent( + self._display, self._surface, self._surface, self._context + ): + raise RuntimeError("eglMakeCurrent failed.") + + # Force PyOpenGL to discover and cache the context we just made current. + # PyOpenGL's contextdata module only recognizes contexts it has "seen" + # via at least one GL call; glGetError() is the cheapest trigger. + gl.glGetError() + + # Build FBO so rendering is directed off-screen + self.fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo) + + # Color renderbuffer + self._rbo_color = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_color) + gl.glRenderbufferStorage( + gl.GL_RENDERBUFFER, gl.GL_RGBA8, self.width, self.height + ) + gl.glFramebufferRenderbuffer( + gl.GL_FRAMEBUFFER, + gl.GL_COLOR_ATTACHMENT0, + gl.GL_RENDERBUFFER, + self._rbo_color, + ) + + # Depth + stencil renderbuffer + self._rbo_depth = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_depth) + gl.glRenderbufferStorage( + gl.GL_RENDERBUFFER, + gl.GL_DEPTH24_STENCIL8, + self.width, + self.height, + ) + gl.glFramebufferRenderbuffer( + gl.GL_FRAMEBUFFER, + gl.GL_DEPTH_STENCIL_ATTACHMENT, + gl.GL_RENDERBUFFER, + self._rbo_depth, + ) + + status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) + if status != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError( + f"FBO is not complete after EGL setup (status=0x{status:X})." + ) + + # Set the viewport to match the render target + gl.glViewport(0, 0, self.width, self.height) + logger.debug("EGL FBO complete and bound (%dx%d)", self.width, self.height) + + def read_pixels(self) -> Image.Image: + """Read the FBO contents and return a PIL RGB Image. + + Returns + ------- + PIL.Image.Image + Captured frame, vertically flipped to convert from OpenGL's + bottom-left origin to image top-left convention. + """ + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + buf = gl.glReadPixels( + 0, 0, self.width, self.height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE + ) + img = Image.frombytes("RGB", (self.width, self.height), buf) + return img.transpose(Image.FLIP_TOP_BOTTOM) + + def destroy(self): + libegl = self._libegl + # GL cleanup first (context must be current) + if self.fbo is not None: + gl.glDeleteFramebuffers(1, [self.fbo]) + self.fbo = None + if self._rbo_color is not None: + gl.glDeleteRenderbuffers(1, [self._rbo_color]) + self._rbo_color = None + if self._rbo_depth is not None: + gl.glDeleteRenderbuffers(1, [self._rbo_depth]) + self._rbo_depth = None + if self._display: + libegl.eglMakeCurrent(self._display, None, None, None) + if self._context: + libegl.eglDestroyContext(self._display, self._context) + if self._surface: + libegl.eglDestroySurface(self._display, self._surface) + libegl.eglTerminate(self._display) + self._display = None + self._context = None + self._surface = None + logger.debug("EGL context destroyed.") + + # Allow use as a context manager + def __enter__(self): + self.make_current() + return self + + def __exit__(self, *_): + self.destroy() diff --git a/whippersnappy/gl/shaders.py b/whippersnappy/gl/shaders.py new file mode 100644 index 0000000..fd3c789 --- /dev/null +++ b/whippersnappy/gl/shaders.py @@ -0,0 +1,217 @@ +"""Shared shader sources inside the gl package.""" + +def get_default_shaders(): + """Return the default GLSL 330 vertex and fragment shader sources. + + These shaders are intended for desktop OpenGL (GLSL 330) use in the + offline snapshot renderer. The returned strings contain full GLSL + shader sources for vertex and fragment stages. + + Returns + ------- + vertex_shader, fragment_shader : tuple[str, str] + Tuple containing the vertex shader and fragment shader source code + as plain strings. + """ + vertex_shader = """ + + #version 330 + + layout (location = 0) in vec3 aPos; + layout (location = 1) in vec3 aNormal; + layout (location = 2) in vec3 aColor; + + out vec3 FragPos; + out vec3 Normal; + out vec3 Color; + + uniform mat4 transform; + uniform mat4 model; + uniform mat4 view; + uniform mat4 projection; + + void main() + { + gl_Position = projection * view * model * transform * vec4(aPos, 1.0f); + FragPos = vec3(model * transform * vec4(aPos, 1.0)); + // normal matrix should be computed outside and passed! + Normal = mat3(transpose(inverse(view * model * transform))) * aNormal; + Color = aColor; + } + + """ + + fragment_shader = """ + #version 330 + + in vec3 FragPos; + in vec3 Normal; + in vec3 Color; + + out vec4 FragColor; + + uniform vec3 lightColor = vec3(1.0, 1.0, 1.0); + uniform bool doSpecular = true; + uniform float ambientStrength = 0.0; + + void main() + { + // ambient + vec3 ambient = ambientStrength * lightColor; + + // diffuse + vec3 norm = normalize(Normal); + vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); + + // key light (overhead) + vec3 lightPos1 = vec3(0.0,5.0,5.0); + vec3 lightDir = normalize(lightPos1 - FragPos); + float diff = max(dot(norm, lightDir), 0.0); + vec3 diffuse = diffweights[0] * diff * lightColor; + + // headlight (at camera) + vec3 lightPos2 = vec3(0.0,0.0,5.0); + lightDir = normalize(lightPos2 - FragPos); + vec3 ohlightDir = lightDir; + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[1] * diff * lightColor; + + // fill light (from below) + vec3 lightPos3 = vec3(0.0,-5.0,5.0); + lightDir = normalize(lightPos3 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[2] * diff * lightColor; + + // left right back lights + vec3 lightPos4 = vec3(5.0,0.0,-5.0); + lightDir = normalize(lightPos4 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + vec3 lightPos5 = vec3(-5.0,0.0,-5.0); + lightDir = normalize(lightPos5 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + // specular — camera is at (0,0,-5) in world space (from make_view), + // not at origin, so viewDir must point from FragPos toward (0,0,-5). + vec3 result; + if (doSpecular) + { + float specularStrength = 0.5; + vec3 cameraPos = vec3(0.0, 0.0, -5.0); + vec3 viewDir = normalize(cameraPos - FragPos); + vec3 reflectDir = reflect(ohlightDir, norm); + float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32); + vec3 specular = specularStrength * spec * lightColor; + result = (ambient + diffuse + specular) * Color; + } + else + { + result = (ambient + diffuse) * Color; + } + FragColor = vec4(result, 1.0); + } + + """ + + return vertex_shader, fragment_shader + + +def get_webgl_shaders(): + """Return vertex and fragment shader source strings suitable for WebGL/Three.js. + + These shader snippets are small GLSL pieces that expect Three.js to + provide built-in attributes/uniforms (e.g. projectionMatrix, + modelViewMatrix, normalMatrix). They are used by the Jupyter + pythreejs-based viewer. + + Returns + ------- + vertex_shader, fragment_shader : tuple[str, str] + Vertex and fragment shader source strings for WebGL / Three.js. + """ + + # Only declare custom attributes - Three.js provides position, normal, matrices + # Don't declare position, normal, *Matrix + # Only attributes like color , or uniforms like lightColor, ambientStrenght + # Use normalMatrix instead of computing transpose... + vertex_shader = """ + attribute vec3 color; + + varying vec3 vFragPos; + varying vec3 vNormal; + varying vec3 vColor; + + void main() + { + gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0); + vFragPos = vec3(modelViewMatrix * vec4(position, 1.0)); + vNormal = normalMatrix * normal; + vColor = color; + } + """ + + fragment_shader = """ + precision highp float; + + varying vec3 vNormal; + varying vec3 vFragPos; + varying vec3 vColor; + + uniform vec3 lightColor; + uniform float ambientStrength; + + void main() + { + // ambient + vec3 ambient = ambientStrength * lightColor; + + // diffuse + vec3 norm = normalize(vNormal); + vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); + + // key light (overhead) + vec3 lightPos1 = vec3(0.0, 5.0, 5.0); + vec3 lightDir = normalize(lightPos1 - vFragPos); + float diff = max(dot(norm, lightDir), 0.0); + vec3 diffuse = diffweights[0] * diff * lightColor; + + // headlight (at camera) + vec3 lightPos2 = vec3(0.0, 0.0, 5.0); + lightDir = normalize(lightPos2 - vFragPos); + vec3 ohlightDir = lightDir; + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[1] * diff * lightColor; + + // fill light (from below) + vec3 lightPos3 = vec3(0.0, -5.0, 5.0); + lightDir = normalize(lightPos3 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[2] * diff * lightColor; + + // left right back lights + vec3 lightPos4 = vec3(5.0, 0.0, -5.0); + lightDir = normalize(lightPos4 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + vec3 lightPos5 = vec3(-5.0, 0.0, -5.0); + lightDir = normalize(lightPos5 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + // specular + float specularStrength = 0.5; + vec3 viewDir = normalize(-vFragPos); + vec3 reflectDir = reflect(-ohlightDir, norm); + float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32.0); + vec3 specular = specularStrength * spec * lightColor; + + vec3 result = (ambient + diffuse + specular) * vColor; + gl_FragColor = vec4(result, 1.0); + } + """ + + return vertex_shader, fragment_shader + diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py new file mode 100644 index 0000000..288e383 --- /dev/null +++ b/whippersnappy/gl/utils.py @@ -0,0 +1,450 @@ +"""GL helper utilities. + +Contains the implementation of OpenGL helpers used by the package. +""" + +import logging +import os +import sys +from typing import Any + +import glfw +import OpenGL.GL as gl +import OpenGL.GL.shaders as shaders +from PIL import Image + +from .camera import make_model, make_projection, make_view +from .shaders import get_default_shaders + +# Module logger +logger = logging.getLogger(__name__) + +# Module-level EGL context handle (None when GLFW is used instead) +_egl_context: Any = None + + +def create_vao(): + """Create and bind a Vertex Array Object (VAO). + + Returns + ------- + int + OpenGL handle for the created VAO. + """ + vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(vao) + return vao + + +def compile_shader_program(vertex_src, fragment_src): + """Compile GLSL vertex and fragment sources and link them into a program. + + Parameters + ---------- + vertex_src : str + Vertex shader source code. + fragment_src : str + Fragment shader source code. + + Returns + ------- + int + OpenGL program handle. + """ + return shaders.compileProgram( + shaders.compileShader(vertex_src, gl.GL_VERTEX_SHADER), + shaders.compileShader(fragment_src, gl.GL_FRAGMENT_SHADER), + ) + + +def setup_buffers(meshdata, triangles): + """Create and upload vertex and element buffers for the mesh. + + Parameters + ---------- + meshdata : numpy.ndarray + Vertex array with interleaved attributes (position, normal, color). + triangles : numpy.ndarray + Face index array. + + Returns + ------- + (vbo, ebo) : tuple + OpenGL buffer handles for the VBO and EBO. + """ + vbo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo) + gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) + + ebo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ebo) + gl.glBufferData( + gl.GL_ELEMENT_ARRAY_BUFFER, triangles.nbytes, triangles, gl.GL_STATIC_DRAW + ) + + return vbo, ebo + + +def setup_vertex_attributes(shader): + """Configure vertex attribute pointers for position, normal and color. + + Parameters + ---------- + shader : int + OpenGL shader program handle used to query attribute locations. + """ + position = gl.glGetAttribLocation(shader, "aPos") + gl.glVertexAttribPointer( + position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0) + ) + gl.glEnableVertexAttribArray(position) + + vnormalpos = gl.glGetAttribLocation(shader, "aNormal") + gl.glVertexAttribPointer( + vnormalpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(3 * 4) + ) + gl.glEnableVertexAttribArray(vnormalpos) + + colorpos = gl.glGetAttribLocation(shader, "aColor") + gl.glVertexAttribPointer( + colorpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(6 * 4) + ) + gl.glEnableVertexAttribArray(colorpos) + + +def set_default_gl_state(): + """Set frequently used default OpenGL state for rendering. + + This function enables depth testing and sets a default clear color. + """ + gl.glClearColor(0.0, 0.0, 0.0, 1.0) + gl.glEnable(gl.GL_DEPTH_TEST) + + +def set_camera_uniforms(shader, view, projection, model): + """Upload camera MVP (view, projection, model) matrices to the shader. + + Parameters + ---------- + shader : int + OpenGL shader program handle. + view, projection, model : array-like + 4x4 matrices to be uploaded to the corresponding shader uniforms. + """ + view_loc = gl.glGetUniformLocation(shader, "view") + proj_loc = gl.glGetUniformLocation(shader, "projection") + model_loc = gl.glGetUniformLocation(shader, "model") + gl.glUniformMatrix4fv(view_loc, 1, gl.GL_FALSE, view) + gl.glUniformMatrix4fv(proj_loc, 1, gl.GL_FALSE, projection) + gl.glUniformMatrix4fv(model_loc, 1, gl.GL_FALSE, model) + + +def set_lighting_uniforms(shader, specular=True, ambient=0.0, light_color=(1.0, 1.0, 1.0)): + """Set lighting-related uniforms (specular toggle, ambient, light color). + + Parameters + ---------- + shader : int + OpenGL shader program handle. + specular : bool, optional, default True + Enable specular highlights. + ambient : float, optional, default 0.0 + Ambient light strength. + light_color : tuple, optional, default (1.0, 1.0, 1.0) + RGB light color. + """ + specular_loc = gl.glGetUniformLocation(shader, "doSpecular") + gl.glUniform1i(specular_loc, specular) + + light_color_loc = gl.glGetUniformLocation(shader, "lightColor") + gl.glUniform3f(light_color_loc, *light_color) + + ambient_loc = gl.glGetUniformLocation(shader, "ambientStrength") + gl.glUniform1f(ambient_loc, ambient) + + +def init_window(width, height, title="PyOpenGL", visible=True): + """Create a GLFW window, make an OpenGL context current and return the window handle. + + Parameters + ---------- + width, height : int + Window dimensions in pixels. + title : str, optional, default 'PyOpenGL' + Window title. + visible : bool, optional, default True + If False create an invisible/offscreen window (useful for headless + rendering when a display is available but no screen is needed). + + Returns + ------- + window or False + GLFW window handle on success, or False on failure. + """ + if not glfw.init(): + return False + + glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) + glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) + glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True) + glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) + if not visible: + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + window = glfw.create_window(width, height, title, None, None) + if not window: + glfw.terminate() + return False + glfw.set_input_mode(window, glfw.STICKY_KEYS, gl.GL_TRUE) + glfw.make_context_current(window) + glfw.swap_interval(0) + return window + + +def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=True): + """Create an OpenGL context, trying GLFW first and EGL as a fallback. + + The function attempts context creation in this priority order: + + 1. **GLFW visible window** — normal path on workstations. + 2. **GLFW invisible window** — when a display exists but no screen + is needed (e.g. a remote desktop session). + 3. **EGL pbuffer** — fully headless; no display server required. + Works with NVIDIA/AMD GPU drivers and Mesa (llvmpipe) on CPU-only + systems. Requires ``libegl1`` (already installed in the Docker + image) and ``pyopengl >= 3.1``. + + When EGL is used the module-level ``_egl_context`` is set and + ``make_current()`` is called so that subsequent OpenGL calls work + identically to the GLFW path. + + Parameters + ---------- + width : int + Render target width in pixels. + height : int + Render target height in pixels. + title : str, optional + Window title (used for GLFW paths only). Default is ``'WhipperSnapPy'``. + visible : bool, optional + Prefer a visible window. Default is ``True``. + + Returns + ------- + GLFWwindow or None + GLFW window handle when GLFW succeeded, ``None`` when EGL is used + (the context is already current via ``_egl_context.make_current()``). + + Raises + ------ + RuntimeError + If all three methods fail to produce a usable OpenGL context. + """ + global _egl_context + + # Fast-path: if _check_display() already determined there is no working + # display, skip the two doomed GLFW attempts and go straight to EGL. + # This avoids warning noise and wasted time in Docker/CI/headless SSH. + # The sys.platform guard is preserved — EGL is Linux-only. + if os.environ.get("PYOPENGL_PLATFORM") == "egl": + if sys.platform != "linux": + raise RuntimeError( + f"Could not create any OpenGL context via GLFW on {sys.platform}. " + "Ensure a display is available." + ) + logger.info("No working display detected — using EGL headless directly.") + try: + from .egl_context import EGLContext + ctx = EGLContext(width, height) + ctx.make_current() + _egl_context = ctx + logger.info("Using EGL headless context — no display server required.") + return None + except (ImportError, RuntimeError) as exc: + raise RuntimeError( + f"EGL headless context failed: {exc}" + ) from exc + + # --- Step 1: GLFW visible window --- + window = init_window(width, height, title, visible=visible) + if window: + return window + + # --- Step 2: GLFW invisible window --- + if visible: + logger.warning( + "Could not create visible GLFW window; retrying with invisible window." + ) + window = init_window(width, height, title, visible=False) + if window: + return window + + # --- Step 3: EGL headless pbuffer (Linux only) --- + logger.warning( + "GLFW context creation failed entirely (no display?). " + "Attempting EGL headless context." + ) + if sys.platform != "linux": + raise RuntimeError( + f"Could not create any OpenGL context via GLFW on {sys.platform}. " + "Ensure a display is available." + ) + try: + from .egl_context import EGLContext + ctx = EGLContext(width, height) + ctx.make_current() + _egl_context = ctx + logger.info("Using EGL headless context — no display server required.") + return None + except (ImportError, RuntimeError) as exc: + raise RuntimeError( + "Could not create any OpenGL context (tried GLFW visible, " + f"GLFW invisible, EGL pbuffer). Last error: {exc}" + ) from exc + + +def terminate_context(window): + """Release the active OpenGL context regardless of how it was created. + + This is a drop-in replacement for ``glfw.terminate()`` that also + handles the EGL path. Call it at the end of every rendering function + instead of calling ``glfw.terminate()`` directly. + + Parameters + ---------- + window : GLFWwindow or None + The GLFW window handle returned by ``create_window_with_fallback``, + or ``None`` when EGL is active. + """ + global _egl_context + if _egl_context is not None: + _egl_context.destroy() # type: ignore[union-attr] + _egl_context = None + else: + glfw.terminate() + + +def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): + """Create shader program, upload mesh and initialize camera & lighting. + + This is a convenience wrapper that compiles default shaders, creates + VAO/VBO/EBO, and configures common uniforms (camera matrices, lighting). + + Parameters + ---------- + meshdata : numpy.ndarray + Interleaved vertex data. + triangles : numpy.ndarray + Triangle indices. + width, height : int + Framebuffer size used to compute projection matrix. + specular : bool, optional, default True + Enable specular highlights. + ambient : float, optional, default 0.0 + Ambient lighting strength. + + Returns + ------- + shader : int + Compiled OpenGL shader program handle. + """ + vertex_shader, fragment_shader = get_default_shaders() + + create_vao() + shader = compile_shader_program(vertex_shader, fragment_shader) + setup_buffers(meshdata, triangles) + setup_vertex_attributes(shader) + + gl.glUseProgram(shader) + set_default_gl_state() + + view = make_view() + projection = make_projection(width, height) + model = make_model() + set_camera_uniforms(shader, view, projection, model) + set_lighting_uniforms(shader, specular=specular, ambient=ambient) + + return shader + +def capture_window(window): + """Read the current GL framebuffer and return it as a PIL Image (RGB). + + Works for both GLFW windows and EGL headless contexts. When EGL is + active (``window`` is ``None``) the pixels are read from the FBO that + was set up by :class:`~whippersnappy.gl.egl_context.EGLContext`; in + that case there is no HiDPI scaling to account for. + + Parameters + ---------- + window : GLFWwindow or None + GLFW window handle, or ``None`` when an EGL context is active. + + Returns + ------- + PIL.Image.Image + RGB image of the rendered frame, with the vertical flip applied so + that the origin is at the top-left (image convention). + """ + global _egl_context + + # --- EGL path: read directly from the FBO --- + if _egl_context is not None: + return _egl_context.read_pixels() # type: ignore[union-attr] + + # --- GLFW path: read from the default framebuffer --- + monitor = glfw.get_primary_monitor() + if monitor is None: + # Invisible / offscreen GLFW window — no monitor, no HiDPI scaling. + x_scale, y_scale = 1.0, 1.0 + else: + x_scale, y_scale = glfw.get_monitor_content_scale(monitor) + width, height = glfw.get_framebuffer_size(window) + + logger.debug("Framebuffer size = (%s,%s)", width, height) + logger.debug("Monitor scale = (%s,%s)", x_scale, y_scale) + + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) + image = Image.frombytes("RGB", (width, height), img_buf) + image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) + + if x_scale != 1 or y_scale != 1: + rwidth = int(round(width / x_scale)) + rheight = int(round(height / y_scale)) + logger.debug("Rescale to = (%s,%s)", rwidth, rheight) + image.thumbnail((rwidth, rheight), Image.Resampling.LANCZOS) + + return image + + +def render_scene(shader, triangles, transform): + """Render a single draw call using the supplied shader/indices. + + Parameters + ---------- + shader : int + OpenGL shader program handle. + triangles : numpy.ndarray + Element/index array used for the draw call. + transform : array-like + 4x4 transform matrix (model/view/projection combined) to upload to + the shader uniform named ``transform``. + + Raises + ------ + RuntimeError + If a GL error occurs during rendering. + """ + try: + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + except Exception as exc: + logger.error("glClear failed: %s", exc) + raise RuntimeError(f"glClear failed: {exc}") from exc + + transform_loc = gl.glGetUniformLocation(shader, "transform") + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transform) + gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) + + err = gl.glGetError() + if err != gl.GL_NO_ERROR: + logger.error("OpenGL error after draw: %s", err) + raise RuntimeError(f"OpenGL error after draw: {err}") diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py new file mode 100644 index 0000000..0d2af24 --- /dev/null +++ b/whippersnappy/gl/views.py @@ -0,0 +1,199 @@ +"""View matrices, presets, and interactive view state under gl package.""" + +from dataclasses import dataclass, field + +import numpy as np +import pyrr # still needed for compute_view_matrix + +from ..utils.types import ViewType + +# --------------------------------------------------------------------------- +# ViewState — single source of truth for all mutable view parameters +# --------------------------------------------------------------------------- + +@dataclass +class ViewState: + """Mutable view parameters for the interactive GUI render loop. + + All mouse/keyboard interaction updates this object; the view matrix is + recomputed from it each frame via :func:`compute_view_matrix`. + + Parameters + ---------- + rotation : np.ndarray + 4×4 float32 rotation matrix (identity = no rotation applied). + pan : np.ndarray + (x, y) pan offset in normalised screen-space units. + zoom : float + Z-translation packed into the transform matrix. + last_mouse_pos : np.ndarray or None + Last recorded mouse position in pixels; ``None`` when no button held. + left_button_down : bool + Whether the left mouse button is currently pressed. + right_button_down : bool + Whether the right mouse button is currently pressed. + middle_button_down : bool + Whether the middle mouse button is currently pressed. + """ + rotation: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + pan: np.ndarray = field( + default_factory=lambda: np.zeros(2, dtype=np.float32) + ) + zoom: float = 0.4 + last_mouse_pos: np.ndarray | None = None + left_button_down: bool = False + right_button_down: bool = False + middle_button_down: bool = False + + +def compute_view_matrix(view_state: ViewState, base_view: np.ndarray) -> np.ndarray: + """Return the ``transform`` uniform — exactly as snap_rotate does it. + + Packs ``transl * rotation * base_view`` into a single matrix, matching + the snap_rotate convention (line: ``viewmat = transl * rot * base_view``). + The ``model`` and ``view`` uniforms are left as set by ``setup_shader`` + (identity and camera respectively) and must not be overwritten. + + Parameters + ---------- + view_state : ViewState + Current interactive view state. + base_view : np.ndarray + Fixed 4×4 orientation preset from :func:`get_view_matrices`. + + Returns + ------- + np.ndarray + 4×4 float32 matrix for the ``transform`` shader uniform. + """ + transl = pyrr.Matrix44.from_translation(( + view_state.pan[0], + view_state.pan[1], + 0.4 + view_state.zoom, + )) + rot = pyrr.Matrix44(view_state.rotation) + return np.array(transl * rot * pyrr.Matrix44(base_view), dtype=np.float32) + + + +# --------------------------------------------------------------------------- +# Arcball helpers +# --------------------------------------------------------------------------- + +def arcball_vector(x: float, y: float, width: int, height: int) -> np.ndarray: + """Map a 2-D screen pixel to a point on the unit arcball sphere. + + Normalises (x, y) to [-1, 1] NDC, then projects onto the unit sphere. + Points outside the sphere radius are clamped to the rim (z = 0). + + Parameters + ---------- + x, y : float + Mouse position in pixels. + width, height : int + Window dimensions in pixels. + + Returns + ------- + np.ndarray + Unit 3-vector on (or clamped to) the arcball sphere. + """ + s = min(width, height) + p = np.array([ + (2.0 * x - width) / s, + -(2.0 * y - height) / s, + 0.0, + ], dtype=np.float64) + sq = p[0] ** 2 + p[1] ** 2 + if sq <= 1.0: + p[2] = np.sqrt(1.0 - sq) + else: + p /= np.sqrt(sq) # clamp to rim + n = np.linalg.norm(p) + return p / n if n > 0 else p + + +def arcball_rotation_matrix(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: + """Return a 4×4 rotation matrix that rotates unit vector *v1* to *v2*. + + Uses Rodrigues' rotation formula in pure numpy — no pyrr dependency. + Returns identity when *v1* and *v2* are coincident. + + Parameters + ---------- + v1, v2 : np.ndarray + Unit 3-vectors on the arcball sphere. + + Returns + ------- + np.ndarray + 4×4 float32 rotation matrix compatible with pyrr. + """ + axis = np.cross(v1, v2) + axis_len = np.linalg.norm(axis) + if axis_len < 1e-10: + return np.eye(4, dtype=np.float32) + + axis = axis / axis_len + angle = np.arctan2(axis_len, np.dot(v1, v2)) + + # Rodrigues' formula: R = I cos(a) + sin(a) [axis]× + (1-cos(a)) axis⊗axis + c, s = np.cos(angle), np.sin(angle) + t = 1.0 - c + x, y, z = axis + r3 = np.array([ + [t*x*x + c, t*x*y - s*z, t*x*z + s*y], + [t*x*y + s*z, t*y*y + c, t*y*z - s*x], + [t*x*z - s*y, t*y*z + s*x, t*z*z + c ], + ], dtype=np.float32) + + r4 = np.eye(4, dtype=np.float32) + r4[:3, :3] = r3 + return r4 + + +def get_view_matrices(): + """Return canonical 4x4 view matrices for common brain orientations. + + The returned dictionary maps :class:`whippersnappy.utils.types.ViewType` + enum members to corresponding 4x4 view matrices (dtype float32) that + can be used as camera/view transforms in the OpenGL renderer. + + Returns + ------- + dict + Mapping of :class:`ViewType` -> 4x4 numpy.ndarray view matrix. + """ + view_left = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_right = np.array([[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_back = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_front = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_bottom = np.array([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float32) + view_top = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32) + + return { + ViewType.LEFT: view_left, + ViewType.RIGHT: view_right, + ViewType.BACK: view_back, + ViewType.FRONT: view_front, + ViewType.TOP: view_top, + ViewType.BOTTOM: view_bottom, + } + + +def get_view_matrix(view_type): + """Return the 4x4 view matrix for a single :class:`ViewType`. + + Parameters + ---------- + view_type : ViewType + Enum member indicating the requested view. + + Returns + ------- + numpy.ndarray + 4x4 float32 view matrix. + """ + return get_view_matrices()[view_type] diff --git a/whippersnappy/gui/__init__.py b/whippersnappy/gui/__init__.py new file mode 100644 index 0000000..6c5c9a9 --- /dev/null +++ b/whippersnappy/gui/__init__.py @@ -0,0 +1,4 @@ +from .config_app import ConfigWindow + +__all__ = ["ConfigWindow"] + diff --git a/whippersnappy/config_app.py b/whippersnappy/gui/config_app.py similarity index 73% rename from whippersnappy/config_app.py rename to whippersnappy/gui/config_app.py index e824a79..0f85ed7 100644 --- a/whippersnappy/config_app.py +++ b/whippersnappy/gui/config_app.py @@ -5,10 +5,6 @@ Dependencies: PyQt6 - -@Author : Ahmed Faisal Abdelrahman -@Created : 20.03.2022 - """ from PyQt6.QtCore import Qt @@ -23,20 +19,24 @@ class ConfigWindow(QWidget): - """ - Encapsulates the Qt widget for the parameter configuration. + """Qt configuration window for interactive parameter tuning. + + The configuration window exposes sliders and text boxes to adjust the + f-threshold and f-max parameters used by the renderer. The widget is + intended to run alongside the OpenGL window and to push updated values + to the renderer via polling from the main loop. Parameters ---------- - parent : QWidget - This widget's parent, if any (usually none). - screen_dims : tuple - Integers specifying screen dims in pixels; used to always position - the window in the top-right corner, if given. - initial_fthresh_value : float - Initial fthreshold value is 2.0. - initial_fmax_value : float - Initial fmax value is 4.0. + parent : QWidget, optional + Parent Qt widget. Defaults to ``None``. + screen_dims : tuple or None, optional + (width, height) of the available screen; used to position the + window in the top-right corner when provided. + initial_fthresh_value : float, optional + Initial threshold value (default 2.0). + initial_fmax_value : float, optional + Initial fmax value (default 4.0). """ def __init__( @@ -153,11 +153,11 @@ def __init__( self.setGeometry(0, 0, self.window_size[0], self.window_size[1]) def fthresh_slider_value_cb(self): - """ - Callback function for user-modified fthresh slider. + """Handle changes from the f-threshold slider. - This function is triggered when the user modifies the fthresh slider. It - stores the selected value and updates the corresponding user input box. + This slot is connected to the slider's valueChanged signal. It maps the + slider tick value into the configured value range and updates the + text input box accordingly. """ self.current_fthresh_value = self.convert_value_to_range( self.fthresh_slider.value(), @@ -167,22 +167,12 @@ def fthresh_slider_value_cb(self): self.fthresh_value_box.setText(str(self.current_fthresh_value)) def fthresh_value_cb(self, new_value): - """ - Callback function for user input of fthresh value. - - This function is triggered when the user inputs a value for fthresh. It - stores the selected value and updates the corresponding slider. + """Handle text input changes for f-threshold. Parameters ---------- new_value : float or str - The new value input by the user. It can be a float or a string that - can be converted to a float. - - Returns - ------- - None - This function does not return any value. + The new value input by the user. May be a float or numeric string. """ # Do not react to invalid values: try: @@ -200,12 +190,7 @@ def fthresh_value_cb(self, new_value): self.fthresh_slider.setValue(int(slider_fthresh_value)) def fmax_slider_value_cb(self): - """ - Callback function for user-modified fmax slider. - - This function is triggered when the user modifies the fmax slider. It - stores the selected value and updates the corresponding user input box. - """ + """Handle changes from the f-max slider and update the text box.""" self.current_fmax_value = self.convert_value_to_range( self.fmax_slider.value(), self.fmax_slider_tick_limits, @@ -214,22 +199,12 @@ def fmax_slider_value_cb(self): self.fmax_value_box.setText(str(self.current_fmax_value)) def fmax_value_cb(self, new_value): - """ - Callback function for user input of fmax value. - - This function is triggered when the user inputs a value for fmax. It - stores the selected value and updates the corresponding slider. + """Handle text input changes for f-max. Parameters ---------- new_value : float or str - The new value input by the user. It can be a float or a string that - can be converted to a float. - - Returns - ------- - None - This function does not return any value. + New value provided by the user. """ # Do not react to invalid values: try: @@ -247,25 +222,21 @@ def fmax_value_cb(self, new_value): self.fmax_slider.setValue(int(slider_fmax_value)) def convert_value_to_range(self, value, old_limits, new_limits): - """ - Convert a given number from one range to another. - - This is useful for transforming values from the original range to that - of the slider widget tick range and vice-versa. + """Map ``value`` from ``old_limits`` to ``new_limits``. Parameters ---------- value : float Value to be converted. old_limits : tuple - Minimum and maximum values that define the source range. + (min, max) source range. new_limits : tuple - Minimum and maximum values that define the target range. + (min, max) target range. Returns ------- - new_value : float - Converted value. + float + Value mapped into ``new_limits``. """ old_range = old_limits[1] - old_limits[0] new_range = new_limits[1] - new_limits[0] @@ -274,40 +245,34 @@ def convert_value_to_range(self, value, old_limits, new_limits): return new_value def get_fthresh_value(self): - """ - Return the current stores value for fthresh. + """Return the currently selected f-threshold value. Returns ------- - current_fthresh_value: float - Current fthresh value. + float + Current f-threshold value. """ return self.current_fthresh_value def get_fmax_value(self): - """ - Return the current stores value for fmax. + """Return the currently selected f-max value. Returns ------- - current_fmax_value : float - Current fmax value. + float + Current f-max value. """ return self.current_fmax_value def keyPressEvent(self, event): - """ - Close the window when the ESC key is pressed. + """Handle key press events for the window. + + The handler closes the window when the ESC key is pressed. Parameters ---------- event : QKeyEvent - The key event object representing the key press. - - Returns - ------- - None - This function return None. + Qt key event delivered by the framework. """ if event.key() == Qt.Key.Escape: self.close() diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py new file mode 100644 index 0000000..8fefcce --- /dev/null +++ b/whippersnappy/plot3d.py @@ -0,0 +1,249 @@ +"""3D plotting for WhipperSnapPy using pythreejs (Three.js) for Jupyter. + +This module provides interactive 3D brain visualization for Jupyter notebooks +using Three.js/WebGL. It works in all Jupyter environments (browser, +JupyterLab, Colab, VS Code). + +Unlike the desktop GUI (``whippersnap`` command), this renders entirely in the +browser via WebGL and is designed for notebook environments. + +Usage:: + + from whippersnappy import plot3d + viewer = plot3d(mesh='path/to/lh.white', bg_map='path/to/lh.curv') + display(viewer) + +Dependencies: + pythreejs, ipywidgets, numpy +""" + +import logging + +import numpy as np +import pythreejs as p3js +from ipywidgets import HTML, VBox + +from .geometry import prepare_geometry +from .gl import get_webgl_shaders +from .utils.types import ColorSelection + +# Module logger +logger = logging.getLogger(__name__) + + +def plot3d( + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=None, + width=800, + height=800, + ambient=0.1, +): + """Create an interactive 3D notebook viewer using pythreejs (Three.js). + + This function prepares geometry and color information (via + :func:`whippersnappy.geometry.prepare_geometry`) and constructs a + pythreejs renderer and controls wrapped in an ``ipywidgets.VBox`` for + display inside a Jupyter notebook. + + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``), + ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY + (``.ply``), or a ``(vertices, faces)`` numpy array tuple. + + Parameters + ---------- + mesh : str or tuple of (array-like, array-like) + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple. + overlay : str, array-like, or None, optional + Path to a per-vertex scalar file, or a (N,) array of per-vertex + scalar values. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple for categorical labeling. + bg_map : str, array-like, or None, optional + Path to a per-vertex scalar file **or** a (N,) array used as + grayscale background shading for non-overlay regions. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file **or** a (N,) boolean array to + restrict overlay coloring to a subset of vertices. + minval, maxval : float or None, optional + Threshold and saturation values used for color mapping. + If ``None``, sensible defaults are chosen automatically. + invert : bool, optional + If True, invert the overlay color map. Default is ``False``. + scale : float, optional + Global geometry scale applied during preparation. Default is ``1.85``. + color_mode : ColorSelection or None, optional + Which sign of overlay values to color (BOTH/POSITIVE/NEGATIVE). + If None, defaults to ``ColorSelection.BOTH``. + width, height : int, optional + Canvas dimensions for the generated renderer. Default is ``800``. + ambient : float, optional + Ambient lighting strength passed to the Three.js shader. Default is ``0.1``. + + Returns + ------- + ipywidgets.VBox + A widget containing the pythreejs Renderer and a small info panel. + + Raises + ------ + ValueError, FileNotFoundError + Errors from :func:`prepare_geometry` are propagated (for example + shape mismatches between overlay and mesh vertex count). + + Examples + -------- + In a Jupyter notebook:: + + from whippersnappy import plot3d + from IPython.display import display + + # FreeSurfer surface + viewer = plot3d('lh.white', overlay='lh.thickness', bg_map='lh.curv') + display(viewer) + + # Any triangular mesh via OFF / VTK / PLY + viewer = plot3d('mesh.off', overlay='values.mgh') + display(viewer) + + # Array inputs + import numpy as np + v = np.random.randn(500, 3).astype(np.float32) + f = np.zeros((1, 3), dtype=np.uint32) + viewer = plot3d((v, f)) + display(viewer) + """ + # Load and prepare mesh data + color_mode = color_mode or ColorSelection.BOTH + meshdata, triangles, fmin, fmax, pos, neg = prepare_geometry( + mesh, overlay, annot, bg_map, roi, + minval, maxval, invert, scale, color_mode + ) + + logger.info("Loaded mesh: %d vertices, %d faces", meshdata.shape[0], triangles.shape[0]) + + # Extract vertices, normals, and colors + vertices = meshdata[:, :3] # x, y, z + normals = meshdata[:, 3:6] # nx, ny, nz + colors = meshdata[:, 6:9] # r, g, b + + # Center and scale the mesh + center = vertices.mean(axis=0) + vertices = vertices - center + max_extent = np.abs(vertices).max() + vertices = vertices / max_extent * 2.0 + + # Create Three.js mesh + mesh = create_threejs_mesh_with_custom_shaders(vertices, triangles, colors, normals, ambient=ambient) + + camera = p3js.PerspectiveCamera( + position=[-5, 0, 0], + up=[0, 0, 1], + aspect=width/height, + fov=45, + near=0.1, + far=1000 + ) + + # Create scene without lights (use our own custom shader): + scene = p3js.Scene( + children=[mesh, camera], # No lights needed + background='#000000' + ) + + # Create renderer + renderer = p3js.Renderer( + camera=camera, + scene=scene, + controls=[p3js.OrbitControls(controlling=camera)], + width=width, + height=height, + antialias=True + ) + + # Create info display + info_text = f""" +
+ Interactive 3D Viewer (Three.js) ✓
+ Vertices: {len(vertices):,}
+ Triangles: {len(triangles):,}
+
+ Drag to rotate, scroll to zoom, right-drag to pan
+ """ + + if overlay or annot: + info_text += "
📊 Overlay/annotation loaded" + elif bg_map: + info_text += "
🧠 Curvature (grayscale is correct)" + + info_text += "
" + + info = HTML(value=info_text) + + # Combine renderer and info + viewer = VBox([renderer, info]) + + return viewer + +def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals, ambient=0.1): + """Create a pythreejs.Mesh using custom shader material and buffers. + + The function builds a BufferGeometry with position, color and normal + attributes, attaches an index buffer, and creates a ShaderMaterial + using the WebGL shader snippets returned by :func:`get_webgl_shaders`. + + Parameters + ---------- + vertices : numpy.ndarray + Array of shape (N, 3) with vertex positions (float32). + faces : numpy.ndarray + Integer face index array shape (M, 3) or flattened (3*M,) dtype uint32. + colors : numpy.ndarray + Array of shape (N, 3) with per-vertex RGB colors (float32). + normals : numpy.ndarray + Array of shape (N, 3) with per-vertex normals (float32). + ambient : float, optional, default 0.1 + Ambient lighting strength for the shader (passed to Three.js uniform). + + Returns + ------- + pythreejs.Mesh + Mesh object ready to be inserted into a pythreejs.Scene. + """ + vertices = vertices.astype(np.float32) + colors = colors.astype(np.float32) + normals = normals.astype(np.float32) + faces = faces.astype(np.uint32).flatten() + + vertex_shader, fragment_shader = get_webgl_shaders() + + geometry = p3js.BufferGeometry( + attributes={ + 'position': p3js.BufferAttribute(array=vertices), + 'color': p3js.BufferAttribute(array=colors), + 'normal': p3js.BufferAttribute(array=normals), + } + ) + geometry.index = p3js.BufferAttribute(array=faces) + + material = p3js.ShaderMaterial( + vertexShader=vertex_shader, + fragmentShader=fragment_shader, + uniforms={ + 'lightColor': {'value': [1.0, 1.0, 1.0]}, + 'ambientStrength': {'value': ambient} + } + ) + + three_mesh = p3js.Mesh(geometry=geometry, material=material) + return three_mesh diff --git a/whippersnappy/Roboto-Regular.ttf b/whippersnappy/resources/fonts/Roboto-Regular.ttf similarity index 100% rename from whippersnappy/Roboto-Regular.ttf rename to whippersnappy/resources/fonts/Roboto-Regular.ttf diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py new file mode 100644 index 0000000..176d98e --- /dev/null +++ b/whippersnappy/snap.py @@ -0,0 +1,712 @@ +"""Snapshot (static rendering) API for WhipperSnapPy.""" + +import logging +import os + +import glfw +import numpy as np +import pyrr +from PIL import Image, ImageFont + +from .geometry import estimate_overlay_thresholds, get_surf_name +from .geometry.prepare import prepare_and_validate_geometry +from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader, terminate_context +from .gl.views import get_view_matrices +from .utils.image import create_colorbar, draw_caption, draw_colorbar, load_roboto_font, text_size +from .utils.types import ColorSelection, OrientationType, ViewType + +# Module logger +logger = logging.getLogger(__name__) + + +def snap1( + mesh, + outpath=None, + overlay=None, + annot=None, + bg_map=None, + roi=None, + view=ViewType.LEFT, + viewmat=None, + width=700, + height=500, + fthresh=None, + fmax=None, + caption=None, + caption_x=None, + caption_y=None, + caption_scale=1, + invert=False, + colorbar=True, + colorbar_x=None, + colorbar_y=None, + colorbar_scale=1, + orientation=OrientationType.HORIZONTAL, + color_mode=ColorSelection.BOTH, + font_file=None, + specular=True, + brain_scale=1.5, + ambient=0.0, +): + """Render a single static snapshot of a surface mesh. + + This function opens an OpenGL context, uploads the provided + surface geometry and colors (overlay or annotation), renders the scene + for a single view, captures the framebuffer, and returns a PIL Image. + When ``outpath`` is provided the image is also written to disk. + + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``), + ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY + (``.ply``), or a ``(vertices, faces)`` numpy array tuple. + + Parameters + ---------- + mesh : str or tuple of (array-like, array-like) + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple where *vertices* is + (N, 3) float and *faces* is (M, 3) int. + outpath : str or None, optional + When provided, the resulting image is saved to this path. + overlay : str, array-like, or None, optional + Overlay file path (``.mgh`` or FreeSurfer morph) **or** a (N,) array + of per-vertex scalar values. If ``None``, coloring falls back to + background shading / annotation. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file **or** a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple with per-vertex labels. + bg_map : str, array-like, or None, optional + Path to a per-vertex scalar file **or** a (N,) array whose sign + determines light/dark background shading for non-overlay vertices. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file **or** a (N,) boolean array. + Vertices with ``True`` receive overlay coloring; others fall back + to *bg_map* shading. + view : ViewType, optional + Which pre-defined view to render (left, right, front, ...). + Default is ``ViewType.LEFT``. + viewmat : 4x4 matrix-like, optional + Optional view matrix to override the pre-defined view. + width, height : int, optional + Output canvas size in pixels. Defaults to (700×500). + fthresh, fmax : float or None, optional + Threshold and saturation values for overlay coloring. + caption, caption_x, caption_y, caption_scale : str/float, optional + Caption text and layout parameters. + invert : bool, optional + Invert the color scale. Default is ``False``. + colorbar : bool, optional + If True, render a colorbar when an overlay is present. Default is ``True``. + colorbar_x, colorbar_y, colorbar_scale : float, optional + Colorbar positioning and scale. Scale defaults to 1. + orientation : OrientationType, optional + Colorbar orientation (HORIZONTAL/VERTICAL). Default is ``OrientationType.HORIZONTAL``. + color_mode : ColorSelection, optional + Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``. + font_file : str or None, optional + Path to a TTF font for captions; fallback to bundled font if None. + specular : bool, optional + Enable specular highlights. Default is ``True``. + brain_scale : float, optional + Scale factor applied when preparing the geometry. Default is ``1.5``. + ambient : float, optional + Ambient lighting strength for shader. Default is ``0.0``. + + Returns + ------- + PIL.Image.Image + Rendered snapshot as a PIL Image. + + Raises + ------ + RuntimeError + If the OpenGL/GLFW context cannot be initialized. + ValueError + If the overlay contains no values to display for the chosen + color_mode. + FileNotFoundError + If a required file cannot be found. + + Examples + -------- + FreeSurfer surface with overlay:: + + >>> from whippersnappy import snap1 + >>> img = snap1('lh.white', overlay='lh.thickness', + ... bg_map='lh.curv', roi='lh.cortex.label') + >>> img.save('/tmp/lh.png') + + Array inputs (any triangular mesh):: + + >>> import numpy as np + >>> v = np.random.randn(100, 3).astype(np.float32) + >>> f = np.array([[0, 1, 2]], dtype=np.uint32) + >>> img = snap1((v, f)) + + OFF / VTK / PLY file:: + + >>> img = snap1('mesh.off', overlay='values.mgh') + """ + ref_width = 700 + ref_height = 500 + ui_scale = min(width / ref_width, height / ref_height) + try: + if glfw.init(): + primary_monitor = glfw.get_primary_monitor() + if primary_monitor: + mode = glfw.get_video_mode(primary_monitor) + if width > mode.size.width: + logger.info("Requested width %d exceeds screen width %d, expect black bars", + width, mode.size.width) + elif height > mode.size.height: + logger.info("Requested height %d exceeds screen height %d, expect black bars", + height, mode.size.height) + except Exception: + pass # headless — no monitor info available, that's fine + + image = Image.new("RGB", (width, height)) + + bwidth = int(540 * brain_scale * ui_scale) + bheight = int(450 * brain_scale * ui_scale) + brain_display_width = min(bwidth, width) + brain_display_height = min(bheight, height) + logger.debug("Requested (width,height) = (%s,%s)", width, height) + logger.debug("Brain (width,height) = (%s,%s)", bwidth, bheight) + logger.debug("B-Display (width,height) = (%s,%s)", brain_display_width, brain_display_height) + + # will raise exception if it cannot be created + window = create_window_with_fallback(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + mesh, + overlay, + annot, + bg_map, + roi, + fthresh, + fmax, + invert, + scale=brain_scale, + color_mode=color_mode, + ) + + shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, + specular=specular, ambient=ambient) + + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + view_mats = get_view_matrices() + viewmat = transl * (view_mats[view] if viewmat is None else viewmat) + render_scene(shader, triangles, viewmat) + + # Center the brain rendering in the output image, clamp to zero + brain_x = max(0, (width - brain_display_width) // 2) + brain_y = max(0, (height - brain_display_height) // 2) + image.paste(capture_window(window), (brain_x, brain_y)) + + bar = ( + create_colorbar( + fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file + ) + if overlay is not None and colorbar + else None + ) + font = ( + load_roboto_font(int(20 * caption_scale * ui_scale)) + if font_file is None + else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) + if caption + else None + ) + + # Compute positions to avoid overlap, unless explicit positions are given + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bar_h = bar.height if bar is not None else 0 + gap = int(4 * ui_scale) + bottom_pad = int(20 * ui_scale) + + if orientation == OrientationType.HORIZONTAL: + # If explicit positions are given, use them + if colorbar_x is not None or colorbar_y is not None or caption_x is not None or caption_y is not None: + bx = int(colorbar_x * width) if colorbar_x is not None else None + by = int(colorbar_y * height) if colorbar_y is not None else None + cx = int(caption_x * width) if caption_x is not None else None + cy = int(caption_y * height) if caption_y is not None else None + draw_colorbar(image, bar, orientation, x=bx, y=by) + draw_caption(image, caption, font, orientation, x=cx, y=cy) + else: + # Place colorbar above caption if both present + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + caption_y = image.height - bottom_pad - text_h + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + caption_y = None + elif caption: + bar_y = None + caption_y = image.height - bottom_pad - text_h + else: + bar_y = caption_y = None + draw_colorbar(image, bar, orientation, y=bar_y) + draw_caption(image, caption, font, orientation, y=caption_y) + else: + # For vertical, allow explicit x/y for both, else use default + bx = int(colorbar_x * width) if colorbar_x is not None else None + by = int(colorbar_y * height) if colorbar_y is not None else None + cx = int(caption_x * width) if caption_x is not None else None + cy = int(caption_y * height) if caption_y is not None else None + draw_colorbar(image, bar, orientation, x=bx, y=by) + draw_caption(image, caption, font, orientation, x=cx, y=cy) + + if outpath: + logger.info("Saving snapshot to %s", outpath) + image.save(outpath) + return image + finally: + terminate_context(window) + + +def snap4( + lh_overlay=None, + rh_overlay=None, + lh_annot=None, + rh_annot=None, + fthresh=None, + fmax=None, + sdir=None, + caption=None, + invert=False, + roi_name="cortex.label", + surfname=None, + bg_map_name="curv", + colorbar=True, + outpath=None, + font_file=None, + specular=True, + ambient=0.0, + brain_scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Render four snapshot views (left/right hemispheres, lateral/medial). + + This convenience function renders four views (lateral/medial for each + hemisphere), stitches them together into a single PIL Image and returns + it (and saves it to ``outpath`` when provided). It is typically used to + produce publication-ready overview figures composed from both + hemispheres. + + Parameters + ---------- + lh_overlay, rh_overlay : str, array-like, or None + Left/right hemisphere overlay — either a file path (FreeSurfer morph + or .mgh) or a per-vertex scalar array. Typically provided as a pair + for a coherent two-hemisphere color scale. + lh_annot, rh_annot : str, tuple, or None + Left/right hemisphere annotation — either a path to a .annot file or + a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple. + Cannot be combined with ``lh_overlay``/``rh_overlay``. + fthresh, fmax : float or None + Threshold and saturation for overlay coloring. Auto-estimated when + ``None``. + sdir : str or None + Subject directory containing ``surf/`` and ``label/`` subdirectories. + Falls back to ``$SUBJECTS_DIR`` when ``None``. + caption : str or None + Caption string to place on the final image. + invert : bool, optional + Invert color scale. Default is ``False``. + roi_name : str, optional + Basename of the label file used to restrict overlay coloring (default + ``'cortex.label'``). The full path is constructed as + ``/label/.``. + surfname : str or None, optional + Surface basename to load (e.g. ``'white'``); auto-detected when + ``None``. + bg_map_name : str, optional + Basename of the curvature/morph file used for background shading + (default ``'curv'``). The full path is constructed as + ``/surf/.``. + colorbar : bool, optional + Whether to draw a colorbar on the composed image. Default is ``True``. + outpath : str or None, optional + If provided, save composed image to this path. + font_file : str or None, optional + Path to a font to use for captions. + specular : bool, optional + Enable/disable specular highlights in the renderer. Default is ``True``. + ambient : float, optional + Ambient lighting strength. Default is ``0``. + brain_scale : float, optional + Scaling factor passed to geometry preparation. Default is ``1.85``. + color_mode : ColorSelection, optional + Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``. + + Returns + ------- + PIL.Image.Image + Composed image of the four views. + + Raises + ------ + ValueError + For invalid argument combinations or when required overlay values + are absent. + FileNotFoundError + When required surface files are not found. + + Examples + -------- + >>> from whippersnappy import snap4 + >>> img = snap4( + ... lh_overlay='fsaverage/surf/lh.thickness', + ... rh_overlay='fsaverage/surf/rh.thickness', + ... sdir='./fsaverage' + ... ) + >>> img.save('/tmp/whippersnappy_overview.png') + """ + wwidth = 540 + wheight = 450 + + # Resolve sdir early so path-building works for both the pre-pass and + # the rendering loop. + if sdir is None: + sdir = os.environ.get("SUBJECTS_DIR") + if not sdir and surfname is None: + logger.error("No sdir or SUBJECTS_DIR provided") + raise ValueError("No sdir or SUBJECTS_DIR provided") + if not sdir and surfname is not None: + logger.error("surfname provided but sdir is None") + raise ValueError("surfname provided but sdir is None; cannot construct mesh path.") + + # Pre-pass: estimate missing fthresh/fmax from overlays for global color scale + has_overlay = lh_overlay is not None or rh_overlay is not None + if has_overlay and (fthresh is None or fmax is None): + est_fthreshs = [] + est_fmaxs = [] + for _overlay in filter(None, (lh_overlay, rh_overlay)): + h_fthresh, h_fmax = estimate_overlay_thresholds(_overlay, fthresh, fmax) + est_fthreshs.append(h_fthresh) + est_fmaxs.append(h_fmax) + if fthresh is None and est_fthreshs: + fthresh = min(est_fthreshs) + if fmax is None and est_fmaxs: + fmax = max(est_fmaxs) + logger.debug("Global color range: fthresh=%s fmax=%s", fthresh, fmax) + + # will raise exception if it cannot be created + window = create_window_with_fallback(wwidth, wheight, "WhipperSnapPy", visible=True) + try: + # Use standard view matrices from get_view_matrices and ViewType + view_mats = get_view_matrices() + view_left = view_mats[ViewType.LEFT] + view_right = view_mats[ViewType.RIGHT] + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + + # Predefine hemisphere images so static analysis knows they exist even if + # an earlier step raises an exception (we still will fail at runtime). + lhimg = None + rhimg = None + + for hemi in ("lh", "rh"): + if surfname is None: + found_surfname = get_surf_name(sdir, hemi) + if found_surfname is None: + logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi) + raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}") + mesh = os.path.join(sdir, "surf", hemi + "." + found_surfname) + else: + mesh = os.path.join(sdir, "surf", hemi + "." + surfname) + + # Assign derived paths for bg_map and roi + bg_map = os.path.join(sdir, "surf", hemi + "." + bg_map_name) if bg_map_name else None + roi = os.path.join(sdir, "label", hemi + "." + roi_name) if roi_name else None + overlay = lh_overlay if hemi == "lh" else rh_overlay + annot = lh_annot if hemi == "lh" else rh_annot + + # If overlay is an array, it doesn't have a path to log; handle gracefully + if isinstance(overlay, str): + logger.debug("overlay=%s exists=%s", overlay, os.path.exists(overlay)) + elif overlay is not None: + logger.debug("overlay=", getattr(overlay, 'shape', None)) + + # Diagnostic: report mesh and overlay paths and whether they exist + logger.debug("hemisphere=%s", hemi) + if isinstance(mesh, str): + logger.debug("mesh=%s exists=%s", mesh, os.path.exists(mesh)) + if isinstance(annot, str) and annot is not None: + logger.debug("annot=%s exists=%s", annot, os.path.exists(annot)) + if bg_map is not None: + logger.debug("bg_map=%s exists=%s", bg_map, os.path.exists(bg_map)) + + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + mesh, overlay, annot, bg_map, roi, fthresh, fmax, invert, + scale=brain_scale, color_mode=color_mode + ) + except Exception as e: + logger.error("prepare_geometry failed for %s: %s", mesh, e) + raise + + # Diagnostics about mesh data + try: + logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), + getattr(triangles, 'size', None)) + except Exception: + pass + + try: + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) + logger.debug("Shader setup complete") + except Exception as e: + logger.error("setup_shader failed: %s", e) + raise + + render_scene(shader, triangles, transl * view_left) + im1 = capture_window(window) + render_scene(shader, triangles, transl * view_right) + im2 = capture_window(window) + + if hemi == "lh": + lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + lhimg.paste(im1, (0, 0)) + lhimg.paste(im2, (0, im1.height)) + else: + rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + # For right hemisphere, reverse the order: top=im2, bottom=im1 + rhimg.paste(im2, (0, 0)) + rhimg.paste(im1, (0, im2.height)) + + # Add small padding around each hemisphere to avoid cropping at edges + pad = max(4, int(0.03 * wwidth)) + padded_lh = Image.new("RGB", (lhimg.width + 2 * pad, lhimg.height + 2 * pad), (0, 0, 0)) + padded_lh.paste(lhimg, (pad, pad)) + padded_rh = Image.new("RGB", (rhimg.width + 2 * pad, rhimg.height + 2 * pad), (0, 0, 0)) + padded_rh.paste(rhimg, (pad, pad)) + + image = Image.new("RGB", (padded_lh.width + padded_rh.width, padded_lh.height)) + image.paste(padded_lh, (0, 0)) + image.paste(padded_rh, (padded_lh.width, 0)) + + font = load_roboto_font(20) if font_file is None else ImageFont.truetype(font_file, 20) if caption else None + # Place caption at bottom, colorbar above if both present + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bottom_pad = 20 + gap = 4 + caption_y = image.height - bottom_pad - text_h + bar = ( + create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) + if lh_annot is None and rh_annot is None and colorbar + else None + ) + bar_h = bar.height if bar is not None else 0 + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + elif caption: + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + + # If outpath is specified, save to disk + if outpath: + logger.info("Saving snapshot to %s", outpath) + image.save(outpath) + + return image + finally: + terminate_context(window) + + +def snap_rotate( + mesh, + outpath, + n_frames=72, + fps=24, + width=700, + height=500, + overlay=None, + bg_map=None, + annot=None, + roi=None, + fthresh=None, + fmax=None, + invert=False, + specular=True, + ambient=0.0, + brain_scale=1.5, + start_view=ViewType.LEFT, + color_mode=ColorSelection.BOTH, +): + """Render a rotating 360° video of a surface mesh. + + Rotates the view around the vertical (Y) axis in ``n_frames`` equal + steps, captures each frame via OpenGL, and encodes the result into a + video file. An animated GIF can be produced by passing an ``outpath`` + ending in ``.gif``; in that case ``imageio-ffmpeg`` is not required. + + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface, ASCII OFF (``.off``), + legacy ASCII VTK PolyData (``.vtk``), ASCII PLY (``.ply``), or a + ``(vertices, faces)`` numpy array tuple. + + Parameters + ---------- + mesh : str or tuple of (array-like, array-like) + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple. + outpath : str + Destination file path. The extension controls the output format: + + * ``.mp4`` — H.264 MP4 (recommended, requires ``imageio-ffmpeg``). + * ``.webm`` — VP9 WebM (requires ``imageio-ffmpeg``). + * ``.gif`` — animated GIF (no ffmpeg required, but larger file). + + n_frames : int, optional + Number of frames for a full 360° rotation. Default is ``72`` + (one frame every 5°). + fps : int, optional + Output frame rate in frames per second. Default is ``24``. + width, height : int, optional + Render resolution in pixels. Defaults are ``700`` and ``500``. + overlay : str, array-like, or None, optional + Per-vertex overlay file path or array (e.g. thickness). + bg_map : str, array-like, or None, optional + Curvature/morph file path or array for background shading. + annot : str, tuple, or None, optional + FreeSurfer ``.annot`` file path or ``(labels, ctab)`` tuple. + roi : str, array-like, or None, optional + Label file path or boolean array to restrict overlay coloring. + fthresh : float or None, optional + Overlay threshold value. + fmax : float or None, optional + Overlay saturation value. + invert : bool, optional + Invert the overlay color scale. Default is ``False``. + specular : bool, optional + Enable specular highlights. Default is ``True``. + ambient : float, optional + Ambient lighting strength. Default is ``0.0``. + brain_scale : float, optional + Geometry scale factor. Default is ``1.5``. + start_view : ViewType, optional + Pre-defined view to start the rotation from. + Default is ``ViewType.LEFT``. + color_mode : ColorSelection, optional + Which overlay sign to color (POSITIVE/NEGATIVE/BOTH). + Default is ``ColorSelection.BOTH``. + + Returns + ------- + str + The resolved ``outpath`` that was written. + + Raises + ------ + ImportError + If ``imageio`` or ``imageio-ffmpeg`` is not installed and a + video format (``.mp4``, ``.webm``) was requested. + RuntimeError + If the OpenGL context cannot be initialised. + ValueError + If the overlay contains no values for the chosen color mode. + + Examples + -------- + >>> from whippersnappy import snap_rotate + >>> snap_rotate( + ... 'fsaverage/surf/lh.white', + ... '/tmp/rotation.mp4', + ... overlay='fsaverage/surf/lh.thickness', + ... ) + '/tmp/rotation.mp4' + """ + ext = os.path.splitext(outpath)[1].lower() + use_gif = ext == ".gif" + + if not use_gif: + try: + import imageio # noqa: F401 + import imageio_ffmpeg # noqa: F401 + except ImportError as exc: + raise ImportError( + f"Video output requires the 'imageio' and 'imageio-ffmpeg' packages. " + f"Install with: pip install 'whippersnappy[video]'\n" + f"Original error: {exc}" + ) from exc + import imageio + else: + try: + import imageio # noqa: F401 + except ImportError as exc: + raise ImportError( + "GIF output requires the 'imageio' package. " + "Install with: pip install 'whippersnappy[video]'" + ) from exc + import imageio + + window = create_window_with_fallback(width, height, "WhipperSnapPy", visible=True) + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + mesh, + overlay, + annot, + bg_map, + roi, + fthresh, + fmax, + invert, + scale=brain_scale, + color_mode=color_mode, + ) + logger.info( + "Rendering %d frames at %dx%d (%.0f° per step) → %s", + n_frames, width, height, 360.0 / n_frames, outpath, + ) + + shader = setup_shader(meshdata, triangles, width, height, + specular=specular, ambient=ambient) + + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + base_view = get_view_matrices()[start_view] + + frames = [] + for i in range(n_frames): + angle = 2 * np.pi * i / n_frames + rot = pyrr.Matrix44.from_y_rotation(angle) + viewmat = transl * rot * base_view + render_scene(shader, triangles, viewmat) + frames.append(np.array(capture_window(window))) + if (i + 1) % max(1, n_frames // 10) == 0: + logger.debug(" frame %d / %d", i + 1, n_frames) + + finally: + terminate_context(window) + + logger.info("Encoding %d frames to %s …", len(frames), outpath) + if use_gif: + # Pure-PIL GIF — no ffmpeg required + pil_frames = [Image.fromarray(f) for f in frames] + pil_frames[0].save( + outpath, + save_all=True, + append_images=pil_frames[1:], + loop=0, + duration=int(1000 / fps), + optimize=True, + ) + else: + writer_kwargs = { + "fps": fps, + "codec": "libx264", + "quality": 6, + "pixelformat": "yuv420p", + } + if ext == ".webm": + writer_kwargs["codec"] = "libvpx-vp9" + writer_kwargs.pop("pixelformat", None) + imageio.mimwrite(outpath, frames, **writer_kwargs) + + logger.info("Saved rotation video to %s", outpath) + return outpath + diff --git a/whippersnappy/types.py b/whippersnappy/types.py deleted file mode 100644 index b707a20..0000000 --- a/whippersnappy/types.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Contains the types used in WhipperSnapPy. - -Dependencies: - enum - -@Author : Abdulla Ahmadkhan -@Created : 02.10.2025 -@Revised : 02.10.2025 - -""" -import enum - - -class ColorSelection(enum.Enum): - BOTH = 1 - POSITIVE = 2 - NEGATIVE = 3 - -class OrientationType(enum.Enum): - HORIZONTAL = 1 - VERTICAL = 2 - -class ViewType(enum.Enum): - LEFT = 1 - RIGHT = 2 - BACK = 3 - FRONT = 4 - TOP = 5 - BOTTOM = 6 \ No newline at end of file diff --git a/whippersnappy/utils/__init__.py b/whippersnappy/utils/__init__.py index 285e1d4..26ac282 100644 --- a/whippersnappy/utils/__init__.py +++ b/whippersnappy/utils/__init__.py @@ -1 +1,4 @@ -"""Utilities module.""" +"""Utils subpackage exports.""" +from . import colormap, datasets, image, types + +__all__ = ["colormap", "datasets", "image", "types"] diff --git a/whippersnappy/utils/_config.py b/whippersnappy/utils/_config.py deleted file mode 100644 index b1d393c..0000000 --- a/whippersnappy/utils/_config.py +++ /dev/null @@ -1,112 +0,0 @@ -import platform -import re -import sys -from functools import partial -from importlib.metadata import requires, version -from typing import IO, Callable, Optional - -import psutil - - -def sys_info(fid: Optional[IO] = None, developer: bool = False): - """Print the system information for debugging. - - Parameters - ---------- - fid : file-like, default=None - The file to write to, passed to :func:`print`. - Can be None to use :data:`sys.stdout`. - developer : bool, default=False - If True, display information about optional dependencies. - """ - - ljust = 26 - out = partial(print, end="", file=fid) - package = __package__.split(".")[0] - - # OS information - requires python 3.8 or above - out("Platform:".ljust(ljust) + platform.platform() + "\n") - # Python information - out("Python:".ljust(ljust) + sys.version.replace("\n", " ") + "\n") - out("Executable:".ljust(ljust) + sys.executable + "\n") - # CPU information - out("CPU:".ljust(ljust) + platform.processor() + "\n") - out("Physical cores:".ljust(ljust) + str(psutil.cpu_count(False)) + "\n") - out("Logical cores:".ljust(ljust) + str(psutil.cpu_count(True)) + "\n") - # Memory information - out("RAM:".ljust(ljust)) - out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n") - out("SWAP:".ljust(ljust)) - out(f"{psutil.swap_memory().total / float(2 ** 30):0.1f} GB\n") - - # dependencies - out("\nDependencies info\n") - out(f"{package}:".ljust(ljust) + version(package) + "\n") - dependencies = [ - elt.split(";")[0].rstrip() for elt in requires(package) if "extra" not in elt - ] - _list_dependencies_info(out, ljust, dependencies) - - # extras - if developer: - keys = ( - "build", - "doc", - "test", - "style", - ) - for key in keys: - dependencies = [ - elt.split(";")[0].rstrip() - for elt in requires(package) - if f"extra == '{key}'" in elt or f'extra == "{key}"' in elt - ] - if len(dependencies) == 0: - continue - out(f"\nOptional '{key}' info\n") - _list_dependencies_info(out, ljust, dependencies) - - -def _list_dependencies_info(out: Callable, ljust: int, dependencies: list[str]): - """List dependencies names and versions. - - Parameters - ---------- - out : Callable - output function - ljust : int - length of returned string - dependencies : List[str] - list of dependencies - - """ - - for dep in dependencies: - # handle dependencies with version specifiers - specifiers_pattern = r"(~=|==|!=|<=|>=|<|>|===)" - specifiers = re.findall(specifiers_pattern, dep) - if len(specifiers) != 0: - dep, _ = dep.split(specifiers[0]) - while not dep[-1].isalpha(): - dep = dep[:-1] - # handle dependencies provided with a [key], e.g. pydocstyle[toml] - if "[" in dep: - dep = dep.split("[")[0] - try: - version_ = version(dep) - except Exception: - version_ = "Not found." - - # handle special dependencies with backends, C dep, .. - if dep in ("matplotlib", "seaborn") and version_ != "Not found.": - try: - from matplotlib import pyplot as plt - - backend = plt.get_backend() - except Exception: - backend = "Not found" - - out(f"{dep}:".ljust(ljust) + version_ + f" (backend: {backend})\n") - - else: - out(f"{dep}:".ljust(ljust) + version_ + "\n") diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py new file mode 100644 index 0000000..460f81b --- /dev/null +++ b/whippersnappy/utils/colormap.py @@ -0,0 +1,183 @@ +"""Colormap and value preprocessing utilities.""" + +import logging + +import numpy as np + +from .types import ColorSelection + +# Module logger +logger = logging.getLogger(__name__) + + +def heat_color(values, invert=False): + """Convert an array of float values into RGB heat color values. + + Maps scalar values to RGB triplets suitable for visualization. Input + values are expected to be in a symmetric range around zero; mapping + produces blue-to-red heat colors. NaN inputs propagate to NaN outputs. + + Parameters + ---------- + values : array_like + 1-D array of float values to map. May include NaNs. + invert : bool, optional + If True, invert the sign of the input values before mapping. + Default is False. + + Returns + ------- + numpy.ndarray + Array of shape (N, 3) and dtype float32 with RGB channels in [0, 1]. + """ + if invert: + values = -1.0 * values + vabs = np.abs(values) + colors = np.zeros((vabs.size, 3), dtype=np.float32) + crb = 0.5625 + 3 * 0.4375 * vabs + cg = 1.5 * (vabs - (1.0 / 3.0)) + n1 = values < -1.0 + nm = (values >= -1.0) & (values < -(1.0 / 3.0)) + n0 = (values >= -(1.0 / 3.0)) & (values < 0) + p0 = (values >= 0) & (values < (1.0 / 3.0)) + pm = (values >= (1.0 / 3.0)) & (values < 1.0) + p1 = values >= 1.0 + colors[n1, 1:3] = 1.0 + colors[nm, 1] = cg[nm] + colors[nm, 2] = 1.0 + colors[n0, 2] = crb[n0] + colors[p0, 0] = crb[p0] + colors[pm, 1] = cg[pm] + colors[pm, 0] = 1.0 + colors[p1, 0:2] = 1.0 + colors[np.isnan(values), :] = np.nan + return colors + + +def mask_sign(values, color_mode): + """Mask values that don't match the requested sign selection. + + Parameters + ---------- + values : array_like + Input numeric array. + color_mode : ColorSelection + Enum indicating which sign to preserve (POSITIVE, NEGATIVE, BOTH). + + Returns + ------- + numpy.ndarray + Copy of ``values`` where elements not matching the requested sign + are set to ``np.nan``. + """ + masked_values = np.copy(values) + if color_mode == ColorSelection.POSITIVE: + masked_values[masked_values < 0] = np.nan + elif color_mode == ColorSelection.NEGATIVE: + masked_values[masked_values > 0] = np.nan + return masked_values + + +def rescale_overlay(values, minval, maxval): + """Rescale overlay values into a normalized range for colormap computation. + + Values whose absolute magnitude is below ``minval`` are set to ``NaN``. + Remaining values are shifted by ``minval`` and divided by ``(maxval - minval)``. + + Parameters + ---------- + values : numpy.ndarray + Numeric array of overlay values (1-D). + minval : float + Minimum absolute threshold — values with abs < minval are treated as absent. + maxval : float + Maximum absolute value used for normalization. + + Returns + ------- + tuple + ``(values, minval, maxval, pos, neg)`` where ``values`` is the rescaled + array, and ``pos``/``neg`` are booleans indicating presence of positive + / negative values after rescaling. + + Raises + ------ + ValueError + If ``minval`` or ``maxval`` is negative. + """ + valsign = np.sign(values) + valabs = np.abs(values) + + if maxval < 0 or minval < 0: + logger.error("rescale_overlay ERROR: min and maxval should both be positive!") + raise ValueError("minval and maxval must be non-negative") + + values[valabs < minval] = np.nan + range_val = maxval - minval + if range_val == 0: + values = np.zeros_like(values) + else: + values = values - valsign * minval + values = values / range_val + + pos = np.any(values[~np.isnan(values)] > 0) + neg = np.any(values[~np.isnan(values)] < 0) + + return values, minval, maxval, pos, neg + + +def binary_color(values, thres, color_low, color_high): + """Create a binary colormap for values based on a threshold. + + Parameters + ---------- + values : array_like + 1-D array of values to map. + thres : float + Threshold value used to split the colors. + color_low, color_high : scalar or sequence + Colors assigned to values below/above the threshold. Scalars are + expanded to RGB triplets. + + Returns + ------- + numpy.ndarray + Array of shape (N, 3) and dtype float32 containing RGB colors. + """ + if np.isscalar(color_low): + color_low = np.array((color_low, color_low, color_low), dtype=np.float32) + if np.isscalar(color_high): + color_high = np.array((color_high, color_high, color_high), dtype=np.float32) + colors = np.empty((values.size, 3), dtype=np.float32) + colors[values < thres, :] = color_low + colors[values >= thres, :] = color_high + return colors + + +def mask_label(values, labelpath=None): + """Apply a label file as a mask to an array of per-vertex values. + + If ``labelpath`` is provided the function loads vertex indices from the + label file and sets all entries not listed in the label to ``NaN``. + + Parameters + ---------- + values : numpy.ndarray + 1-D array indexed by vertex id. + labelpath : str or None, optional + Path to a label file readable by ``numpy.loadtxt`` (expected format + with vertex ids in the first column after two header lines). + + Returns + ------- + numpy.ndarray + Array with vertices not included in the label set to ``np.nan``. + """ + if not labelpath: + return values + maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0]) + imask = np.ones(values.shape, dtype=bool) + imask[maskvids] = False + values[imask] = np.nan + return values + diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py new file mode 100644 index 0000000..7f2293f --- /dev/null +++ b/whippersnappy/utils/datasets.py @@ -0,0 +1,124 @@ +"""Sample dataset download utility for WhipperSnapPy. + +Downloads and caches a small anonymized FreeSurfer subject from the +WhipperSnapPy GitHub release assets for use in tutorials and tests. +""" + +from pathlib import Path + +RELEASE_URL = ( + "https://github.com/Deep-MI/WhipperSnapPy" + "/releases/download/data-v1.0/{file_name}" +) + +# Mapping of relative path inside the subject directory → SHA-256 hash. +# GitHub release assets are flat (no subdirectories), so the URL uses only +# the basename while pooch.retrieve() reconstructs the subdirectory locally. +_FILES = { + "README.md": "sha256:ecb6ddf31cec17f3a8636fc3ecac90099c441228811efed56104e29fcd301bc5", + "surf/lh.white": "sha256:4ab049fb42ca882ba9b56f8fe0d0e8814973e7fa2e0575a794d8e468abf7d62f", + "surf/lh.curv": "sha256:9edbde57be8593cd9d89d9d1124e2175edd8ecfee55d53e066d89700c480b12a", + "surf/lh.thickness": "sha256:40ab3483284608c6c5cca2d3d794a60cd1bcbeb0140bb1ca6ad0fce7962c57c6", + "surf/rh.white": "sha256:43035c53a8b04bebe4e843c34f80588f253f79052a8dbf7194b706495b11f8d2", + "surf/rh.curv": "sha256:af2bc71133d7ef17ce1a3a6f4208d2495a5a4c96da00c80b59be03bb7c8ea83f", + "surf/rh.thickness": "sha256:50ec291c73928cd697156edd9e0e77f5c54d15c56cf84810d2564b496876e132", + "label/lh.aparc.DKTatlas.mapped.annot": "sha256:4d48d33f4fd8278ab973a1552f6ea9c396dfc1791b707ed17ad8e761299c4960", + "label/lh.cortex.label": "sha256:79ae17fcfde6b2e0a75a0652fcc0f3c072e4ea62a541843b7338e01c598b0b6e", + "label/rh.aparc.DKTatlas.mapped.annot": "sha256:12217166d8ef43ee1fa280511ec2ba0796c6885f527a4455b93760acc73ce273", + "label/rh.cortex.label": "sha256:162c97c887eb1ec857fe575b8cc4e4b950c7dd5ec181a581d709bbe7fca58f9e", +} + + +def _build_dict(base: Path) -> dict: + """Build the return dictionary of paths from a subject base directory.""" + return { + "sdir": str(base), + "lh_white": str(base / "surf/lh.white"), + "lh_curv": str(base / "surf/lh.curv"), + "lh_thickness": str(base / "surf/lh.thickness"), + "rh_white": str(base / "surf/rh.white"), + "rh_curv": str(base / "surf/rh.curv"), + "rh_thickness": str(base / "surf/rh.thickness"), + "lh_annot": str(base / "label/lh.aparc.DKTatlas.mapped.annot"), + "lh_label": str(base / "label/lh.cortex.label"), + "rh_annot": str(base / "label/rh.aparc.DKTatlas.mapped.annot"), + "rh_label": str(base / "label/rh.cortex.label"), + } + + +def fetch_sample_subject() -> dict: + """Download and cache the WhipperSnapPy sample subject (Rhineland Study). + + Downloads FreeSurfer surface files for one anonymized subject into the + OS-specific user cache directory and returns a dictionary of paths to + all files. Files are only downloaded once; subsequent calls use the + local cache. + + If a ``sub-rs/`` directory with all required files is found next to the + package root (i.e. inside the source repository), it is used directly + without any network access. This allows the Sphinx doc build to work + before the GitHub release assets are published. + + Returns + ------- + dict + Dictionary with the following keys: + + * ``sdir`` -- path to the subject root directory (``sub-rs/``), + usable directly as the ``sdir`` argument to :func:`~whippersnappy.snap4`. + * ``lh_white`` -- path to ``surf/lh.white``. + * ``lh_curv`` -- path to ``surf/lh.curv``. + * ``lh_thickness`` -- path to ``surf/lh.thickness``. + * ``rh_white`` -- path to ``surf/rh.white``. + * ``rh_curv`` -- path to ``surf/rh.curv``. + * ``rh_thickness`` -- path to ``surf/rh.thickness``. + * ``lh_annot`` -- path to ``label/lh.aparc.DKTatlas.mapped.annot``. + * ``lh_label`` -- path to ``label/lh.cortex.label``. + * ``rh_annot`` -- path to ``label/rh.aparc.DKTatlas.mapped.annot``. + * ``rh_label`` -- path to ``label/rh.cortex.label``. + + Raises + ------ + ImportError + If ``pooch`` is not installed. Install with + ``pip install 'whippersnappy[notebook]'``. + + Notes + ----- + Data from the Rhineland Study (Koch et al.), + https://doi.org/10.5281/zenodo.11186582, CC BY 4.0. + + Examples + -------- + >>> from whippersnappy import fetch_sample_subject + >>> data = fetch_sample_subject() + >>> print(data["sdir"]) + """ + try: + import pooch + except ImportError as e: + raise ImportError( + "fetch_sample_subject() requires pooch. " + "Install with: pip install 'whippersnappy[notebook]'" + ) from e + + # Use a local sub-rs/ directory (present in the source repo) when all + # required files are already there — no network access needed. + _pkg_root = Path(__file__).parent.parent.parent # .../whippersnappy/ + _local = _pkg_root / "sub-rs" + if _local.is_dir() and all((_local / p).exists() for p in _FILES): + return _build_dict(_local) + + # Otherwise download from the GitHub release and cache in the OS cache dir. + base = Path(pooch.os_cache("whippersnappy")) / "sub-rs" + + for rel_path, known_hash in _FILES.items(): + rel = Path(rel_path) + pooch.retrieve( + url=RELEASE_URL.format(file_name=rel.name), + known_hash=known_hash, + fname=rel.name, + path=base / rel.parent, + ) + + return _build_dict(base) diff --git a/whippersnappy/utils/image.py b/whippersnappy/utils/image.py new file mode 100644 index 0000000..f9d2ea0 --- /dev/null +++ b/whippersnappy/utils/image.py @@ -0,0 +1,343 @@ +"""Image and text helper utilities used by snapshot renderers (moved under utils). +""" +import numpy as np +from PIL import Image, ImageDraw + +from .colormap import heat_color +from .types import OrientationType + +try: + # Prefer stdlib importlib.resources + from importlib import resources +except Exception: + import importlib_resources as resources +import warnings + +from PIL import ImageFont + + +def text_size(caption, font): + """Return text width and height in pixels for a given caption and font. + + Parameters + ---------- + caption : str + Text to measure. + font : PIL.ImageFont.FreeTypeFont or similar + Font object used for measurement. + + Returns + ------- + (width, height) : tuple[int, int] + Pixel dimensions of rendered text. + """ + dummy_img = Image.new("L", (1, 1)) + draw = ImageDraw.Draw(dummy_img) + bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt") + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + return text_width, text_height + + +def get_colorbar_label_positions( + font, + labels, + colorbar_rect, + gapspace=0, + pos=True, + neg=True, + orientation=OrientationType.HORIZONTAL, +): + """Compute positions for colorbar label text. + + Parameters + ---------- + font : PIL.ImageFont + Font used to measure text sizes. + labels : dict + Mapping of label keys to text strings (e.g. 'upper','lower','middle'). + colorbar_rect : tuple + Rectangle for the colorbar (x, y, width, height). + gapspace : int, optional, default 0 + Additional spacing used for split colorbars. + pos, neg : bool, optional, default True, True + Whether positive/negative sides are present. + orientation : OrientationType, optional, default OrientationType.HORIZONTAL + Orientation of the colorbar. + + Returns + ------- + positions : dict + Mapping of label key -> (x, y) pixel position. + """ + positions = {} + cb_x, cb_y, cb_width, cb_height = colorbar_rect + cb_labels_gap = 5 + + if orientation == OrientationType.HORIZONTAL: + label_y = cb_y + cb_height + cb_labels_gap + + w, _ = text_size(labels["upper"], font) + if pos: + positions["upper"] = (cb_x + cb_width - w, label_y) + else: + upper_x = cb_x + cb_width - w - int(gapspace) if gapspace > 0 else cb_x + cb_width - w + positions["upper"] = (upper_x, label_y) + + w, _ = text_size(labels["lower"], font) + if neg: + positions["lower"] = (cb_x, label_y) + else: + lower_x = cb_x + int(gapspace) if gapspace > 0 else cb_x + positions["lower"] = (lower_x, label_y) + + if neg and pos: + if gapspace == 0: + w, _ = text_size(labels["middle"], font) + positions["middle"] = (cb_x + cb_width // 2 - w // 2, label_y) + else: + w, _ = text_size(labels["middle_neg"], font) + positions["middle_neg"] = (cb_x + cb_width // 2 - w - int(gapspace), label_y) + w, _ = text_size(labels["middle_pos"], font) + positions["middle_pos"] = (cb_x + cb_width // 2 + int(gapspace), label_y) + else: + label_x = cb_x + cb_width + cb_labels_gap + + _, h = text_size(labels["upper"], font) + if pos: + positions["upper"] = (label_x, cb_y) + else: + upper_y = cb_y + int(gapspace) if gapspace > 0 else cb_y + positions["upper"] = (label_x, upper_y) + + _, h = text_size(labels["lower"], font) + if neg: + positions["lower"] = (label_x, cb_y + cb_height - 1.5 * h) + else: + lower_y = cb_y + cb_height - int(gapspace) - 1.5 * h if gapspace > 0 else cb_y + cb_height - 1.5 * h + positions["lower"] = (label_x, lower_y) + + if neg and pos: + if gapspace == 0: + _, h = text_size(labels["middle"], font) + positions["middle"] = (label_x, cb_y + cb_height // 2 - h // 2) + else: + _, h = text_size(labels["middle_pos"], font) + positions["middle_pos"] = (label_x, cb_y + cb_height // 2 - 1.5 * h - int(gapspace)) + _, h = text_size(labels["middle_neg"], font) + positions["middle_neg"] = (label_x, cb_y + cb_height // 2 + int(gapspace)) + + return positions + + +def create_colorbar( + fmin, + fmax, + invert, + orientation=OrientationType.HORIZONTAL, + colorbar_scale=1, + pos=True, + neg=True, + font_file=None, +): + """Create a colored colorbar as a PIL.Image. + + The colorbar visualizes the overlay color mapping (using + :func:`whippersnappy.utils.colormap.heat_color`) and optionally draws + numeric labels for the min/threshold/saturation positions. + + Parameters + ---------- + fmin, fmax : float + Threshold and saturation values used to label the colorbar. + invert : bool + Invert the heat color mapping. + orientation : OrientationType, optional, default OrientationType.HORIZONTAL + Orientation of the colorbar (HORIZONTAL/VERTICAL). + colorbar_scale : float, optional, default 1 + Scale factor for resulting image size. + pos, neg : bool, optional, default True, True + Whether the colorbar has positive/negative regions. + font_file : str or None, optional + Path to a TTF font file to use for labels. + + Returns + ------- + PIL.Image.Image or None + A PIL image containing the colorbar, or ``None`` if inputs are + insufficient (e.g. fmin/fmax are None). + """ + # If fmin/fmax are not specified, we cannot create a meaningful colorbar. + if fmin is None or fmax is None: + return None + + cwidth = int(200 * colorbar_scale) + cheight = int(30 * colorbar_scale) + gapspace = 0 + + if fmin > 0.01: + num = int(0.42 * cwidth) + gapspace = 0.08 * cwidth + else: + num = int(0.5 * cwidth) + if not neg or not pos: + num = num * 2 + gapspace = gapspace * 2 + + values = np.nan * np.ones(cwidth) + steps = np.linspace(0.01, 1, num) + if pos and not neg: + values[-steps.size:] = steps + elif not pos and neg: + values[: steps.size] = -1.0 * np.flip(steps) + else: + values[: steps.size] = -1.0 * np.flip(steps) + values[-steps.size:] = steps + + colors = heat_color(values, invert) + colors[np.isnan(values), :] = 0.33 * np.ones((1, 3)) + img_bar = np.uint8(np.tile(colors, (cheight, 1, 1)) * 255) + + pad_top, pad_left = 3, 10 + img_buf = np.zeros((cheight + 2 * pad_top, cwidth + 2 * pad_left, 3), dtype=np.uint8) + img_buf[pad_top : cheight + pad_top, pad_left : cwidth + pad_left, :] = img_bar + image = Image.fromarray(img_buf) + + if font_file is None: + # Try to load bundled font from package resources + font = None + try: + font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf") + with resources.as_file(font_trav) as font_path: + font = ImageFont.truetype(str(font_path), int(12 * colorbar_scale)) + except Exception: + warnings.warn("Roboto font not found in package resources; falling back to default font", + UserWarning, stacklevel=2) + font = ImageFont.load_default() + else: + try: + font = ImageFont.truetype(font_file, int(12 * colorbar_scale)) + except Exception: + font = ImageFont.load_default() + + labels = {} + labels["upper"] = f">{fmax:.2f}" if pos else (f"{-fmin:.2f}" if gapspace != 0 else "0") + labels["lower"] = f"<{-fmax:.2f}" if neg else (f"{fmin:.2f}" if gapspace != 0 else "0") + if neg and pos and gapspace != 0: + labels["middle_neg"] = f"{-fmin:.2f}" + labels["middle_pos"] = f"{fmin:.2f}" + elif neg and pos and gapspace == 0: + labels["middle"] = "0" + + caption_sizes = [text_size(caption, font) for caption in labels.values()] + max_caption_width = int(max([caption_size[0] for caption_size in caption_sizes])) + max_caption_height = int(max([caption_size[1] for caption_size in caption_sizes])) + + if orientation == OrientationType.VERTICAL: + image = image.rotate(90, expand=True) + new_width = image.width + int(max_caption_width) + new_image = Image.new("RGB", (new_width, image.height), (0, 0, 0)) + new_image.paste(image, (0, 0)) + image = new_image + colorbar_rect = (pad_top, pad_left, cheight, cwidth) + else: + new_height = image.height + int(max_caption_height * 2) + new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0)) + new_image.paste(image, (0, 0)) + image = new_image + colorbar_rect = (pad_left, pad_top, cwidth, cheight) + + positions = get_colorbar_label_positions(font, labels, colorbar_rect, gapspace, pos, neg, orientation) + draw = ImageDraw.Draw(image) + for label_key, position in positions.items(): + draw.text((int(position[0]), int(position[1])), labels[label_key], fill=(220, 220, 220), font=font) + + return image + + +def load_roboto_font(size=14): + """Load the bundled Roboto font from package resources. + + Parameters + ---------- + size : int, optional + Requested point size. + + Returns + ------- + PIL.ImageFont.FreeTypeFont or PIL.ImageFont.ImageFont or None + A PIL font object; falls back to ``ImageFont.load_default()`` or + ``None`` if fonts cannot be loaded. + """ + try: + # resources was imported earlier in this module + font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf") + with resources.as_file(font_trav) as font_path: + return ImageFont.truetype(str(font_path), size) + except Exception: + warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning, + stacklevel=2) + try: + return ImageFont.load_default() + except Exception: + return None + + +def draw_colorbar(image, bar, orientation, x=None, y=None): + """Paste a colorbar image onto the target image at the specified position. + + Parameters + ---------- + image : PIL.Image.Image + The target image to paste onto. + bar : PIL.Image.Image + The colorbar image to paste. + orientation : OrientationType + Orientation of the colorbar (HORIZONTAL/VERTICAL). + x, y : int or None, optional + Position to paste the colorbar. If None, defaults to centered at bottom (horizontal) or right (vertical). + """ + if bar is None: + return + if orientation == OrientationType.HORIZONTAL: + bx = int(0.5 * (image.width - bar.width)) if x is None else x + by = image.height - bar.height - 10 if y is None else y + image.paste(bar, (bx, by)) + else: + bx = image.width - bar.width - 10 if x is None else x + by = int(0.5 * (image.height - bar.height)) if y is None else y + image.paste(bar, (bx, by)) + + +def draw_caption(image, caption, font, orientation, x=None, y=None): + """Draw a caption string onto the image at the specified position and orientation. + + Parameters + ---------- + image : PIL.Image.Image + The target image to draw onto. + caption : str + The caption text to draw. + font : PIL.ImageFont + Font to use for the caption. + orientation : OrientationType + Orientation of the caption (HORIZONTAL/VERTICAL). + x, y : int or None, optional + Position to draw the caption. If None, defaults to centered at bottom (horizontal) or right (vertical). + """ + if not caption or font is None: + return + text_w, text_h = text_size(caption, font) + draw = ImageDraw.Draw(image) + if orientation == OrientationType.HORIZONTAL: + cx = int(0.5 * (image.width - text_w)) if x is None else x + cy = image.height - text_h - 10 if y is None else y + draw.text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt") + else: + temp_caption_img = Image.new("RGBA", (text_w, text_h), (0, 0, 0, 0)) + ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") + rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0)) + rotated_w, rotated_h = rotated_caption.size + cx = image.width - rotated_w - 10 if x is None else x + cy = int(0.5 * (image.height - rotated_h)) if y is None else y + image.paste(rotated_caption, (cx, cy), rotated_caption) diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py new file mode 100644 index 0000000..8d0e417 --- /dev/null +++ b/whippersnappy/utils/types.py @@ -0,0 +1,93 @@ +"""Contains the types used in WhipperSnapPy. + +This module defines small enumeration types used across the package for +controlling color selection, colorbar orientation, and predefined views. + +Classes +------- +ColorSelection + Which sign(s) of overlay values should be used to produce colors. +OrientationType + Orientation of UI elements such as the colorbar (horizontal or vertical). +ViewType + Predefined canonical view orientations for rendering the brain surface. +""" + +import enum + + +class ColorSelection(enum.Enum): + """Enum to select which sign(s) of overlay values to color. + + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). Documented here to satisfy documentation + linters that inspect the class signature. + + Attributes + ---------- + BOTH : int + Use both positive and negative values for coloring. + POSITIVE : int + Use only positive values for coloring. + NEGATIVE : int + Use only negative values for coloring. + """ + BOTH = 1 + POSITIVE = 2 + NEGATIVE = 3 + + +class OrientationType(enum.Enum): + """Enum describing orientation choices for elements like the colorbar. + + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). + + Attributes + ---------- + HORIZONTAL : int + Layout along the horizontal axis. + VERTICAL : int + Layout along the vertical axis. + """ + HORIZONTAL = 1 + VERTICAL = 2 + + +class ViewType(enum.Enum): + """Predefined canonical view directions used by snapshot renderers. + + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). + + Attributes + ---------- + LEFT : int + Left hemisphere lateral view. + RIGHT : int + Right hemisphere lateral view. + BACK : int + Posterior view. + FRONT : int + Anterior/frontal view. + TOP : int + Superior/top view. + BOTTOM : int + Inferior/bottom view. + """ + LEFT = 1 + RIGHT = 2 + BACK = 3 + FRONT = 4 + TOP = 5 + BOTTOM = 6 +