Skip to content

feat(lib): add BVH acceleration#406

Open
rwydaegh wants to merge 42 commits into
jeertmans:mainfrom
rwydaegh:feature/bvh-acceleration
Open

feat(lib): add BVH acceleration#406
rwydaegh wants to merge 42 commits into
jeertmans:mainfrom
rwydaegh:feature/bvh-acceleration

Conversation

@rwydaegh
Copy link
Copy Markdown

Fixes Issue

Closes #313

Description

Summary

This PR adds a Rust BVH (Bounding Volume Hierarchy) to differt-core with both PyO3 and XLA FFI bindings, fully integrated into all three compute_paths methods. It solves the O(rays * triangles) memory scaling that causes OOM on scenes with more than a few thousand triangles (#313).

Key results:

  • 951x speedup on Munich (38K triangles) for hard-mode intersection
  • 2-3x speedup for differentiable (soft) mode, but more importantly avoids OOM by reducing allocations from [rays, 39K, 3] to [rays, ~300, 3]
  • Works inside jax.jit and jax.lax.scan via XLA FFI
  • Zero regressions on existing test suite (1,642 tests pass)
  • 15 Rust + 29 Python tests added

Motivation

DiffeRT's intersection functions allocate O(rays * triangles) intermediate arrays in JAX. On Munich (38K triangles), this means multi-GB arrays that OOM even on beefy GPUs. Jerome tried every pure-JAX approach:

  • vmap+sum: OOM
  • lax.scan: correct but slow
  • lax.map: correct but slow
  • fori_loop with batching: best compromise, still 20s+ on GPU

The JAX team confirmed in jax-ml/jax#30841 that lax.reduce cannot close over Tracers due to a StableHLO limitation with no fix planned. The only viable path is moving the ray-triangle loop out of JAX.

Approach: Rust for spatial queries, JAX for math

The core idea is to split responsibilities: the BVH in Rust handles spatial acceleration (which triangles could a ray hit?), while the Moller-Trumbore intersection math stays in JAX (where it auto-differentiates through sigmoid smoothing). This means:

  • No custom VJPs needed in Rust (candidate indices are integers with zero gradient)
  • The backward pass drops from O(rays * all_triangles) to O(rays * candidates)
  • The Rust BVH code can be reviewed independently from the differentiable logic

Two query paths serve different contexts:

Path Where it works How it's used
PyO3 (direct Python-to-Rust) Outside JIT only Visibility estimation, standalone queries
XLA FFI (Rust via C++ XLA handler) Inside jax.jit, lax.scan, vmap Blocking checks, SBR bounce loops

The XLA FFI pipeline:

Python: ffi_nearest_hit(origins, dirs, active_mask, bvh_id=id)
  -> jax.ffi.ffi_call("bvh_nearest_hit", ...)   # JAX traces into XLA
  -> BvhNearestHit(XLA_FFI_CallFrame*)           # C++ handler (ffi.cc)
  -> bvh_nearest_hit_ffi(...)                    # Rust via cxx bridge
  -> Bvh::nearest_hit(origin, dir, active_mask)  # BVH traversal

The BVH lives in a global Mutex<HashMap<u64, Arc<Bvh>>> registry. Python calls bvh.register() to get an integer ID passed as an XLA attribute (compile-time constant). A Drop impl on TriangleBvh auto-unregisters when Python GC collects the object.

Differentiable mode: the expanded BVH

For the soft path (smoothing_factor set), boolean tests become sigmoid(x * alpha). Triangles far from a ray have exponentially small sigmoid values. The expansion radius captures all triangles with gradient contribution above a threshold:

r_near = triangle_size * ln(1/epsilon_grad) / smoothing_factor

When r_near exceeds the scene diagonal or candidate counts overflow, the system automatically falls back to brute force.

Performance

Hard mode (nearest-hit)

Scene Triangles Rays BVH build BVH query Brute force Speedup
Munich 38,936 200 136 ms 1 ms 1,054 ms 951x
Random 10,000 100 13 ms 9 ms 545 ms 58x
Random 1,000 1,000 2 ms 10 ms 481 ms 47x

BVH build is one-time (cached per scene). Query scales as O(rays * log(triangles)).

XLA FFI vs PyO3 (Munich, 38K triangles, 200 rays)

Path Time Notes
PyO3 (outside JIT) 3.5 ms Direct Python-to-Rust
XLA FFI (inside JIT, warm) 2.6 ms After JIT compilation

Soft mode (Munich, 38K triangles, 50 rays)

smoothing_factor Expansion radius BVH time Brute force Speedup
10 8.06 m 845 ms 622 ms Falls back to BF
100 0.81 m 233 ms 682 ms 2.9x
500 0.16 m 252 ms 735 ms 2.9x
1000 0.08 m 271 ms 727 ms 2.7x

The soft-mode speedup is modest (2-3x) because JAX soft intersection on candidates still dominates. The real value is avoiding OOM.

Visibility estimation (hybrid method, Munich, 100K rays)

Method Visible triangles Time Speedup
BVH 1,143 1.12 s 14x
Brute force 1,128 15.18 s 1x

Usage

from differt.scene import TriangleScene

scene = TriangleScene.load_xml("munich/munich.xml")
bvh = scene.build_bvh()  # one-time O(N log N) build

# All compute_paths methods accept bvh= parameter
paths = scene.compute_paths(order=1, method="exhaustive", bvh=bvh)
paths = scene.compute_paths(order=1, method="hybrid", bvh=bvh)
paths = scene.compute_paths(order=2, method="sbr", bvh=bvh)

# Standalone queries (PyO3, outside JIT)
from differt.accel import bvh_first_triangles_hit_by_rays
idx, t = bvh_first_triangles_hit_by_rays(
    origins, directions, scene.mesh.triangle_vertices, bvh=bvh,
)

# Differentiable mode with BVH candidate pruning
from differt.accel import bvh_rays_intersect_any_triangle
blocked = bvh_rays_intersect_any_triangle(
    origins, directions, scene.mesh.triangle_vertices,
    smoothing_factor=100.0, bvh=bvh,
)
# Gradients flow through JAX autodiff on the reduced candidate set

What changed

Rust (differt-core/src/accel/, ~1,280 lines):

  • bvh.rs: SAH-binned BVH construction, slab-method AABB traversal, Moller-Trumbore intersection, global registry with Arc<Bvh>, PyO3 bindings, Drop for auto-cleanup. Supports optional active_mask to skip inactive triangles during traversal.
  • ffi.rs: cxx bridge declarations, FFI entry points, PyCapsule exports for jax.ffi.register_ffi_target.

C++ (src/ffi.cc + include/ffi.h, ~120 lines): XLA FFI handlers that decode buffers and call Rust via cxx.

Build (build.rs, 45 lines): Queries Python for JAX XLA include paths, compiles C++ via cxx-build. Gated behind xla-ffi Cargo feature.

Python (differt/src/differt/accel/, ~730 lines):

  • _bvh.py: TriangleBvh wrapper with batch handling, register/unregister lifecycle.
  • _accelerated.py: Drop-in replacements for rays_intersect_any_triangle, first_triangles_hit_by_rays, triangles_visible_from_vertices.
  • _ffi.py: JAX FFI wrappers (ffi_nearest_hit, ffi_get_candidates).

Scene integration (_triangle_scene.py, +120 lines): build_bvh() method, compute_paths(bvh=...) parameter wired into exhaustive, hybrid, and SBR methods.

Tests (15 Rust + 29 Python): Construction, traversal, active_mask filtering, empty BVH, soft/hard mode, expansion radius, visibility, full compute_paths integration for all three methods.

Known limitations

  • CPU only. The BVH runs on CPU; data transfers to/from GPU via PCIe for each query. A GPU BVH (OptiX or pure-JAX) would eliminate this transfer and leverage RT cores, but is a separate effort.
  • No best_t pruning in traversal. The AABB slab test checks ray intersection but does not cull nodes farther than the current closest hit. Correct but suboptimal for dense scenes.
  • Soft mode FFI not wired. ffi_get_candidates exists but the soft blocking check in _compute_paths still falls back to brute force. The standalone bvh_rays_intersect_any_triangle does use the BVH for soft mode.
  • Moller-Trumbore epsilon differs from JAX. The Rust BVH uses 1e-8; the JAX brute-force path uses ~1.2e-6 (10 * f32 eps). This can cause minor disagreements on grazing rays (<0.1% of rays in practice).

Test plan

  • 15 Rust unit tests pass (cargo test -- accel::bvh)
  • 29 Python tests pass (pytest differt/tests/accel/test_bvh.py)
  • Full DiffeRT suite: 1,642 passed, 4 failed (pre-existing vispy headless failures)
  • BVH results match brute force on random scenes (100+ rays)
  • active_mask correctly skips inactive triangles to find nearest active hit
  • Empty BVH returns miss without hanging
  • SBR paths with BVH match brute force (>95% object index agreement)

Checklist

Check all the applicable boxes:

  • I understand that my contributions need to pass the checks;
  • If I created new functions / methods, I documented them and add type hints;
  • If I modified already existing code, I updated the documentation accordingly;
  • The title of my pull request is a short description of the requested changes.

Note to reviewers

rwydaegh and others added 14 commits March 28, 2026 01:58
Implements a SAH-based BVH in Rust (differt-core) with PyO3 bindings,
providing two query types:
- nearest_hit: O(log N) per ray for SBR (951x speedup on Munich scene)
- get_candidates: expanded-box traversal for differentiable mode

Python integration (differt.accel) provides drop-in replacements for
rays_intersect_any_triangle and first_triangles_hit_by_rays that accept
an optional bvh= parameter. For differentiable mode, the BVH selects
candidate triangles and existing JAX Moller-Trumbore runs on the reduced
set, preserving gradient correctness.

Adds TriangleScene.build_bvh() convenience method.
Includes 11 Rust tests and 20 Python tests (BVH vs brute-force).

Resolves the core memory bottleneck described in issue jeertmans#313.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- bvh_triangles_visible_from_vertices: 14x faster on Munich (38K tris)
- compute_paths(method="hybrid", bvh=bvh): BVH for visibility step
- 7 new tests (27 total): visibility, compute_paths integration
- Update REPORT.md with Phase 3 progress

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace Python for-loops with NumPy array operations for
active_triangles checks in bvh_rays_intersect_any_triangle
and bvh_first_triangles_hit_by_rays.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Global registry (Mutex<HashMap<u64, Arc<Bvh>>>) allows XLA FFI handlers
to look up pre-built BVHs by integer ID. register()/unregister() methods
on TriangleBvh. This is the Rust-side foundation for making BVH queries
work inside JIT-compiled JAX functions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BVH nearest-hit and get-candidates now work inside jax.jit and
jax.lax.scan via XLA FFI. This enables BVH-accelerated SBR.

Rust side:
- accel/ffi.rs: cxx bridge + FFI entry points + PyCapsule exports
- build.rs: finds JAX XLA headers, compiles C++ via cxx-build
- ffi.cc + ffi.h: XLA FFI handlers (BvhNearestHit, BvhGetCandidates)
- Cargo.toml: optional xla-ffi feature (cxx + cxx-build)

Python side:
- _ffi.py: jax.ffi.register_ffi_target + ffi_call wrappers
- ffi_nearest_hit() and ffi_get_candidates() work in JIT

Verified: BVH inside lax.scan on Munich (38K triangles), exact match
with PyO3 results.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All three compute_paths methods now use BVH when bvh= is provided:

- exhaustive: BVH FFI replaces blocking check inside @eqx.filter_jit
- sbr: BVH FFI replaces first_triangles_hit_by_rays inside lax.scan
- hybrid: BVH for visibility (PyO3) + blocking check (FFI)

Hard mode only (smoothing_factor=None). Soft mode falls back to
brute force for the blocking check since sigmoid smoothing needs
JAX-side math.

29 BVH tests pass, 245 RT tests pass. Zero regressions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Clean rewrite reflecting all completed work: Rust BVH, XLA FFI
pipeline, full compute_paths integration. Added XLA FFI architecture
diagram, FFI vs PyO3 benchmarks, updated file table and test counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add early return in nearest_hit/get_candidates for empty BVH
  (previously caused infinite traversal loop)
- Add Drop impl for TriangleBvh to auto-unregister from global
  registry when Python GC collects the object
- Add Rust tests for empty BVH traversal
- SBR test now checks shapes and object index agreement instead
  of only checking return type

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The BVH nearest_hit now accepts an optional active_mask parameter at
every layer (Rust core, XLA FFI, PyO3, Python wrapper). When provided,
inactive triangles are skipped during traversal, correctly finding the
nearest *active* hit. This replaces the previous approach of finding
the global nearest hit and then discarding it if inactive, which missed
active triangles behind inactive ones.

Changes across the full stack:
- Rust: nearest_hit gets Option<&[bool]> active_mask, skip in leaf loop
- FFI: active_mask passed as u8 slice through cxx bridge
- C++: PRED buffer added to XLA FFI handler binding
- Python: active_mask parameter on ffi_nearest_hit, TriangleBvh.nearest_hit
- Scene: mesh.mask passed directly instead of post-hoc filtering
- Test: Rust test verifies mask-out-front-triangle finds rear triangle

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions github-actions Bot added python Changes Python code rust Changes Rust code tests Changes tests dependencies Related to the project dependencies labels Mar 29, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 29, 2026

Codecov Report

❌ Patch coverage is 88.90943% with 60 lines in your changes missing coverage. Please review.
✅ Project coverage is 85.59%. Comparing base (bce9e3b) to head (0b6fed3).

Files with missing lines Patch % Lines
differt/src/differt/scene/_triangle_scene.py 45.07% 39 Missing ⚠️
differt/src/differt/accel/_ffi.py 52.77% 17 Missing ⚠️
differt/src/differt/accel/_accelerated.py 98.42% 2 Missing ⚠️
differt/src/differt/geometry/_triangle_mesh.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #406      +/-   ##
==========================================
+ Coverage   85.03%   85.59%   +0.56%     
==========================================
  Files          32       37       +5     
  Lines        3080     3617     +537     
==========================================
+ Hits         2619     3096     +477     
- Misses        461      521      +60     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

The xla-ffi feature requires JAX at build time (build.rs queries
JAX for XLA FFI header paths). CI builds differt-core in isolated
environments without JAX, causing all wheel builds and pytest jobs
to fail.

Fix: remove xla-ffi from default maturin features. The BVH still
works via PyO3 (outside JIT). Users who need the FFI path (BVH
inside jax.jit/lax.scan) build with:
  maturin develop --features xla-ffi

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add noqa comments for intentional lazy imports (PLC0415)
- Add return type annotations (ANN202)
- Add Raises section to docstring (DOC501)
- Convert if/else to ternary (SIM108)
- Fix ruff formatting and cargo fmt

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented Mar 29, 2026

Merging this PR will improve performance by 13.56%

⚠️ Different runtime environments detected

Some benchmarks with significant performance changes were compared across different runtime environments,
which may affect the accuracy of the results.

Open the report in CodSpeed to investigate

⚡ 1 improved benchmark
✅ 9 untouched benchmarks
🆕 3 new benchmarks

Performance Changes

Benchmark BASE HEAD Efficiency
🆕 test_first_triangles_hit_by_rays_bvh N/A 610.4 ms N/A
🆕 test_rays_intersect_any_triangle_bvh N/A 608.5 ms N/A
test_image_method 184.1 µs 162.1 µs +13.56%
🆕 test_transmitter_visibility_bvh N/A 1.3 s N/A

Comparing rwydaegh:feature/bvh-acceleration (0b6fed3) with main (bce9e3b)

Open in CodSpeed

@rwydaegh rwydaegh changed the title Feature/bvh acceleration feat(lib): add BVH acceleration Mar 29, 2026
- Skip TestComputePathsBvh when xla-ffi feature not built
- Add type stubs for accel.bvh module (fixes typecheck)
- Add accel to _differt_core __init__.pyi
- Fix ruff lint in test file (type annotations, import sorting)
- Pretty-format Cargo.toml

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add three BVH benchmarks matching existing brute-force ones:
- rays_intersect_any_triangle_bvh (1M rays, hard mode)
- first_triangles_hit_by_rays_bvh (1M rays)
- triangles_visible_from_vertices_bvh

These use the PyO3 path (no FFI needed) so they run in CI.
CodSpeed will show BVH vs brute-force performance side by side.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions github-actions Bot added the benchmarks Changes benchmarks label Mar 29, 2026
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@rwydaegh
Copy link
Copy Markdown
Author

rwydaegh commented Apr 1, 2026

LGTM

Copy link
Copy Markdown
Owner

@jeertmans jeertmans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rwydaegh, I had some time to perform a first review round. Let me know if you have any question :-)

