feat(lib): add BVH acceleration#406
Conversation
Implements a SAH-based BVH in Rust (differt-core) with PyO3 bindings, providing two query types: - nearest_hit: O(log N) per ray for SBR (951x speedup on Munich scene) - get_candidates: expanded-box traversal for differentiable mode Python integration (differt.accel) provides drop-in replacements for rays_intersect_any_triangle and first_triangles_hit_by_rays that accept an optional bvh= parameter. For differentiable mode, the BVH selects candidate triangles and existing JAX Moller-Trumbore runs on the reduced set, preserving gradient correctness. Adds TriangleScene.build_bvh() convenience method. Includes 11 Rust tests and 20 Python tests (BVH vs brute-force). Resolves the core memory bottleneck described in issue jeertmans#313. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- bvh_triangles_visible_from_vertices: 14x faster on Munich (38K tris) - compute_paths(method="hybrid", bvh=bvh): BVH for visibility step - 7 new tests (27 total): visibility, compute_paths integration - Update REPORT.md with Phase 3 progress Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace Python for-loops with NumPy array operations for active_triangles checks in bvh_rays_intersect_any_triangle and bvh_first_triangles_hit_by_rays. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Global registry (Mutex<HashMap<u64, Arc<Bvh>>>) allows XLA FFI handlers to look up pre-built BVHs by integer ID. register()/unregister() methods on TriangleBvh. This is the Rust-side foundation for making BVH queries work inside JIT-compiled JAX functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BVH nearest-hit and get-candidates now work inside jax.jit and jax.lax.scan via XLA FFI. This enables BVH-accelerated SBR. Rust side: - accel/ffi.rs: cxx bridge + FFI entry points + PyCapsule exports - build.rs: finds JAX XLA headers, compiles C++ via cxx-build - ffi.cc + ffi.h: XLA FFI handlers (BvhNearestHit, BvhGetCandidates) - Cargo.toml: optional xla-ffi feature (cxx + cxx-build) Python side: - _ffi.py: jax.ffi.register_ffi_target + ffi_call wrappers - ffi_nearest_hit() and ffi_get_candidates() work in JIT Verified: BVH inside lax.scan on Munich (38K triangles), exact match with PyO3 results. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All three compute_paths methods now use BVH when bvh= is provided: - exhaustive: BVH FFI replaces blocking check inside @eqx.filter_jit - sbr: BVH FFI replaces first_triangles_hit_by_rays inside lax.scan - hybrid: BVH for visibility (PyO3) + blocking check (FFI) Hard mode only (smoothing_factor=None). Soft mode falls back to brute force for the blocking check since sigmoid smoothing needs JAX-side math. 29 BVH tests pass, 245 RT tests pass. Zero regressions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Clean rewrite reflecting all completed work: Rust BVH, XLA FFI pipeline, full compute_paths integration. Added XLA FFI architecture diagram, FFI vs PyO3 benchmarks, updated file table and test counts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add early return in nearest_hit/get_candidates for empty BVH (previously caused infinite traversal loop) - Add Drop impl for TriangleBvh to auto-unregister from global registry when Python GC collects the object - Add Rust tests for empty BVH traversal - SBR test now checks shapes and object index agreement instead of only checking return type Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The BVH nearest_hit now accepts an optional active_mask parameter at every layer (Rust core, XLA FFI, PyO3, Python wrapper). When provided, inactive triangles are skipped during traversal, correctly finding the nearest *active* hit. This replaces the previous approach of finding the global nearest hit and then discarding it if inactive, which missed active triangles behind inactive ones. Changes across the full stack: - Rust: nearest_hit gets Option<&[bool]> active_mask, skip in leaf loop - FFI: active_mask passed as u8 slice through cxx bridge - C++: PRED buffer added to XLA FFI handler binding - Python: active_mask parameter on ffi_nearest_hit, TriangleBvh.nearest_hit - Scene: mesh.mask passed directly instead of post-hoc filtering - Test: Rust test verifies mask-out-front-triangle finds rear triangle Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #406 +/- ##
==========================================
+ Coverage 85.03% 85.59% +0.56%
==========================================
Files 32 37 +5
Lines 3080 3617 +537
==========================================
+ Hits 2619 3096 +477
- Misses 461 521 +60 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
The xla-ffi feature requires JAX at build time (build.rs queries JAX for XLA FFI header paths). CI builds differt-core in isolated environments without JAX, causing all wheel builds and pytest jobs to fail. Fix: remove xla-ffi from default maturin features. The BVH still works via PyO3 (outside JIT). Users who need the FFI path (BVH inside jax.jit/lax.scan) build with: maturin develop --features xla-ffi Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add noqa comments for intentional lazy imports (PLC0415) - Add return type annotations (ANN202) - Add Raises section to docstring (DOC501) - Convert if/else to ternary (SIM108) - Fix ruff formatting and cargo fmt Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Merging this PR will improve performance by 13.56%
|
| Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|
| 🆕 | test_first_triangles_hit_by_rays_bvh |
N/A | 610.4 ms | N/A |
| 🆕 | test_rays_intersect_any_triangle_bvh |
N/A | 608.5 ms | N/A |
| ⚡ | test_image_method |
184.1 µs | 162.1 µs | +13.56% |
| 🆕 | test_transmitter_visibility_bvh |
N/A | 1.3 s | N/A |
Comparing rwydaegh:feature/bvh-acceleration (0b6fed3) with main (bce9e3b)
- Skip TestComputePathsBvh when xla-ffi feature not built - Add type stubs for accel.bvh module (fixes typecheck) - Add accel to _differt_core __init__.pyi - Fix ruff lint in test file (type annotations, import sorting) - Pretty-format Cargo.toml Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add three BVH benchmarks matching existing brute-force ones: - rays_intersect_any_triangle_bvh (1M rays, hard mode) - first_triangles_hit_by_rays_bvh (1M rays) - triangles_visible_from_vertices_bvh These use the PyO3 path (no FFI needed) so they run in CI. CodSpeed will show BVH vs brute-force performance side by side. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
LGTM |
| if TYPE_CHECKING: | ||
| from differt.accel._bvh import TriangleBvh |
There was a problem hiding this comment.
a try--except block might be needed here: otherwise it will fail for any people using differt without the acceleration feature. However, we might just always enable this feature anyway (unless it increases the binary wheels too much).
There was a problem hiding this comment.
Changed to try: from differt.accel._bvh import TriangleBvh / except ImportError: TriangleBvh = Any. Also applied the same pattern to _triangle_mesh.py which had the same if TYPE_CHECKING guard.
| import math # noqa: PLC0415 | ||
|
|
||
| if smoothing_factor <= 0: | ||
| return float("inf") | ||
| return triangle_size * math.log(1.0 / epsilon_grad) / smoothing_factor |
There was a problem hiding this comment.
Don't import inside the function :-)
Also, is there any reason not too just JAX here?
I would argue to use JAX everywhere we can, and only cast to NumPy floats (or Python floats) when we really need it. This way, implementing differentiable acceleration structures will only be easier.
There was a problem hiding this comment.
Moved import math to top of file. Kept math.log instead of jnp.log because compute_expansion_radius returns a plain Python float used for a non-JIT decision (comparing against a threshold). Can switch to jnp.log if you prefer, it doesn't matter much either way.
There was a problem hiding this comment.
Currently, you are using JAX's FFI to register this data structure and use it within JAX functions, but I don't see any benefits over using PyO3's classes (like I did for TriangleMesh). In the future, using JAX's FFI will definitely help with automatic differentiation or on GPU execution, but here I wonder if there is any reason that made you go for JAX's FII instead? I'm ok with this, I'm just curious.
There was a problem hiding this comment.
The BVH queries need to run inside jax.jit (specifically inside _compute_paths which is jitted). PyO3 classes can't be called from inside a jitted function, so JAX FFI was the way to go here. You're right that for standalone usage PyO3 would work fine, but since the main use case is inside the jitted path computation, FFI was necessary.
| Example: | ||
| >>> import jax.numpy as jnp | ||
| >>> from differt.accel import TriangleBvh | ||
| >>> verts = jnp.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32) | ||
| >>> bvh = TriangleBvh(verts) | ||
| >>> bvh.num_triangles | ||
| 1 |
There was a problem hiding this comment.
We might want to find a better example in the future :-)
There was a problem hiding this comment.
Ha, fair enough 😄 Updated it to show a two-triangle quad with nearest_hit so it actually demonstrates the BVH doing something useful (hit vs miss).
| existing JAX-based Moller-Trumbore runs on the reduced set. | ||
| """ | ||
|
|
||
| from __future__ import annotations |
There was a problem hiding this comment.
Removed from both _accelerated.py and _ffi.py, good catch on this one. It actually exposed a real bug: the BVH wrappers were computing batch_shape from only ray_origins.shape, ignoring broadcasting with other inputs. The brute-force originals use jnp.broadcast_shapes across all inputs. Fixed all three functions to match, return type annotations are correct now too.
- Remove `from __future__ import annotations` (breaks jaxtyping + docs) - Move all in-function imports to module top level - Use chex instead of np.testing for JAX array assertions - Rename hard/soft mode → smoothing naming convention in tests - Remove redundant "with shape (...)" from FFI docstrings - Move `build_bvh` from TriangleScene to TriangleMesh (delegates) - Use try/except for TriangleBvh import in scene (not TYPE_CHECKING) - Move `import math` to module level in _bvh.py - Revert codespell ignore-words (maths fixed in jeertmans#412) - Enable xla-ffi by default in maturin features - Improve build.rs: robust Python detection, hard fail on missing JAX - Add num_nodes property to bvh.pyi stub Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Always-on xla-ffi panics in environments without JAX (pre-commit, MSRV check). Revert to optional feature with graceful warning when JAX is not found. Keep improved Python interpreter detection (PYTHON env → pyo3_build_config → python3 fallback). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
With __future__ annotations removed, jaxtyping's import hook wraps functions at runtime. The #batch dimension constraint on return types fails because BVH functions reshape internally. Use "..." (any shape) for return annotations since these are wrapper functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Missed a bunch of hard mode/soft mode references in comments, docstrings, and variable names during the initial rename pass. Also moved import sys to top level in tests, switched _triangle_mesh.py to try/except ImportError for consistency with _triangle_scene.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All three BVH functions computed batch_shape from ray_origins alone, ignoring broadcasting with other inputs. The brute-force originals use jnp.broadcast_shapes across all inputs. This meant broadcastable inputs (legal per the #batch annotation) would produce wrong-shaped returns. The "..." return type annotations were masking this. Fixed with proper broadcast_shapes + broadcast_to, restored correct return type annotations. Also fixed the warning message that said "increase smoothing_factor" when it should say "decrease", moved geometry import to top level, and finished the smoothing rename in this file. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Minor, unrelated: CI is a bit slow to iterate on, perhaps worth checking out Ubicloud or Blacksmith |
|
Thx for the review @jeertmans! Went through everything together with Claude 🙂 For the tests I moved all the inline imports to top level, switched to Removed Moved Also renamed all the "hard mode"/"soft mode" references in For the smoothing BVH path (your comment on line 342), the code already does this: Tried always-on xla-ffi, broke pre-commit and MSRV check. The cxx bridge compiles C++ that For build.rs I improved the Python interpreter detection (PYTHON env, then pyo3_build_config, then python3 fallback), similar to your extending-jax setup. Kept the graceful warning instead of hard error since xla-ffi is still optional. Some smaller stuff: reverted the codespell workaround (#412 fixes it), cleaned up the redundant shape descriptions in FFI docstrings, moved For docs I kept |
|
Hi @rwydaegh, thanks for your work on this! I see you have made many changes and summarized them in one big comment. Could you please look at each comment I made in my first review and indicate whether you addressed it or not? Alternatively, if you disagree with a comment, simply reply with your opinion. This will help me keep track of progress and resolve conversations more easily. For this, you might need to click on and then on
to show all comments. Alternatively, here a links to my comments:
I believe most comments can already be resolved. I also see that the most comments no longer point to the original lines of code, which makes it hard to understand the original context of the comment... I don't know what happened.
What part of the CI is too slow? If you are talking about macOS, this is a known issue: the amount of macOS runners is very limited and the queue is usually very long before our CI gets to run. For the other tests, I believe it is ok. There exists alternatives, but are they free, or as unlimited as GitHub actions? Also note that you can very easily test your feature locally with On a side note, I'm releasing a new version today with added functionality. I plan to release your contribution in the next version. I'm okay with releasing it as an "experimental feature" in case the API changes later on if I find a better way to organize the library that is more transparent to the user. I will continue my review next week once I have finished preparing my slides for EuCAP :D |
|
Posted them all. Perhaps CI is actually okay then. For me experimental is fine, I've been testing it for real ray-tracing and so far it works as intended. Good luck with EuCAP! |
Forgot to push this earlier. The module docstring example now builds a two-triangle quad and uses nearest_hit to show a hit (idx=0, t=1.0) and a miss (idx=-1, t=inf) instead of just checking num_triangles. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
EDIT: looks like the docs build is failing, but it was working locally (-: Hi @rwydaegh! I have reviewed most of the code and pushed a refactor, to meet some of my criteria; I hope you don't mind. I left a few to-dos in the code, and made just that the lints were passing (except ones) and that the documentation was built properly. Tests will likely fail because I didn't remove all the occurrences of the "previous" ffi registration logic. I consider opening a meta issue with a list of idea of things to improve in a future PR. However, I'd like to have your opinion on this, and let me know if you think you can implement some of this points before merging:
|



Fixes Issue
Closes #313
Description
Summary
This PR adds a Rust BVH (Bounding Volume Hierarchy) to
differt-corewith both PyO3 and XLA FFI bindings, fully integrated into all threecompute_pathsmethods. It solves the O(rays * triangles) memory scaling that causes OOM on scenes with more than a few thousand triangles (#313).Key results:
[rays, 39K, 3]to[rays, ~300, 3]jax.jitandjax.lax.scanvia XLA FFIMotivation
DiffeRT's intersection functions allocate O(rays * triangles) intermediate arrays in JAX. On Munich (38K triangles), this means multi-GB arrays that OOM even on beefy GPUs. Jerome tried every pure-JAX approach:
vmap+sum: OOMlax.scan: correct but slowlax.map: correct but slowfori_loopwith batching: best compromise, still 20s+ on GPUThe JAX team confirmed in jax-ml/jax#30841 that
lax.reducecannot close over Tracers due to a StableHLO limitation with no fix planned. The only viable path is moving the ray-triangle loop out of JAX.Approach: Rust for spatial queries, JAX for math
The core idea is to split responsibilities: the BVH in Rust handles spatial acceleration (which triangles could a ray hit?), while the Moller-Trumbore intersection math stays in JAX (where it auto-differentiates through sigmoid smoothing). This means:
Two query paths serve different contexts:
jax.jit,lax.scan,vmapThe XLA FFI pipeline:
The BVH lives in a global
Mutex<HashMap<u64, Arc<Bvh>>>registry. Python callsbvh.register()to get an integer ID passed as an XLA attribute (compile-time constant). ADropimpl onTriangleBvhauto-unregisters when Python GC collects the object.Differentiable mode: the expanded BVH
For the soft path (
smoothing_factorset), boolean tests becomesigmoid(x * alpha). Triangles far from a ray have exponentially small sigmoid values. The expansion radius captures all triangles with gradient contribution above a threshold:When
r_nearexceeds the scene diagonal or candidate counts overflow, the system automatically falls back to brute force.Performance
Hard mode (nearest-hit)
BVH build is one-time (cached per scene). Query scales as O(rays * log(triangles)).
XLA FFI vs PyO3 (Munich, 38K triangles, 200 rays)
Soft mode (Munich, 38K triangles, 50 rays)
The soft-mode speedup is modest (2-3x) because JAX soft intersection on candidates still dominates. The real value is avoiding OOM.
Visibility estimation (hybrid method, Munich, 100K rays)
Usage
What changed
Rust (
differt-core/src/accel/, ~1,280 lines):bvh.rs: SAH-binned BVH construction, slab-method AABB traversal, Moller-Trumbore intersection, global registry withArc<Bvh>, PyO3 bindings,Dropfor auto-cleanup. Supports optionalactive_maskto skip inactive triangles during traversal.ffi.rs: cxx bridge declarations, FFI entry points, PyCapsule exports forjax.ffi.register_ffi_target.C++ (
src/ffi.cc+include/ffi.h, ~120 lines): XLA FFI handlers that decode buffers and call Rust via cxx.Build (
build.rs, 45 lines): Queries Python for JAX XLA include paths, compiles C++ via cxx-build. Gated behindxla-ffiCargo feature.Python (
differt/src/differt/accel/, ~730 lines):_bvh.py:TriangleBvhwrapper with batch handling, register/unregister lifecycle._accelerated.py: Drop-in replacements forrays_intersect_any_triangle,first_triangles_hit_by_rays,triangles_visible_from_vertices._ffi.py: JAX FFI wrappers (ffi_nearest_hit,ffi_get_candidates).Scene integration (
_triangle_scene.py, +120 lines):build_bvh()method,compute_paths(bvh=...)parameter wired into exhaustive, hybrid, and SBR methods.Tests (15 Rust + 29 Python): Construction, traversal, active_mask filtering, empty BVH, soft/hard mode, expansion radius, visibility, full compute_paths integration for all three methods.
Known limitations
best_tpruning in traversal. The AABB slab test checks ray intersection but does not cull nodes farther than the current closest hit. Correct but suboptimal for dense scenes.ffi_get_candidatesexists but the soft blocking check in_compute_pathsstill falls back to brute force. The standalonebvh_rays_intersect_any_triangledoes use the BVH for soft mode.1e-8; the JAX brute-force path uses~1.2e-6(10 * f32 eps). This can cause minor disagreements on grazing rays (<0.1% of rays in practice).Test plan
cargo test -- accel::bvh)pytest differt/tests/accel/test_bvh.py)active_maskcorrectly skips inactive triangles to find nearest active hitChecklist
Check all the applicable boxes:
Note to reviewers