Comment thread pyproject.toml Outdated
Comment thread differt/tests/accel/test_bvh.py Outdated
Comment thread differt/tests/accel/test_bvh.py Outdated
Comment thread differt/tests/accel/test_bvh.py
Comment on lines +8 to +9
if TYPE_CHECKING:
from differt.accel._bvh import TriangleBvh
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a try--except block might be needed here: otherwise it will fail for any people using differt without the acceleration feature. However, we might just always enable this feature anyway (unless it increases the binary wheels too much).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to try: from differt.accel._bvh import TriangleBvh / except ImportError: TriangleBvh = Any. Also applied the same pattern to _triangle_mesh.py which had the same if TYPE_CHECKING guard.

Comment thread differt/src/differt/accel/_bvh.py Outdated
Comment on lines +200 to +204
import math # noqa: PLC0415

if smoothing_factor <= 0:
return float("inf")
return triangle_size * math.log(1.0 / epsilon_grad) / smoothing_factor
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import inside the function :-)

Also, is there any reason not too just JAX here?
I would argue to use JAX everywhere we can, and only cast to NumPy floats (or Python floats) when we really need it. This way, implementing differentiable acceleration structures will only be easier.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved import math to top of file. Kept math.log instead of jnp.log because compute_expansion_radius returns a plain Python float used for a non-JIT decision (comparing against a threshold). Can switch to jnp.log if you prefer, it doesn't matter much either way.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, you are using JAX's FFI to register this data structure and use it within JAX functions, but I don't see any benefits over using PyO3's classes (like I did for TriangleMesh). In the future, using JAX's FFI will definitely help with automatic differentiation or on GPU execution, but here I wonder if there is any reason that made you go for JAX's FII instead? I'm ok with this, I'm just curious.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BVH queries need to run inside jax.jit (specifically inside _compute_paths which is jitted). PyO3 classes can't be called from inside a jitted function, so JAX FFI was the way to go here. You're right that for standalone usage PyO3 would work fine, but since the main use case is inside the jitted path computation, FFI was necessary.

Comment thread differt/src/differt/accel/__init__.py Outdated
Comment on lines +6 to +12
Example:
>>> import jax.numpy as jnp
>>> from differt.accel import TriangleBvh
>>> verts = jnp.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32)
>>> bvh = TriangleBvh(verts)
>>> bvh.num_triangles
1
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to find a better example in the future :-)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha, fair enough 😄 Updated it to show a two-triangle quad with nearest_hit so it actually demonstrates the BVH doing something useful (hit vs miss).

Comment thread differt-core/Cargo.toml
Comment thread differt-core/build.rs
existing JAX-based Moller-Trumbore runs on the reduced set.
"""

from __future__ import annotations
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use from __future__ import annotations imports as it is incompatible with jaxtyping (sadly). It also seem to break the type annotations on the corresponding documentation pages.

Image

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from both _accelerated.py and _ffi.py, good catch on this one. It actually exposed a real bug: the BVH wrappers were computing batch_shape from only ray_origins.shape, ignoring broadcasting with other inputs. The brute-force originals use jnp.broadcast_shapes across all inputs. Fixed all three functions to match, return type annotations are correct now too.

- Remove `from __future__ import annotations` (breaks jaxtyping + docs)
- Move all in-function imports to module top level
- Use chex instead of np.testing for JAX array assertions
- Rename hard/soft mode → smoothing naming convention in tests
- Remove redundant "with shape (...)" from FFI docstrings
- Move `build_bvh` from TriangleScene to TriangleMesh (delegates)
- Use try/except for TriangleBvh import in scene (not TYPE_CHECKING)
- Move `import math` to module level in _bvh.py
- Revert codespell ignore-words (maths fixed in jeertmans#412)
- Enable xla-ffi by default in maturin features
- Improve build.rs: robust Python detection, hard fail on missing JAX
- Add num_nodes property to bvh.pyi stub

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Always-on xla-ffi panics in environments without JAX (pre-commit,
MSRV check). Revert to optional feature with graceful warning when
JAX is not found. Keep improved Python interpreter detection
(PYTHON env → pyo3_build_config → python3 fallback).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
With __future__ annotations removed, jaxtyping's import hook wraps
functions at runtime. The #batch dimension constraint on return types
fails because BVH functions reshape internally. Use "..." (any shape)
for return annotations since these are wrapper functions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
rwydaegh and others added 2 commits April 10, 2026 16:22
Missed a bunch of hard mode/soft mode references in comments,
docstrings, and variable names during the initial rename pass.
Also moved import sys to top level in tests, switched
_triangle_mesh.py to try/except ImportError for consistency
with _triangle_scene.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All three BVH functions computed batch_shape from ray_origins alone,
ignoring broadcasting with other inputs. The brute-force originals
use jnp.broadcast_shapes across all inputs. This meant broadcastable
inputs (legal per the #batch annotation) would produce wrong-shaped
returns. The "..." return type annotations were masking this.

Fixed with proper broadcast_shapes + broadcast_to, restored correct
return type annotations. Also fixed the warning message that said
"increase smoothing_factor" when it should say "decrease", moved
geometry import to top level, and finished the smoothing rename
in this file.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@rwydaegh
Copy link
Copy Markdown
Author

Minor, unrelated: CI is a bit slow to iterate on, perhaps worth checking out Ubicloud or Blacksmith

@rwydaegh
Copy link
Copy Markdown
Author

Thx for the review @jeertmans! Went through everything together with Claude 🙂

For the tests I moved all the inline imports to top level, switched to chex for array comparisons, and renamed hard/soft to the smoothing_ convention everywhere (method names, class names, variable names, comments, docstrings).

Removed from __future__ import annotations from _accelerated.py and _ffi.py, good catch. This exposed a real bug: the BVH wrappers were computing batch_shape from only ray_origins.shape, ignoring broadcasting with the other inputs. The brute-force originals use jnp.broadcast_shapes across all inputs. Fixed all three functions to do the same, and the return type annotations now match the originals (*batch, *batch num_triangles, etc.).

Moved build_bvh to TriangleMesh, TriangleScene.build_bvh() just delegates now. For storing the BVH inside the mesh: I looked into it and I think the explicit bvh=bvh API is actually better here. There are 12 geometry-mutating methods on TriangleMesh (rotate, scale, translate, append, the .at[] helpers, etc.) that would all need BVH invalidation, and eqx.tree_at can't update static fields so each site needs a dataclasses.replace follow-up. If anyone forgets on a future method, you get a stale BVH returning wrong intersections silently. The one-line ergonomic gain doesn't seem worth that maintenance risk.

Also renamed all the "hard mode"/"soft mode" references in _accelerated.py and _triangle_scene.py to use the smoothing convention consistently.

For the smoothing BVH path (your comment on line 342), the code already does this: ffi_get_candidates (BVH) narrows down candidates, then rays_intersect_triangles (pure JAX) runs on those candidates. Differentiability is preserved because gradients only flow through the JAX intersection, not the FFI call.

Tried always-on xla-ffi, broke pre-commit and MSRV check. The cxx bridge compiles C++ that #includes JAX headers, so any environment without JAX installed (pre-commit, MSRV check, CI lint) fails at build time. Making it always-on means making JAX a build dependency for differt-core, which is a bigger change to the package's dependency tree than fits in this PR.

For build.rs I improved the Python interpreter detection (PYTHON env, then pyo3_build_config, then python3 fallback), similar to your extending-jax setup. Kept the graceful warning instead of hard error since xla-ffi is still optional.

Some smaller stuff: reverted the codespell workaround (#412 fixes it), cleaned up the redundant shape descriptions in FFI docstrings, moved import math to top level in _bvh.py (kept math.log over jnp.log since the result is a plain float for a non-JIT decision, but can switch if you prefer), added num_nodes to the .pyi stub, changed the TriangleBvh import to try/except in both _triangle_scene.py and _triangle_mesh.py, and moved an unconditional inline geometry import to top level in _accelerated.py.

For docs I kept differt.accel.rst under differt since that's the user-facing API. differt_core.accel is the internal Rust FFI layer, so documenting it separately would likely confuse users who should be using differt.accel instead.

@jeertmans
Copy link
Copy Markdown
Owner

Hi @rwydaegh, thanks for your work on this! I see you have made many changes and summarized them in one big comment. Could you please look at each comment I made in my first review and indicate whether you addressed it or not? Alternatively, if you disagree with a comment, simply reply with your opinion. This will help me keep track of progress and resolve conversations more easily.

For this, you might need to click on
image

and then on

image

to show all comments.

Alternatively, here a links to my comments:

I believe most comments can already be resolved. I also see that the most comments no longer point to the original lines of code, which makes it hard to understand the original context of the comment... I don't know what happened.

Minor, unrelated: CI is a bit slow to iterate on, perhaps worth checking out Ubicloud or Blacksmith

What part of the CI is too slow? If you are talking about macOS, this is a known issue: the amount of macOS runners is very limited and the queue is usually very long before our CI gets to run. For the other tests, I believe it is ok. There exists alternatives, but are they free, or as unlimited as GitHub actions? Also note that you can very easily test your feature locally with uv run pytest -x [<path-to-file>] [-k <name-or-test(s)>]. If you wait for the full tests suite to finish just to verify that a single function works, that's indeed quite long ^^'


On a side note, I'm releasing a new version today with added functionality. I plan to release your contribution in the next version. I'm okay with releasing it as an "experimental feature" in case the API changes later on if I find a better way to organize the library that is more transparent to the user.

I will continue my review next week once I have finished preparing my slides for EuCAP :D

@rwydaegh
Copy link
Copy Markdown
Author

Posted them all. Perhaps CI is actually okay then. For me experimental is fine, I've been testing it for real ray-tracing and so far it works as intended. Good luck with EuCAP!

Forgot to push this earlier. The module docstring example now builds
a two-triangle quad and uses nearest_hit to show a hit (idx=0, t=1.0)
and a miss (idx=-1, t=inf) instead of just checking num_triangles.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jeertmans
Copy link
Copy Markdown
Owner

jeertmans commented Apr 16, 2026

EDIT: looks like the docs build is failing, but it was working locally (-:

Hi @rwydaegh! I have reviewed most of the code and pushed a refactor, to meet some of my criteria; I hope you don't mind. I left a few to-dos in the code, and made just that the lints were passing (except ones) and that the documentation was built properly. Tests will likely fail because I didn't remove all the occurrences of the "previous" ffi registration logic.

I consider opening a meta issue with a list of idea of things to improve in a future PR. However, I'd like to have your opinion on this, and let me know if you think you can implement some of this points before merging:

  • Store TriangleBVH inside TriangleMesh, so that TriangleScene can automatically use it, when available, when calling compute_paths. To would reduce the API complexity. E.g., turn TriangleMesh.build_bvh() -> TriangleBvh into TriangleMesh.build_bvh(update: bool = False) -> TriangleMesh, where Triangle.bvh is set to None by default.
  • The BVH id is a long unsigned integer: I think using short value might be better. However, this comment may be unnecessary if we no longer use a global hashmap (see after).
  • The current Rust implementation stores TriangleBvh inside global hashmap: this is probably not a good idea if we later want to implement differentiability. The best might be to explicitly take a TriangleBvh by argument.
  • Choose whether we use BVH or Bvh naming convention (maybe see if there another example of acronym use in the library)
  • I don't know if we should prepend bvh_ before the BVH-accelerated functions, since they are already in the bvh module.
  • Maybe a full Python (but JAX powered) BVH implementation would be more efficient, or at least be decent while allowing for automatic differentiation very easily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

benchmarks Changes benchmarks ci Continuous integration (tests, lints, ...) dependencies Related to the project dependencies documentation Improvements or additions to documentation python Changes Python code rust Changes Rust code tests Changes tests visualization Related to visualization utilities

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Faster ray-triangle intersection test that does not allocate large arrays

2 participants