diff --git a/.github/workflows/Documenter.yml b/.github/workflows/Documenter.yml index 57779e1..7422a67 100644 --- a/.github/workflows/Documenter.yml +++ b/.github/workflows/Documenter.yml @@ -21,11 +21,13 @@ permissions: id-token: write statuses: write -# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. -# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. +# Concurrency: PR builds cancel previous in-flight builds for the same ref +# (latest commit wins; no point sitting in a long queue). Master/tag pushes +# go through the production deploy path and must NOT cancel an in-flight +# deploy, so they keep `cancel-in-progress: false` via a different group. concurrency: - group: pages - cancel-in-progress: false + group: pages-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: # Build job diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d8254d9..bedd494 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,25 +6,47 @@ on: tags: '*' pull_request: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Tests + name: Tests (${{ matrix.backend }}) runs-on: ubuntu-latest + # Concurrency lives on the job (not the workflow) so it can reference + # `matrix.backend`, which is only in scope at the job level. Skip / + # cancel intermediate builds per matrix entry independently. + concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.backend }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + strategy: + fail-fast: false + matrix: + backend: [cpu, lavapipe] env: - DISPLAY: ':0' + RAYCORE_TEST_BACKEND: ${{ matrix.backend }} steps: - name: Checkout uses: actions/checkout@v4 + - name: Install xvfb (GLFW pulled in by Lava needs an X display at module __init__) + run: | + sudo apt-get update + sudo apt-get install -y xvfb + - name: Install lavapipe (mesa software Vulkan) + if: matrix.backend == 'lavapipe' + run: | + sudo apt-get install -y mesa-vulkan-drivers vulkan-tools spirv-tools + echo "VK_DRIVER_FILES=/usr/share/vulkan/icd.d/lvp_icd.x86_64.json" >> "$GITHUB_ENV" + # Confirm lavapipe is the active ICD + xvfb-run --auto-servernum vulkaninfo --summary | grep -E "deviceName|driverName" | head -4 || true - uses: julia-actions/setup-julia@v2 with: version: 1 arch: x64 - uses: julia-actions/cache@v2 - - name: Install pkgs dependencies - run: julia --project=@. -e 'using Pkg; Pkg.test("Raycore", coverage=true)' - uses: julia-actions/julia-runtest@v1 + with: + # --check-bounds=auto: GPU kernels (lavapipe / OpenCL / pocl) crash with + # --check-bounds=yes because bounds checking injects error paths that + # can't compile to SPIR-V. + check_bounds: auto + # Wrap julia with xvfb-run so GLFW's __init__ (transitively loaded + # via Lava) can call glfwInit() on a headless runner. + prefix: xvfb-run --auto-servernum diff --git a/Project.toml b/Project.toml index 17cf1b3..c69f65c 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,10 @@ version = "0.1.1" authors = ["Anton Smirnov ", "Simon Danisch UInt32(mi)) # Cast rays and find intersections ray = Ray(o=Point3f(0, 0, 0), d=Vec3f(0, 0, 1)) -hit_found, triangle, distance, bary_coords = closest_hit(bvh, ray) +hit_found, triangle, distance, bary_coords, instance_id = closest_hit(tlas, ray) if hit_found hit_point = ray.o + ray.d * distance @@ -46,13 +48,38 @@ end ```julia # Calculate scene centroid from a viewing direction viewdir = normalize(Vec3f(0, 0, -1)) -hitpoints, centroid = get_centroid(bvh, viewdir) +hitpoints, centroid = get_centroid(tlas, viewdir) # Analyze illumination -illumination = get_illumination(bvh, viewdir) +illumination = get_illumination(tlas, viewdir) # Compute view factors for radiosity -vf_matrix = view_factors(bvh; rays_per_triangle=1000) +vf_matrix = view_factors(tlas; rays_per_triangle=1000) +``` + +### Hardware ray tracing (Vulkan) + +For hardware-accelerated ray tracing, use `Lava.HWTLAS` as a drop-in replacement: + +```julia +using Raycore, Lava, GeometryBasics, StaticArrays, LinearAlgebra + +backend = Lava.LavaBackend() +hwtlas = Lava.HWTLAS(backend) +mesh = normal_mesh(Sphere(Point3f(0, 0, 2), 1.0f0)) +push!(hwtlas, mesh, SMatrix{4,4,Float32}(I); instance_id=UInt32(1)) +Raycore.sync!(hwtlas) +``` + +See the [HW RT tutorial](https://juliageometry.github.io/Raycore.jl/dev/hw_acceleration.html) for the full setup. + +## Testing + +Run tests with `--check-bounds=auto` (not the `Pkg.test` default of `--check-bounds=yes`), because GPU kernels compiled with bounds checking generate SPIR-V that crashes pocl: + +```julia +using Pkg +Pkg.test("Raycore"; julia_args=`--check-bounds=auto`) ``` ## Documentation diff --git a/benchmarks/implicitbvh_comparison.md b/benchmarks/implicitbvh_comparison.md new file mode 100644 index 0000000..9ba2551 --- /dev/null +++ b/benchmarks/implicitbvh_comparison.md @@ -0,0 +1,65 @@ +# ImplicitBVH.jl vs Raycore.jl — GPU Benchmark + +**Date**: 2026-03-29 +**GPU**: AMD RX 7900 XTX (RDNA3) +**Backend**: AMDGPU.jl (ROCArray) +**Mesh**: xyzrgb_dragon.obj (249,882 triangles) + procedural random geometry + +## BVH Build + +| Triangles | ImplicitBVH | Raycore | Ratio | +|-----------|-------------|---------|-------| +| 250K | 0.98 ms | 4.93 ms | ImplicitBVH 5.0x faster | +| 1M | 2.25 ms | 7.46 ms | ImplicitBVH 3.3x faster | +| 4M | 8.41 ms | 16.16 ms | ImplicitBVH 1.9x faster | + +ImplicitBVH builds faster due to simpler construction (Morton sort + bottom-up aggregate). +Raycore does more work: topology emission, parent pointers, leaf creation, atomic refit — all separate kernel launches. + +## Ray Tracing — Dragon Mesh (249K triangles) + +**Important**: ImplicitBVH `traverse_rays` returns bounding volume candidates (broad-phase only). +Raycore `closest_hit` returns the actual closest triangle intersection (full narrow-phase). +These are fundamentally different operations — ImplicitBVH does less work per ray but doesn't give a usable hit result. + +| Rays | ImplicitBVH (LVT) | Raycore | Speedup (Raycore) | +|------|--------------------|---------|--------------------| +| 100K | 4.60 ms | 1.33 ms | 3.5x | +| 500K | 11.06 ms | 3.14 ms | 3.5x | +| 1M | 20.84 ms | 3.00 ms | 6.9x | +| 2M | 41.52 ms | 6.00 ms | 6.9x | +| 4M | 83.31 ms | 5.91 ms | 14.1x | + +## Ray Tracing — Scaling with Triangle Count (1M rays) + +| Triangles | ImplicitBVH (BFS) | Raycore | Speedup (Raycore) | +|-----------|--------------------|---------|--------------------| +| 250K | 43.89 ms | 8.99 ms | 4.9x | +| 1M | 217.37 ms | 11.08 ms | 19.6x | +| 4M | 313.0 ms | 15.41 ms | 20.3x | + +## Why Raycore Is Faster for Ray Tracing + +| Factor | ImplicitBVH | Raycore | +|--------|-------------|---------| +| Output | All BV candidates (variable-size list) | Single closest hit (fixed) | +| Triangle test | None (BSphere overlap only) | Moller-Trumbore per leaf | +| Passes | Two-pass (count + write) | Single-pass | +| Early termination | No — finds all overlaps | Yes — t_max shrinks on hit | +| Node layout | Implicit tree + skip array | Inline leaves (BVH2IL) | +| Allocations | Output buffer per trace | None | + +## What ImplicitBVH Does Better + +- **Build speed** (2-5x faster) — fewer kernel launches, implicit indexing +- **Collision detection** — LVT (leaf-vs-tree) two-pass is designed for finding all contact pairs +- **Two-BVH collision** — native support for inter-object contact detection +- **Cache reuse** — `BVHTraversal` cache avoids re-allocation across frames +- **Mixed bounding volumes** — BSphere leaves with BBox internal nodes + +## Reference: ImplicitBVH README Numbers (Nvidia A100) + +From the ImplicitBVH.jl README (249,882 triangles, BSphere/BBox): +- Build: 410 us +- Contact detection: 1.14 ms +- 100K rays: 2.00 ms diff --git a/docs/Project.toml b/docs/Project.toml index 209abc4..1ce89af 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bonito = "824d6782-a2ef-11e9-3a09-e5662e0c26f8" BonitoBook = "b416d416-7a6e-4336-8c1a-1f8a8cd59518" @@ -9,9 +10,13 @@ GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" MeshIO = "7269a6da-0436-5bbc-96c2-40638cbb6118" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Raycore = "afc56b53-c9a9-482a-a956-d1d800e05559" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008" [sources] diff --git a/docs/examples.jl b/docs/examples.jl index 6dcae50..0fce655 100644 --- a/docs/examples.jl +++ b/docs/examples.jl @@ -14,8 +14,11 @@ begin l = 0.5 floor = Rect3f(-l, -l, -0.01, 2l, 2l, 0.01) cat = load(Makie.assetpath("cat.obj")) - bvh = Raycore.BVH([s1, s2, s3, s4, cat]); - world_mesh = GeometryBasics.Mesh(bvh) + bvh = Raycore.TLAS([normal_mesh(s1), normal_mesh(s2), normal_mesh(s3), normal_mesh(s4), cat], (mi, ti) -> UInt32(mi)); + # TODO: examples.jl needs rewrite for TLAS API + # bvh.primitives → iterate tlas.blas_array[i].primitives + # GeometryBasics.Mesh(bvh) → use Makie extension: plot(tlas) + world_mesh = Makie.convert_arguments(Makie.Mesh, bvh)[1] f, ax, pl = Makie.mesh(world_mesh; color=:teal) center!(ax.scene) viewdir = normalize(ax.scene.camera.view_direction[]) diff --git a/docs/gpu-optimization-guidelines.md b/docs/gpu-optimization-guidelines.md new file mode 100644 index 0000000..2a09e29 --- /dev/null +++ b/docs/gpu-optimization-guidelines.md @@ -0,0 +1,215 @@ +# GPU Optimization Guidelines for Julia + +This document covers common pitfalls that cause GPU kernel compilation failures or poor performance in Julia, and provides patterns to avoid them. + +## The Core Problem: Dynamic Dispatch on GPU + +GPUs require fully static, type-stable code. Any operation that requires runtime type dispatch will fail to compile or cause `ijl_get_nth_field_checked` errors. Common causes: + +1. **Closure boxing** - captured variables become heap-allocated `Core.Box` +2. **Heterogeneous tuple iteration** - `for x in tuple` with mixed types +3. **Type inference limits** - tuples > 32 elements, deeply nested types +4. **Abstract field access** - accessing fields through abstract types + +## Compiler Limits to Know + +| Constant | Default | Effect when exceeded | +|----------|---------|---------------------| +| `MAX_TUPLETYPE_LEN` | ~32 | Tuple types lose precise inference | +| `MAX_TYPE_DEPTH` | varies | Nested types get widened to supertypes | +| `MAX_UNION_SPLITTING` | ~4 | Union types use dynamic dispatch | +| Inlining threshold | ~100 cycles | Functions won't inline | + +## Pattern 1: Avoid `for` Loops Over Heterogeneous Tuples + +### Bad - Causes boxing and dynamic dispatch +```julia +function sum_lights(lights::Tuple, ray) + result = RGBSpectrum(0f0) + for light in lights # Creates dynamic iteration! + result += le(light, ray) + end + return result +end +``` + +### Good - Recursive tuple traversal (compile-time unrolling) +```julia +@inline sum_lights(::Tuple{}, ray) = RGBSpectrum(0f0) +@inline function sum_lights(lights::Tuple, ray) + return le(first(lights), ray) + sum_lights(Base.tail(lights), ray) +end +``` + +### Better - Use `for_unrolled` with explicit arguments +```julia +# Avoids closure capture entirely by passing all data as arguments +result = for_unrolled(sum_light_contribution, lights, ray, initial_value) +``` + +## Pattern 2: Avoid Closure Capture + +### Bad - Variable capture causes boxing +```julia +function process(data, threshold) + # `threshold` gets boxed because it's captured + map(x -> x > threshold ? x : zero(x), data) +end +``` + +### Good - Use `let` block to create immutable binding +```julia +function process(data, threshold) + f = let t = threshold + x -> x > t ? x : zero(x) + end + map(f, data) +end +``` + +### Better - Avoid closures entirely, pass as argument +```julia +function process(data, threshold) + map((x, t) -> x > t ? x : zero(x), data, Ref(threshold)) +end + +# Or use a functor +struct ThresholdFilter{T} + threshold::T +end +(f::ThresholdFilter)(x) = x > f.threshold ? x : zero(x) +``` + +## Pattern 3: Use `for_unrolled` for GPU-Safe Iteration + +The `for_unrolled` function provides compile-time loop unrolling without closure capture: + +```julia +# Instead of: +for i in 1:N + process(data[i], extra_arg) +end + +# Use: +for_unrolled(process_item, Val(N), data, extra_arg) +# Where process_item(i, data, extra_arg) is your function +``` + +### With Tuples (heterogeneous types) +```julia +# Bad: for light in lights +# Good: +result = for_unrolled( + accumulate_light, # function(elem, acc, ray) -> new_acc + lights, # tuple to iterate + RGBSpectrum(0f0), # initial accumulator + ray # extra arguments... +) +``` + +## Pattern 4: Ensure Type Stability + +### Check with `@code_warntype` +```julia +@code_warntype my_kernel_function(args...) +# Look for: +# - `Any` types (red in color terminals) +# - `Core.Box` (captured variables) +# - `Union{...}` with many types +``` + +### Use JET.jl for deeper analysis +```julia +using JET +@report_opt my_function(args...) +``` + +## Pattern 5: Use Concrete Types in Structs + +### Bad - Abstract field types +```julia +struct Scene + lights::Vector{Light} # Abstract element type +end +``` + +### Good - Parameterized concrete types +```julia +struct Scene{L<:Tuple} + lights::L # Concrete tuple type, e.g., Tuple{SunLight, PointLight} +end +``` + +## Pattern 6: Avoid Runtime Allocations + +### Bad - Creates intermediate arrays +```julia +function compute(points) + distances = [norm(p) for p in points] # Allocates! + return sum(distances) +end +``` + +### Good - Fuse operations +```julia +function compute(points) + total = 0f0 + for p in points + total += norm(p) + end + return total +end +``` + +### For small fixed-size data, use tuples/StaticArrays +```julia +# Heap allocated (bad for GPU registers) +coords = [1.0f0, 2.0f0, 3.0f0] + +# Stack allocated (good) +coords = (1.0f0, 2.0f0, 3.0f0) +coords = SVector{3, Float32}(1, 2, 3) +``` + +## Pattern 7: Use 32-bit Integers + +```julia +# Bad - promotes to Int64 +idx = blockIdx().x - 1 + +# Good - stays Int32 +idx = blockIdx().x - Int32(1) + +# Helper function +gpu_int(x) = x % Int32 +``` + +## Quick Reference: GPU-Safe Alternatives + +| Avoid | Use Instead | +|-------|-------------| +| `for x in heterogeneous_tuple` | Recursive functions or `for_unrolled` | +| `x -> f(x, captured_var)` | `let` blocks or pass args explicitly | +| `Vector{AbstractType}` | `Tuple` or `Vector{ConcreteType}` | +| Dynamic `if typeof(x) == ...` | Multiple dispatch | +| `Int64` literals | `Int32` / `gpu_int()` | +| `Array` in kernel | `Tuple` or `SVector` | + +## Debugging GPU Compilation Errors + +When you see errors like: +- `ijl_get_nth_field_checked` - Dynamic field access (boxing) +- `jl_apply_generic` - Dynamic dispatch +- `jl_gc_*` - Heap allocation attempted + +Steps to debug: +1. Identify the function mentioned in the stack trace +2. Run `@code_warntype` on that function +3. Look for `Core.Box`, `Any`, or large `Union` types +4. Apply the patterns above to make the code type-stable + +## See Also + +- [Julia Performance Tips](https://docs.julialang.org/en/v1/manual/performance-tips/) +- [CUDA.jl Performance Tips](https://cuda.juliagpu.org/stable/tutorials/performance/) +- [Julia Issue #15276](https://github.com/JuliaLang/julia/issues/15276) - Closure boxing diff --git a/docs/instanced-bvh-architecture.md b/docs/instanced-bvh-architecture.md new file mode 100644 index 0000000..be65523 --- /dev/null +++ b/docs/instanced-bvh-architecture.md @@ -0,0 +1,454 @@ +# Instanced BVH Architecture + +## Overview + +This document describes the two-level instanced BVH implementation in Raycore, inspired by AMD's RadeonRays SDK architecture. The design enables efficient ray tracing of scenes with repeated geometry (instances) while maintaining GPU/CPU portability through parametrized array types. + +## Architecture + +### Two-Level Hierarchy + +The instanced BVH uses a **two-level acceleration structure**: + +1. **BLAS (Bottom-Level Acceleration Structure)**: BVH over triangle geometry + - One BLAS per unique mesh + - Built once, reused for all instances + - Stores geometry in local/object space + +2. **TLAS (Top-Level Acceleration Structure)**: BVH over instances + - Contains transformed instances of BLAS objects + - Each instance has a transformation matrix + - Built per-frame or when instances change + +### Memory Layout + +``` +TLAS (Scene-level) +├── Node[0]: Root (Interior) +│ ├── Child0 → Node[1] +│ └── Child1 → Node[2] +├── Node[1]: Interior +│ ├── Child0 → Instance 0 (Leaf) +│ └── Child1 → Instance 1 (Leaf) +└── Node[2]: Instance 2 (Leaf) + +BLAS (Mesh-level) +├── Node[0]: Root (Interior) +│ ├── Child0 → Node[1] +│ └── Child1 → Triangle 0 (Leaf) +└── Node[1]: Triangle 1 (Leaf) +``` + +## Core Data Structures + +### BVHNode2 + +Compact BVH node for binary trees (BVH2IL layout from RadeonRays): + +```julia +struct BVHNode2 + aabb0_min::Point3f # Child 0 AABB min + aabb0_max::Point3f # Child 0 AABB max + aabb1_min::Point3f # Child 1 AABB min + aabb1_max::Point3f # Child 1 AABB max + child0::UInt32 # Child 0 index (INVALID_NODE for leaves) + child1::UInt32 # Child 1 index (primitive for leaves) + parent::UInt32 # Parent node index +end +``` + +**Design Rationale:** +- **Inline AABBs**: Storing both children's AABBs directly in the parent node enables branchless intersection tests +- **Compact size**: 64 bytes per node for cache efficiency +- **Leaf identification**: `child0 == INVALID_NODE` marks leaf nodes +- **Flexible use**: Same structure for both BLAS and TLAS + +### InstanceDescriptor + +Describes an instance of a BLAS in world space: + +```julia +struct InstanceDescriptor + blas_index::UInt32 # Which BLAS to instance + instance_id::UInt32 # User-defined ID + transform::Mat4f # Local-to-world transformation + inv_transform::Mat4f # World-to-local transformation + flags::UInt32 # Reserved for future use +end +``` + +### BLAS + +```julia +struct BLAS{NodeArray <: AbstractVector{BVHNode2}, + TriArray <: AbstractVector{<:Triangle}} + nodes::NodeArray # BVH nodes + primitives::TriArray # Triangles (sorted by Morton code) + root_aabb::Bounds3 # Bounding box in local space +end +``` + +**Type Parameters:** +- `NodeArray`: Can be `Vector`, `CuArray`, `ROCArray`, etc. +- `TriArray`: Same - enables CPU/GPU execution + +### TLAS + +```julia +struct TLAS{NodeArray <: AbstractVector{BVHNode2}, + InstArray <: AbstractVector{InstanceDescriptor}, + BLASArray <: AbstractVector{<:BLAS}} + nodes::NodeArray # Top-level BVH nodes + instances::InstArray # Instance descriptors + blas_array::BLASArray # Array of BLAS objects + root_aabb::Bounds3 # World-space bounding box +end +``` + +## Construction Algorithm + +### LBVH (Linear BVH) + +Both BLAS and TLAS use the **LBVH algorithm** (Karras 2012): + +#### Step 1: Compute Scene AABB +```julia +scene_aabb = mapreduce(world_bound, ∪, primitives) +``` + +#### Step 2: Calculate 30-bit Morton Codes + +Morton codes provide a space-filling Z-curve ordering: + +```julia +function morton_code_30bit(p::Point3f)::UInt32 + # Normalize to [0, 1023] (10 bits per axis) + unit_side = 1024.0f0 + x = clamp(p[1] * unit_side, 0.0f0, unit_side - 1.0f0) + y = clamp(p[2] * unit_side, 0.0f0, unit_side - 1.0f0) + z = clamp(p[3] * unit_side, 0.0f0, unit_side - 1.0f0) + + # Interleave bits: xxyyzzxxyyzzxxyyzz... + return (expand_bits(UInt32(x)) << 2) | + (expand_bits(UInt32(y)) << 1) | + expand_bits(UInt32(z)) +end +``` + +**Why 30 bits?** +- 10 bits per axis = 1024³ spatial resolution +- Leaves 2 bits for flags if needed +- Fits comfortably in UInt32 + +#### Step 3: Sort Primitives + +```julia +sorted_indices = sortperm(morton_codes) +morton_codes .= morton_codes[sorted_indices] +sorted_prims = primitives[sorted_indices] +``` + +#### Step 4: Build Binary Radix Tree + +Using Karras' algorithm to find node spans: + +```julia +# Find span of internal node i +d_left = delta(i, i-1, morton_codes, n) +d_right = delta(i, i+1, morton_codes, n) +direction = sign(d_right - d_left) + +# Binary search for exact span +# ... (see implementation) + +# Find split point +split = find_split(span_left, span_right, morton_codes, n) + +# Determine children +child0 = (split == span_left) ? leaf(split) : internal(split) +child1 = (split+1 == span_right) ? leaf(split+1) : internal(split+1) +``` + +**Key function: `delta` (Longest Common Prefix)** + +```julia +function delta(i1::Int32, i2::Int32, codes::Vector{UInt32}, n::Int32)::Int32 + left = min(i1, i2) + right = max(i1, i2) + + (left < 1 || right > n) && return Int32(-1) + + left_code = codes[left] + right_code = codes[right] + + # If codes differ, count common prefix bits + # If codes identical, use indices as tiebreaker + if left_code != right_code + return Int32(clz32(left_code ⊻ right_code)) + else + return Int32(32 + clz32(UInt32(left) ⊻ UInt32(right))) + end +end +``` + +#### Step 5: Compute AABBs Bottom-Up + +```julia +# Create leaf nodes with primitive AABBs +for i in 1:n + leaf_idx = leaf_index(i, n) + tri_aabb = world_bound(sorted_prims[i]) + nodes[leaf_idx] = create_leaf(tri_aabb, i) +end + +# Propagate AABBs upward +for i in (n-1):-1:1 + child0_aabb = get_node_aabb(nodes[nodes[i].child0]) + child1_aabb = get_node_aabb(nodes[nodes[i].child1]) + nodes[i] = create_interior(child0_aabb ∪ child1_aabb, ...) +end +``` + +## Traversal Algorithm + +### Two-Level Traversal with Stack + +```julia +function closest_hit(tlas::TLAS, ray::AbstractRay) + # Initialize state + stack = MVector{64, UInt32}(undef) + stack_ptr = 1 + current_node_idx = 1 + current_instance = INVALID_NODE + + while current_node_idx != INVALID_NODE + # Fetch node (from TLAS or BLAS) + node = (current_instance == INVALID_NODE) ? + tlas.nodes[current_node_idx] : + tlas.blas_array[current_instance].nodes[current_node_idx] + + # Test ray-AABB intersection + if intersect_aabb(node, ray) + if is_leaf(node) + if current_instance == INVALID_NODE + # Top-level leaf: transition to BLAS + instance_idx = node.child1 + inst = tlas.instances[instance_idx] + + # Push sentinel + stack[stack_ptr++] = TOP_LEVEL_SENTINEL + + # Transform ray to local space + ray_local = transform_ray(inst.inv_transform, ray) + + # Switch to BLAS traversal + current_instance = inst.blas_index + current_node_idx = 1 + else + # Bottom-level leaf: test triangle + test_triangle_intersection(...) + end + else + # Interior node: push far, visit near + push_and_traverse(...) + end + else + # Pop from stack + current_node_idx = stack[--stack_ptr] + + # Check for level transition + if current_node_idx == TOP_LEVEL_SENTINEL + current_node_idx = stack[--stack_ptr] + current_instance = INVALID_NODE + ray = restore_original_ray() + end + end + end +end +``` + +### Key Features + +1. **Sentinel-based Level Switching**: `TOP_LEVEL_SENTINEL (0xFFFFFFFE)` marks transitions +2. **Ray Transformation**: Transform once when entering BLAS, restore when returning +3. **Hybrid Stack**: LDS (16 entries) + global (64 entries) for GPU efficiency +4. **Ordered Traversal**: Visit near child first for early termination + +## Type Stability + +### Critical for GPU Performance + +All functions must be **type-stable** to generate efficient GPU kernels: + +```julia +# ✓ Type-stable +@inline function get_node_aabb(node::BVHNode2, is_interior::Bool)::Bounds3 + if is_interior + Bounds3(min.(node.aabb0_min, node.aabb1_min), + max.(node.aabb0_max, node.aabb1_max)) + else + Bounds3(node.aabb0_min, node.aabb0_max) + end +end + +# ✗ NOT type-stable (return type depends on runtime value) +function bad_example(node::BVHNode2, flag::Bool) + if flag + return node.aabb0_min # Point3f + else + return node.child0 # UInt32 + end +end +``` + +### Testing Type Stability + +```julia +using Test + +@testset "Type Stability" begin + node = BVHNode2(...) + + # Should infer return type at compile time + result = @inferred get_node_aabb(node, true) + @test result isa Bounds3 +end +``` + +## GPU Portability + +### Parametrized Array Types + +```julia +# CPU +cpu_blas = BLAS{Vector{BVHNode2}, Vector{Triangle}}(...) + +# CUDA +using CUDA +gpu_nodes = CuArray(cpu_blas.nodes) +gpu_tris = CuArray(cpu_blas.primitives) +gpu_blas = BLAS{CuArray{BVHNode2}, CuArray{Triangle}}( + gpu_nodes, gpu_tris, cpu_blas.root_aabb +) +``` + +### KernelAbstractions Integration + +```julia +using KernelAbstractions + +@kernel function traverse_rays_kernel!(hits, @Const(tlas), @Const(rays)) + i = @index(Global) + ray = rays[i] + hit, tri, t, bary, inst_id = closest_hit(tlas, ray) + hits[i] = (hit, t, inst_id) +end + +# Execute on different backends +backend = get_backend(rays) # CPU, CUDA, ROC, oneAPI, etc. +kernel! = traverse_rays_kernel!(backend) +kernel!(hits, tlas, rays, ndrange=length(rays)) +``` + +## Performance Characteristics + +### Construction + +- **LBVH**: O(n log n) due to sorting +- **Parallel-friendly**: Each internal node computed independently +- **Memory**: ~64 bytes per node × (2n-1) nodes + +### Traversal + +- **Stack depth**: Typically 20-40 for balanced trees, max 64 +- **Memory access**: Sequential node reads (cache-friendly) +- **Transform overhead**: 2 matrix-vector mults per instance transition + +### Instancing Benefits + +**Memory Savings:** +- Without instancing: n_instances × n_triangles × sizeof(Triangle) +- With instancing: n_triangles × sizeof(Triangle) + n_instances × 128 bytes +- **Example**: 1000 cubes = 98% memory reduction + +**Update Performance:** +- Updating instance transform: O(1) + TLAS rebuild +- No BLAS rebuild needed for rigid transformations + +## Comparison to Radeon Rays + +### Similarities + +✓ BVH2IL node layout (inline children AABBs) +✓ LBVH construction with 30-bit Morton codes +✓ Two-level hierarchy (TLAS/BLAS) +✓ Sentinel-based level switching +✓ Transform caching in instances + +### Differences + +- **No SAH restructuring**: RadeonRays has optional treelet optimization +- **Simplified construction**: No multi-threaded workgroup coordination +- **Julia-specific**: Leverages multiple dispatch and type system +- **No hardware RT**: Pure software traversal (can use RT when available) + +## Future Enhancements + +### Short Term + +1. **Complete triangle intersection in traversal** +2. **Implement any_hit for occlusion queries** +3. **Add multi-instance TLAS builder** (currently only single instance) + +### Medium Term + +1. **SAH-based restructuring** (treelet optimization) +2. **Compressed-wide BVH (CWBVH)** for higher branching factor +3. **Hardware ray tracing backend** (Vulkan RT, OptiX, DXR) + +### Long Term + +1. **Dynamic BVH updates** (refit without full rebuild) +2. **Streaming geometry** (out-of-core BVH) +3. **Neural BVH optimization** (learned splitting) + +## References + +1. Karras, T. (2012). "Maximizing Parallelism in the Construction of BVHs, Octrees, and k-d Trees" +2. AMD RadeonRays SDK: https://github.com/GPUOpen-LibrariesAndSDKs/RadeonRays_SDK +3. Ylitie, H. et al. (2017). "Efficient Incoherent Ray Traversal on GPUs Through Compressed Wide BVHs" +4. Meister, D. et al. (2020). "A Survey on Bounding Volume Hierarchies for Ray Tracing" + +## Example Usage + +```julia +using Raycore +using GeometryBasics +import KernelAbstractions as KA + +# Create geometry and build TLAS using the high-level API +cube_mesh = normal_mesh(Rect3f(Point3f(-0.5), Vec3f(1.0))) + +tlas = Raycore.TLAS(KA.CPU()) + +# Add 10 instances of the cube at different positions +for i in 1:10 + t = Raycore.translate(Vec3f(i * 2, 0, 0)) + push!(tlas, cube_mesh, t.m) +end + +Raycore.sync!(tlas) + +# Get immutable StaticTLAS for traversal +static = Adapt.adapt(KA.CPU(), tlas) + +# Trace ray +ray = Ray(o=Point3f(0, 0, -5), d=Vec3f(0, 0, 1)) +hit, tri, t, bary, inst_id = closest_hit(static, ray) + +if hit + println("Hit instance $inst_id at distance $t") + hit_point = ray.o + ray.d * t + println("Hit point: $hit_point") +end +``` diff --git a/docs/make.jl b/docs/make.jl index e405a22..d250c60 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs(; "BVH Hit Tests" => "bvh_hit_tests.md", "Ray Tracing Tutorial" => "raytracing_tutorial.md", "GPU Ray Tracing Tutorial" => "gpu_raytracing.md", + "Hardware RT Acceleration" => "hw_acceleration.md", "View Factors and More" => "viewfactors.md", ], ], diff --git a/docs/src/assets/hw_acceleration_compare.png b/docs/src/assets/hw_acceleration_compare.png new file mode 100644 index 0000000..66611da Binary files /dev/null and b/docs/src/assets/hw_acceleration_compare.png differ diff --git a/docs/src/bvh_hit_tests_content.md b/docs/src/bvh_hit_tests_content.md index 82154f5..319cdd1 100644 --- a/docs/src/bvh_hit_tests_content.md +++ b/docs/src/bvh_hit_tests_content.md @@ -1,6 +1,6 @@ -# BVH Hit Testing: `closest_hit` vs `any_hit` +# TLAS Hit Testing: `closest_hit` vs `any_hit` -This document tests and visualizes the difference between `closest_hit` and `any_hit` functions in the BVH implementation using the new `RayIntersectionSession` API. +This document tests and visualizes the difference between `closest_hit` and `any_hit` functions in the TLAS implementation using the `trace_rays` API. ## Test Setup @@ -9,19 +9,24 @@ using Raycore, GeometryBasics, LinearAlgebra using WGLMakie using Test using Bonito +import KernelAbstractions as KA # Create a simple test scene with multiple overlapping primitives function create_test_scene() # Three spheres at different distances along the Z-axis - sphere1 = Tesselation(Sphere(Point3f(0, 0, 5), 1.0f0), 20) # Furthest - sphere2 = Tesselation(Sphere(Point3f(0, 0, 3), 1.0f0), 20) # Middle - sphere3 = Tesselation(Sphere(Point3f(0, 0, 1), 1.0f0), 20) # Closest - - bvh = Raycore.BVH([sphere1, sphere2, sphere3]) - return bvh + sphere1 = normal_mesh(Sphere(Point3f(0, 0, 5), 1.0f0)) # Furthest + sphere2 = normal_mesh(Sphere(Point3f(0, 0, 3), 1.0f0)) # Middle + sphere3 = normal_mesh(Sphere(Point3f(0, 0, 1), 1.0f0)) # Closest + + tlas = Raycore.TLAS(KA.CPU()) + push!(tlas, sphere1) + push!(tlas, sphere2) + push!(tlas, sphere3) + sync!(tlas) + return tlas end -bvh = create_test_scene() +tlas = create_test_scene() ``` ## Test 1: Single Ray Through Center @@ -31,19 +36,14 @@ Test a ray through the center that passes through all three spheres. # Create a ray with slight offset to avoid hitting triangle vertices exactly test_ray = Raycore.Ray(o=Point3f(0.1, 0.1, -5), d=Vec3f(0, 0, 1)) -# Create session with closest_hit -session_closest = RayIntersectionSession(Raycore.closest_hit, [test_ray], bvh) - -# Create session with any_hit for comparison -session_any = RayIntersectionSession(Raycore.any_hit, [test_ray], bvh) +# Trace with closest_hit (default) +result_closest = trace_rays(tlas, [test_ray]) fig = Figure() # Left: closest_hit visualization -plot(fig[1, 1], session_closest; axis=(; show_axis=false)) -plot(fig[1, 2], session_any; axis=(; show_axis=false)) +plot(fig[1, 1], result_closest; axis=(; show_axis=false)) Label(fig[0, 1], "closest_hit", fontsize=20, font=:bold, tellwidth=false) -Label(fig[0, 2], "any_hit", fontsize=20, font=:bold, tellwidth=false) fig ``` @@ -57,64 +57,58 @@ test_positions = map(p-> (p = p.-0.5; Point3f(p..., -5)), rand(Point2f, 10)) # Create rays rays = [Raycore.Ray(o=pos, d=Vec3f(0, 0, 1)) for pos in test_positions] -# Create session -session_multi = RayIntersectionSession(Raycore.closest_hit, rays, bvh) -plot(session_multi; axis=(;show_axis=false)) +# Trace rays and visualize +result_multi = trace_rays(tlas, rays) +plot(result_multi; axis=(;show_axis=false)) ``` -## Visualization: Multiple Rays +## Test 3: Complex Scene -## Test 4: Difference Between any*hit and closest*hit - -Demonstrate that `any_hit` can return different results than `closest_hit`. +Demonstrate ray tracing through a complex scene with many overlapping objects. ```julia (editor=true, logging=false, output=true) # Create a complex scene with overlapping geometry -# This creates a BVH where traversal order can differ from distance order using Random Random.seed!(123) -complex_spheres = [] +complex_tlas = Raycore.TLAS(KA.CPU()) # Add some large overlapping spheres -push!(complex_spheres, Tesselation(Sphere(Point3f(0, 0, 10), 3.0f0), 20)) -push!(complex_spheres, Tesselation(Sphere(Point3f(0.5, 0, 5), 0.5f0), 15)) -push!(complex_spheres, Tesselation(Sphere(Point3f(-0.5, 0, 15), 1.5f0), 18)) +push!(complex_tlas, normal_mesh(Sphere(Point3f(0, 0, 10), 3.0f0))) +push!(complex_tlas, normal_mesh(Sphere(Point3f(0.5, 0, 5), 0.5f0))) +push!(complex_tlas, normal_mesh(Sphere(Point3f(-0.5, 0, 15), 1.5f0))) -# Add many small spheres to create complex BVH structure +# Add many small spheres to create complex TLAS structure for i in 1:30 x = randn() * 5 y = randn() * 5 z = rand(8.0:0.5:12.0) r = 0.3 + rand() * 0.5 - push!(complex_spheres, Tesselation(Sphere(Point3f(x, y, z), r), 8)) + push!(complex_tlas, normal_mesh(Sphere(Point3f(x, y, z), Float32(r)))) end -complex_bvh = Raycore.BVH(complex_spheres) -# Test rays to find cases where any_hit differs from closest_hit +sync!(complex_tlas) + +# Test rays test_rays = map(rand(Point2f, 20)) do p p = (p .* 14f0) .- 8f0 Raycore.Ray(o=Point3f(p..., -5), d=Vec3f(0, 0, 1)) end -session_closest = RayIntersectionSession(Raycore.closest_hit, test_rays, complex_bvh) -session_any = RayIntersectionSession(Raycore.any_hit, test_rays, complex_bvh) +result = trace_rays(complex_tlas, test_rays) + fig = Figure() -# Left: closest_hit visualization -plot(fig[1, 1], session_closest; axis=(; show_axis=false)) -plot(fig[1, 2], session_any; axis=(; show_axis=false)) +plot(fig[1, 1], result; axis=(; show_axis=false)) Label(fig[0, 1], "closest_hit", tellwidth=false) -Label(fig[0, 2], "any_hit", tellwidth=false) fig ``` **Key Findings:** - * `any_hit` exits on the **first** intersection during BVH traversal (uses `intersect`, doesn't update ray) - * `closest_hit` continues searching and updates ray's `t_max` (uses `intersect_p!`) - * In complex scenes with overlapping geometry, `any_hit` can return hits that are significantly farther + * `closest_hit` continues searching and updates ray's `t_max` to find the nearest intersection + * `any_hit` exits on the **first** intersection during TLAS traversal (useful for shadow rays) * Both always agree on **whether** a hit occurred (hit vs miss) - * The difference appears when BVH traversal order differs from spatial distance order + * `any_hit` is typically faster than `closest_hit` due to early termination ## Performance Comparison @@ -131,19 +125,21 @@ end ``` ```julia (editor=true, logging=false, output=true) using BenchmarkTools +using Adapt test_ray = Raycore.Ray(o=Point3f(0.1, 0.1, -5), d=Vec3f(0, 0, 1)) +static_tlas = Adapt.adapt(KA.CPU(), tlas) # Benchmark closest_hit -closest_time = @benchmark Raycore.closest_hit($bvh, $test_ray) +closest_time = @benchmark Raycore.closest_hit($static_tlas, $test_ray) # Benchmark any_hit -any_time = @benchmark Raycore.any_hit($bvh, $test_ray) +any_time = @benchmark Raycore.any_hit($static_tlas, $test_ray) perf_table = map([ - ("closest_hit", any_time), - ("any_hit", closest_time), + ("closest_hit", closest_time), + ("any_hit", any_time), ]) do (method, time_us) (Method = method, Time_μs = render_io(time_us)) end @@ -153,25 +149,25 @@ Bonito.Table(perf_table) This document demonstrated: -1. **`RayIntersectionSession`** - A convenient struct for managing ray tracing sessions +1. **`trace_rays`** - A convenient function for tracing rays against a TLAS and collecting results for visualization - * Bundles rays, BVH, hit function, and results together - * Provides helper functions: `hit_count()`, `miss_count()`, `hit_points()`, `hit_distances()` -2. **Makie visualization recipe** - Automatic visualization via `plot(session)` + * Returns a `RayIntersectionResult` bundling rays, hit data, and the TLAS + * Automatically builds a `StaticTLAS` for traversal +2. **Makie visualization recipe** - Automatic visualization via `plot(result)` - * Automatically renders BVH geometry, rays, and hit points + * Automatically renders TLAS geometry, rays, and hit points * Customizable colors, transparency, markers, and labels * Works with any Makie backend (GLMakie, WGLMakie, CairoMakie) 3. **`closest_hit`** correctly identifies the nearest intersection among multiple overlapping primitives - * Returns: `(hit_found::Bool, hit_primitive::Triangle, distance::Float32, barycentric_coords::Point3f)` - * `distance` is the distance from ray origin to the hit point - * Use `Raycore.sum_mul(bary_coords, primitive.vertices)` to convert to world-space hit point + * Returns: `(hit_found::Bool, triangle::Triangle, distance::Float32, bary_coords::SVector{3,Float32}, instance_id::UInt32)` + * Use `sum(bary_coords .* triangle.vertices)` to convert to world-space hit point 4. **`any_hit`** efficiently determines if any intersection exists, exiting early - * Returns: Same format as `closest_hit`: `(hit_found::Bool, hit_primitive::Triangle, distance::Float32, barycentric_coords::Point3f)` + * Returns: Same format as `closest_hit` * Can exit early on first hit found, making it faster for occlusion testing 5. Both functions handle miss cases correctly (returning `hit_found=false`) 6. `any_hit` is typically faster than `closest_hit` due to early termination All tests passed! ✓ + diff --git a/docs/src/gpu-benchmarks.png b/docs/src/gpu-benchmarks.png index f55ab7f..ac67dab 100644 Binary files a/docs/src/gpu-benchmarks.png and b/docs/src/gpu-benchmarks.png differ diff --git a/docs/src/gpu_raytracing_tutorial.md b/docs/src/gpu_raytracing_tutorial.md index 0713e54..73ed937 100644 --- a/docs/src/gpu_raytracing_tutorial.md +++ b/docs/src/gpu_raytracing_tutorial.md @@ -38,7 +38,7 @@ Let's use the exact same scene as the CPU tutorial - the Makie cat with room geo # Load and prepare the cat model include("raytracing-core.jl") bvh, ctx = example_scene() -md"**Scene loaded: $(length(bvh.primitives)) triangles, $(length(ctx.materials)) materials**" +md"**Scene loaded: $(length(bvh.all_blas_prims)) triangles, $(length(ctx.materials)) materials**" ``` ## Part 2: GPU Kernel Version 1 - Basic Naive Approach diff --git a/docs/src/hikari-wavefront-renderer.jl b/docs/src/hikari-wavefront-renderer.jl new file mode 100644 index 0000000..e6e58f9 --- /dev/null +++ b/docs/src/hikari-wavefront-renderer.jl @@ -0,0 +1,1090 @@ +""" +Wavefront Path Tracer integrated with Hikari MaterialScene. + +This renderer uses Hikari's MaterialScene and Light types while keeping +the original wavefront-renderer.jl shading model. Materials are treated +as data containers - we extract base_color, metallic, roughness from +Hikari materials but use the wavefront shading equations. +""" + +using Raycore, GeometryBasics, LinearAlgebra +using Raycore: gpu_int +using Colors +using KernelAbstractions +using KernelAbstractions: @kernel, @index, @Const +import KernelAbstractions as KA +using Statistics +import Makie + +import Hikari +using Hikari: MaterialScene, MaterialIndex, RGBSpectrum, PointLight, Light, LightδPosition + +# ============================================================================ +# SoA Access Macros (same as wavefront-renderer.jl) +# ============================================================================ + +macro get(expr) + if expr.head != :(=) + error("@get expects assignment syntax: @get field1, field2 = soa[idx]") + end + + lhs = expr.args[1] + rhs = expr.args[2] + + if lhs isa Symbol + fields = [lhs] + elseif lhs.head == :tuple + fields = lhs.args + else + error("@get left side must be field names or tuple of field names") + end + + if rhs.head != :ref + error("@get right side must be array indexing: soa[idx]") + end + soa = rhs.args[1] + idx = rhs.args[2] + + assignments = [:($(esc(field)) = $(esc(soa)).$(field)[$(esc(idx))]) for field in fields] + return Expr(:block, assignments...) +end + +macro set(expr) + if expr.head != :(=) + error("@set expects assignment syntax: @set soa[idx] = (field1=val1, ...)") + end + + lhs = expr.args[1] + rhs = expr.args[2] + + if lhs.head != :ref + error("@set left side must be array indexing: soa[idx]") + end + soa = lhs.args[1] + idx = lhs.args[2] + + assignments = [] + if rhs.head == :tuple || rhs.head == :parameters + for arg in rhs.args + if arg isa Expr && arg.head == :(=) + field = arg.args[1] + val = arg.args[2] + push!(assignments, :($(esc(soa)).$(field)[$(esc(idx))] = $(esc(val)))) + else + error("@set expects named parameters: @set soa[idx] = (field=value, ...)") + end + end + else + error("@set expects a tuple with named fields: @set soa[idx] = (field=value, ...)") + end + return Expr(:block, assignments...) +end + +# ============================================================================ +# Material Property Extraction from Hikari Materials +# ============================================================================ + +""" + WavefrontMaterialProps + +Simple material properties for wavefront shading. +Extracted from Hikari materials at shading time. +""" +struct WavefrontMaterialProps + base_color::Vec3f + metallic::Float32 + roughness::Float32 +end + +""" +Helper to get texture const_value (for constant textures used in our scene). +""" +@inline function get_texture_value(tex::Hikari.Texture{RGBSpectrum}) + return tex.const_value +end + +@inline function get_texture_value(tex::Hikari.Texture{Float32}) + return tex.const_value +end + +""" +Extract wavefront-compatible material properties from a Hikari MatteMaterial. +Matte materials are purely diffuse (metallic=0, roughness from sigma). +""" +@inline function extract_material_props(mat::Hikari.MatteMaterial, ::Point2f) + # Get diffuse color from Kd texture + kd = get_texture_value(mat.Kd) + base_color = Vec3f(kd.c[1], kd.c[2], kd.c[3]) + # σ is roughness in degrees for Oren-Nayar; convert to 0-1 range + σ = get_texture_value(mat.σ) + roughness = clamp(σ / 90f0, 0f0, 1f0) + return WavefrontMaterialProps(base_color, 0f0, roughness) +end + +""" +Extract wavefront-compatible material properties from a Hikari MirrorMaterial. +Mirrors are fully metallic (metallic=1, roughness=0). +""" +@inline function extract_material_props(mat::Hikari.MirrorMaterial, ::Point2f) + kr = get_texture_value(mat.Kr) + base_color = Vec3f(kr.c[1], kr.c[2], kr.c[3]) + return WavefrontMaterialProps(base_color, 1f0, 0f0) +end + +""" +Extract wavefront-compatible material properties from a Hikari PlasticMaterial. +Plastic has both diffuse and specular; map to metallic based on Ks intensity. +""" +@inline function extract_material_props(mat::Hikari.PlasticMaterial, ::Point2f) + kd = get_texture_value(mat.Kd) + ks = get_texture_value(mat.Ks) + base_color = Vec3f(kd.c[1], kd.c[2], kd.c[3]) + # Map specular intensity to metallic factor + metallic = (ks.c[1] + ks.c[2] + ks.c[3]) / 3f0 + roughness = get_texture_value(mat.roughness) + return WavefrontMaterialProps(base_color, metallic, roughness) +end + +""" +Extract wavefront-compatible material properties from a Hikari GlassMaterial. +Glass is mapped to a transparent-ish material; use reflection as base color. +""" +@inline function extract_material_props(mat::Hikari.GlassMaterial, ::Point2f) + kr = get_texture_value(mat.Kr) + base_color = Vec3f(kr.c[1], kr.c[2], kr.c[3]) + roughness = get_texture_value(mat.u_roughness) + # Glass behaves more like a mirror in our simplified model + return WavefrontMaterialProps(base_color, 0.8f0, roughness) +end + +""" +Extract wavefront-compatible material properties from a Hikari MetalMaterial. +Metals are fully metallic (metallic=1) with roughness and reflectance for color. +""" +@inline function extract_material_props(mat::Hikari.MetalMaterial, ::Point2f) + # Use reflectance as the base color (tinting) + refl = get_texture_value(mat.reflectance) + base_color = Vec3f(refl.c[1], refl.c[2], refl.c[3]) + roughness = get_texture_value(mat.roughness) + # Metals are fully metallic + return WavefrontMaterialProps(base_color, 1f0, roughness) +end + +# Fallback for any other material type +@inline function extract_material_props(mat, ::Point2f) + return WavefrontMaterialProps(Vec3f(0.5f0, 0.5f0, 0.5f0), 0f0, 0.5f0) +end + +""" +Generated function for type-stable material property extraction. +Dispatches to the appropriate extract_material_props based on material_type. +""" +@generated function extract_material_from_scene( + materials::NTuple{N,Any}, idx::MaterialIndex, uv::Point2f +) where N + branches = [quote + if idx.material_type === UInt8($i) + return @inline extract_material_props(@inbounds(materials[$i][idx.material_idx]), uv) + end + end for i in 1:N] + quote + $(branches...) + return WavefrontMaterialProps(Vec3f(0.5f0), 0f0, 0.5f0) + end +end + +# ============================================================================ +# Light Representation for Wavefront Renderer +# ============================================================================ + +""" + WavefrontLight + +Simple point light for wavefront shading, compatible with GPU. +""" +struct WavefrontLight + position::Point3f + intensity::Float32 + color::RGB{Float32} +end + +""" +Convert a Hikari PointLight to WavefrontLight format. +""" +function WavefrontLight(light::Hikari.PointLight) + # Extract intensity as luminance of spectrum + intensity = Hikari.to_Y(light.i) + # Extract color (normalized RGB) + rgb = Hikari.rgb(light.i) + max_val = max(rgb[1], rgb[2], rgb[3], 1f-6) + color = RGB{Float32}(rgb[1]/max_val, rgb[2]/max_val, rgb[3]/max_val) + return WavefrontLight(light.position, intensity, color) +end + +# ============================================================================ +# Work Queue Structures (same as wavefront-renderer.jl) +# ============================================================================ + +struct PrimaryRayWork + ray::Raycore.Ray + pixel_x::Int32 + pixel_y::Int32 + sample_idx::Int32 +end + +struct PrimaryHitWork{Tri} + hit_found::Bool + tri::Tri + dist::Float32 + bary::Vec3f + ray::Raycore.Ray + pixel_x::Int32 + pixel_y::Int32 + sample_idx::Int32 +end + +struct ShadowRayWork + ray::Raycore.Ray + hit_idx::Int32 + light_idx::Int32 +end + +struct ShadowResult + visible::Bool + hit_idx::Int32 + light_idx::Int32 +end + +struct ReflectionRayWork + ray::Raycore.Ray + hit_idx::Int32 +end + +struct ReflectionHitWork + hit_found::Bool + material_idx::MaterialIndex + dist::Float32 + bary::Vec3f + normal::Vec3f + ray::Raycore.Ray + primary_hit_idx::Int32 +end + +struct ShadedResult + color::Vec3f + pixel_x::Int32 + pixel_y::Int32 + sample_idx::Int32 +end + +# ============================================================================ +# Render Context for Hikari Integration +# ============================================================================ + +""" + HikariRenderContext + +Holds the lights and materials for GPU rendering. +Uses WavefrontLight for lights (converted from Hikari lights). +Materials are stored as the original Hikari MaterialScene.materials tuple. +""" +struct HikariRenderContext{L<:AbstractVector{WavefrontLight}, M<:Tuple} + lights::L + materials::M + ambient::Float32 +end + +function Raycore.to_gpu(Arr, ctx::HikariRenderContext) + lights_gpu = Raycore.to_gpu(Arr, ctx.lights) + # Materials tuple needs per-type GPU conversion + materials_gpu = map(ctx.materials) do mats + Raycore.to_gpu(Arr, map(m -> Hikari.to_gpu(Arr, m), mats)) + end + return HikariRenderContext(lights_gpu, materials_gpu, ctx.ambient) +end + +# ============================================================================ +# Helper function +# ============================================================================ + +function similar_soa(img, T, num_elements) + fields = [f => similar(img, fieldtype(T, f), num_elements) for f in fieldnames(T)] + return (; fields...) +end + +@generated function for_unrolled(f::F, ::Val{N}) where {F, N} + return Expr(:block, [:(f($(Raycore.gpu_int(i)))) for i in 1:N]...) +end + +# ============================================================================ +# Stage 1: Generate Primary Camera Rays +# ============================================================================ + +@kernel function generate_primary_rays_lookat!( + @Const(width), @Const(height), + @Const(camera_pos), + @Const(camera_right), @Const(camera_up), @Const(camera_forward), + @Const(half_width), @Const(half_height), + ray_queue, + ::Val{NSamples} +) where {NSamples} + i = @index(Global, Cartesian) + y = gpu_int(i[1]) + x = gpu_int(i[2]) + + @inbounds if y <= height && x <= width + pixel_idx = (y - gpu_int(1)) * width + x + ntuple(Val(NSamples)) do s + s_idx = gpu_int(s) + ray_idx = (pixel_idx - gpu_int(1)) * gpu_int(NSamples) + s_idx + jitter = rand(Vec2f) + + u = (2.0f0 * (Float32(x) - 0.5f0 + jitter[1]) / Float32(width) - 1.0f0) + v = (1.0f0 - 2.0f0 * (Float32(y) - 0.5f0 + jitter[2]) / Float32(height)) + + direction = normalize( + camera_forward + + camera_right * (u * half_width) + + camera_up * (v * half_height) + ) + ray = Raycore.Ray(o=camera_pos, d=direction) + + @set ray_queue[ray_idx] = (ray=ray, pixel_x=x, pixel_y=y, sample_idx=s_idx) + nothing + end + end +end + +# ============================================================================ +# Stage 2: Intersect Primary Rays (adapted for MaterialScene) +# ============================================================================ + +@kernel function intersect_primary_rays_hikari!( + @Const(accel), # BVH or TLAS from MaterialScene + @Const(ray_queue), + hit_queue +) + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(ray_queue.ray) + @get ray, pixel_x, pixel_y, sample_idx = ray_queue[idx] + hit_found, tri, dist, bary = Raycore.closest_hit(accel, ray) + @set hit_queue[idx] = (hit_found=hit_found, tri=tri, dist=dist, bary=Vec3f(bary), + ray=ray, pixel_x=pixel_x, pixel_y=pixel_y, sample_idx=sample_idx) + end +end + +# ============================================================================ +# Stage 3: Generate Shadow Rays +# ============================================================================ + +@kernel function generate_shadow_rays_hikari!( + @Const(hit_queue), + @Const(ctx), + shadow_ray_queue, + nlights::Val{NLights} +) where {NLights} + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(hit_queue.hit_found) + @get hit_found, tri, dist, bary, ray = hit_queue[idx] + + if hit_found + hit_point = ray.o + ray.d * dist + v0, v1, v2 = Raycore.normals(tri) + u, v, w = bary[1], bary[2], bary[3] + normal = Vec3f(normalize(v0 * u + v1 * v + v2 * w)) + + for_unrolled(nlights) do light_idx + light_idx_gpu = gpu_int(light_idx) + shadow_ray_idx = (idx - gpu_int(1)) * gpu_int(NLights) + light_idx_gpu + light = ctx.lights[light_idx_gpu] + + shadow_bias = 0.01f0 + shadow_origin = hit_point + normal * shadow_bias + light_vec = light.position - shadow_origin + shadow_dir = normalize(light_vec) + light_dist = norm(light_vec) + shadow_ray = Raycore.Ray(o=shadow_origin, d=shadow_dir, t_max=light_dist) + + @set shadow_ray_queue[shadow_ray_idx] = (ray=shadow_ray, hit_idx=idx, light_idx=light_idx_gpu) + end + else + dummy_ray = Raycore.Ray(o=Point3f(0,0,0), d=Vec3f(0,0,1), t_max=0.0f0) + for_unrolled(nlights) do light_idx + light_idx_gpu = gpu_int(light_idx) + shadow_ray_idx = (idx - gpu_int(1)) * gpu_int(NLights) + light_idx_gpu + shadow_ray_queue.ray[shadow_ray_idx] = dummy_ray + end + end + end +end + +# ============================================================================ +# Stage 4: Test Shadow Rays +# ============================================================================ + +@kernel function test_shadow_rays_hikari!( + @Const(accel), + @Const(shadow_ray_queue), + shadow_result_queue +) + i = @index(Global, Linear) + idx = gpu_int(i) + @inbounds if idx <= length(shadow_ray_queue.ray) + @get ray, hit_idx, light_idx = shadow_ray_queue[idx] + + visible = if ray.t_max > 0.0f0 + hit_found, _, _, _ = Raycore.any_hit(accel, ray) + !hit_found + else + false + end + + @set shadow_result_queue[idx] = (visible=visible, hit_idx=hit_idx, light_idx=light_idx) + end +end + +# ============================================================================ +# Stage 5: Shade Primary Hits with Shadow Information (Hikari materials) +# ============================================================================ + +@kernel function shade_primary_hits_hikari!( + @Const(hit_queue), + @Const(ctx), + @Const(shadow_results), + @Const(sky_color), + shading_queue, + nlights::Val{NLights} +) where {NLights} + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(hit_queue.hit_found) + @get hit_found, tri, dist, bary, ray, pixel_x, pixel_y, sample_idx = hit_queue[idx] + + if hit_found + hit_point = ray.o + ray.d * dist + v0, v1, v2 = Raycore.normals(tri) + u, v, w = bary[1], bary[2], bary[3] + normal = Vec3f(normalize(v0 * u + v1 * v + v2 * w)) + + # Get material properties from Hikari material via MaterialIndex + mat_idx = tri.metadata::MaterialIndex + mat_props = extract_material_from_scene(ctx.materials, mat_idx, Point2f(0)) + base_color = mat_props.base_color + + # Start with ambient + total_color = base_color * ctx.ambient + + # Add contribution from each light + light_samples = ntuple(nlights) do light_idx + light_idx_gpu = gpu_int(light_idx) + shadow_idx = (idx - gpu_int(1)) * gpu_int(NLights) + light_idx_gpu + visible = shadow_results.visible[shadow_idx] + if visible + light = ctx.lights[light_idx_gpu] + light_vec = light.position - hit_point + light_dist = norm(light_vec) + light_dir = light_vec / light_dist + + diffuse = max(0.0f0, dot(normal, light_dir)) + attenuation = light.intensity / (light_dist * light_dist) + light_color = Vec3f(light.color.r, light.color.g, light.color.b) + difa = (diffuse * attenuation) + Vec3f(base_color * (light_color * difa)) + else + Vec3f(0) + end + end + + final_color = total_color + sum(light_samples) + @set shading_queue[idx] = (color=final_color, pixel_x=pixel_x, pixel_y=pixel_y, sample_idx=sample_idx) + else + sky_vec = Vec3f(sky_color.r, sky_color.g, sky_color.b) + @set shading_queue[idx] = (color=sky_vec, pixel_x=pixel_x, pixel_y=pixel_y, sample_idx=sample_idx) + end + end +end + +# ============================================================================ +# Stage 6: Generate Reflection Rays (Hikari materials) +# ============================================================================ + +@kernel function generate_reflection_rays_hikari!( + @Const(hit_queue), + @Const(ctx), + reflection_ray_soa, + active_count +) + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(hit_queue.hit_found) + @get hit_found, tri, dist, bary, ray = hit_queue[idx] + dummy_ray = Raycore.Ray(o=Point3f(0, 0, 0), d=Vec3f(0, 0, 1), t_max=0.0f0) + + if hit_found + mat_idx = tri.metadata::MaterialIndex + mat_props = extract_material_from_scene(ctx.materials, mat_idx, Point2f(0)) + + if mat_props.metallic > 0.0f0 + hit_point = ray.o + ray.d * dist + v0, v1, v2 = Raycore.normals(tri) + u, v, w = bary[1], bary[2], bary[3] + normal = Vec3f(normalize(v0 * u + v1 * v + v2 * w)) + + wo = -ray.d + reflect_dir = Raycore.reflect(wo, normal) + + if mat_props.roughness > 0.0f0 + offset = (rand(Vec3f) .* 2.0f0 .- 1.0f0) * mat_props.roughness + reflect_dir = normalize(reflect_dir + offset) + end + + reflect_ray = Raycore.Ray(o=hit_point + normal * 0.01f0, d=reflect_dir) + @set reflection_ray_soa[idx] = (ray=reflect_ray, hit_idx=idx) + else + reflection_ray_soa.ray[idx] = dummy_ray + end + else + reflection_ray_soa.ray[idx] = dummy_ray + end + end +end + +# ============================================================================ +# Stage 7: Intersect Reflection Rays (Hikari materials) +# ============================================================================ + +@kernel function intersect_reflection_rays_hikari!( + @Const(accel), + @Const(reflection_ray_soa), + reflection_hit_soa +) + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(reflection_ray_soa.ray) + @get ray, hit_idx = reflection_ray_soa[idx] + + if ray.t_max > 0.0f0 + hit_found, tri, dist, bary = Raycore.closest_hit(accel, ray) + if hit_found + v0, v1, v2 = Raycore.normals(tri) + u, v, w = bary[1], bary[2], bary[3] + normal = Vec3f(normalize(v0 * u + v1 * v + v2 * w)) + + # Store MaterialIndex instead of Int32 material_idx + mat_idx = tri.metadata::MaterialIndex + @set reflection_hit_soa[idx] = (hit_found=true, material_idx=mat_idx, + dist=dist, bary=Vec3f(bary), normal=normal, + ray=ray, primary_hit_idx=hit_idx) + else + reflection_hit_soa.hit_found[idx] = false + end + else + reflection_hit_soa.hit_found[idx] = false + end + end +end + +# ============================================================================ +# Stage 8: Shade Reflection Hits and Blend (Hikari materials) +# ============================================================================ + +@kernel function shade_reflections_and_blend_hikari!( + @Const(hit_queue), + @Const(reflection_hit_soa), + @Const(ctx), + @Const(sky_color), + shading_queue +) + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(hit_queue.hit_found) + @get hit_found, tri, pixel_x, pixel_y, sample_idx = hit_queue[idx] + + if hit_found + mat_idx = tri.metadata::MaterialIndex + mat_props = extract_material_from_scene(ctx.materials, mat_idx, Point2f(0)) + + if mat_props.metallic > 0.0f0 + @get hit_found, material_idx, dist, bary, normal, ray = reflection_hit_soa[idx] + + reflection_color = if hit_found + refl_point = ray.o + ray.d * dist + refl_normal = normal + + refl_mat_props = extract_material_from_scene(ctx.materials, material_idx, Point2f(0)) + refl_base_color = refl_mat_props.base_color + + refl_color = refl_base_color * ctx.ambient + + if length(ctx.lights) > 0 + light = ctx.lights[gpu_int(1)] + light_vec = light.position - refl_point + light_dist = norm(light_vec) + light_dir = normalize(light_vec) + diffuse = max(0.0f0, dot(refl_normal, light_dir)) + attenuation = light.intensity / (light_dist * light_dist) + light_color = Vec3f(light.color.r, light.color.g, light.color.b) + refl_color += refl_base_color .* (light_color * (diffuse * attenuation)) + end + refl_color + else + Vec3f(sky_color.r, sky_color.g, sky_color.b) + end + + primary_color = shading_queue.color[idx] + blended_color = primary_color * (1.0f0 - mat_props.metallic) + reflection_color * mat_props.metallic + + @set shading_queue[idx] = (color=blended_color, pixel_x=pixel_x, pixel_y=pixel_y, sample_idx=sample_idx) + end + end + end +end + +# ============================================================================ +# Stage 9: Accumulate Final Image +# ============================================================================ + +@kernel function accumulate_final!( + @Const(shading_queue), + img, + sample_accumulator +) + i = @index(Global, Linear) + idx = gpu_int(i) + + @inbounds if idx <= length(shading_queue.color) + color = shading_queue.color[idx] + sample_accumulator[idx] = color + end +end + +@kernel function finalize_image!( + @Const(sample_accumulator), + img, + nsamples::Val{NSamples} +) where {NSamples} + i = @index(Global, Cartesian) + y = gpu_int(i[1]) + x = gpu_int(i[2]) + height, width = size(img) + + @inbounds if y <= height && x <= width + pixel_idx = (y - gpu_int(1)) * width + x + samples = ntuple(nsamples) do idx + s_idx = gpu_int(idx) + sample_idx = (pixel_idx - gpu_int(1)) * gpu_int(NSamples) + s_idx + sample_accumulator[sample_idx] + end + img[y, x] = RGB{Float32}(mean(samples)...) + end +end + +# ============================================================================ +# HikariWavefrontRenderer +# ============================================================================ + +""" + HikariWavefrontRenderer + +A wavefront renderer that uses Hikari's MaterialScene for materials and lights. +Uses the same shading model as the original wavefront-renderer.jl but with +Hikari data structures. + +Materials are stored in MaterialScene (tuple of typed vectors) and lights +are converted to WavefrontLight format for GPU compatibility. +""" +struct HikariWavefrontRenderer{ImgArr <: AbstractMatrix, Accel, Ctx} + width::Int32 + height::Int32 + + framebuffer::ImgArr + accel::Accel # BVH or TLAS from MaterialScene + ctx::Ctx # HikariRenderContext + + camera_pos::Point3f + camera_lookat::Point3f + camera_up::Vec3f + fov::Float32 + sky_color::RGB{Float32} + samples_per_pixel::Int32 + + # Work queues + primary_ray_queue::NamedTuple + primary_hit_queue::NamedTuple + shadow_ray_queue::NamedTuple + shadow_result_queue::NamedTuple + reflection_ray_soa::NamedTuple + reflection_hit_soa::NamedTuple + shading_queue::NamedTuple + sample_accumulator::AbstractVector + active_count::AbstractVector +end + +""" + HikariWavefrontRenderer(img, material_scene, lights; kwargs...) + +Create a HikariWavefrontRenderer from a Hikari MaterialScene and lights. + +# Arguments +- `img`: Output image buffer +- `material_scene`: Hikari.MaterialScene containing geometry and materials +- `lights`: Vector of Hikari lights (PointLight, etc.) + +# Keyword Arguments +- `camera_pos`: Camera position (default: Point3f(0, -0.9, -2.5)) +- `camera_lookat`: Look-at target (default: Point3f(0, 0, 0)) +- `camera_up`: Up vector (default: Vec3f(0, 0, 1)) +- `fov`: Field of view in degrees (default: 45) +- `sky_color`: Background color (default: light blue) +- `samples_per_pixel`: Anti-aliasing samples (default: 4) +- `ambient`: Ambient light factor (default: 0.1) +""" +function HikariWavefrontRenderer( + img, + material_scene::MaterialScene, + lights::AbstractVector{<:Light}; + camera_pos=Point3f(0, -0.9, -2.5), + camera_lookat=Point3f(0, 0, 0), + camera_up=Vec3f(0, 0, 1), + fov=45.0f0, + sky_color=RGB{Float32}(0.5f0, 0.7f0, 1.0f0), + samples_per_pixel=4, + ambient=0.1f0 + ) + height, width = size(img) + + # Convert Hikari lights to WavefrontLight + wavefront_lights = [WavefrontLight(l) for l in lights] + + # Create render context + ctx = HikariRenderContext(wavefront_lights, material_scene.materials, ambient) + + num_pixels = width * height + num_rays = num_pixels * samples_per_pixel + num_lights = Int32(length(wavefront_lights)) + num_shadow_rays = num_rays * num_lights + + # Get triangle type from accelerator + accel = material_scene.accel + tri_type = eltype(accel) + + # Allocate work queues + primary_ray_queue = similar_soa(img, PrimaryRayWork, num_rays) + primary_hit_queue = similar_soa(img, PrimaryHitWork{tri_type}, num_rays) + shadow_ray_queue = similar_soa(img, ShadowRayWork, num_shadow_rays) + shadow_result_queue = similar_soa(img, ShadowResult, num_shadow_rays) + reflection_ray_soa = similar_soa(img, ReflectionRayWork, num_rays) + reflection_hit_soa = similar_soa(img, ReflectionHitWork, num_rays) + shading_queue = similar_soa(img, ShadedResult, num_rays) + sample_accumulator = similar(img, Vec3f, num_rays) + active_count = similar(img, Int32, 1) + + return HikariWavefrontRenderer( + Int32(width), Int32(height), + img, accel, ctx, + camera_pos, camera_lookat, camera_up, + fov, sky_color, Int32(samples_per_pixel), + primary_ray_queue, primary_hit_queue, + shadow_ray_queue, shadow_result_queue, + reflection_ray_soa, reflection_hit_soa, + shading_queue, sample_accumulator, active_count + ) +end + +""" + to_gpu(ArrayType, renderer::HikariWavefrontRenderer) + +Convert a HikariWavefrontRenderer to GPU arrays. +""" +function Raycore.to_gpu(Arr, renderer::HikariWavefrontRenderer) + img = Arr(renderer.framebuffer) + accel_gpu = Raycore.to_gpu(Arr, renderer.accel) + ctx_gpu = Raycore.to_gpu(Arr, renderer.ctx) + + return HikariWavefrontRenderer( + img, accel_gpu, ctx_gpu; + camera_pos=renderer.camera_pos, + camera_lookat=renderer.camera_lookat, + camera_up=renderer.camera_up, + fov=renderer.fov, + sky_color=renderer.sky_color, + samples_per_pixel=Int(renderer.samples_per_pixel), + ambient=renderer.ctx.ambient + ) +end + +# Inner constructor for to_gpu +function HikariWavefrontRenderer( + img, accel, ctx; + camera_pos, camera_lookat, camera_up, fov, sky_color, samples_per_pixel, ambient + ) + height, width = size(img) + + num_pixels = width * height + num_rays = num_pixels * samples_per_pixel + num_lights = Int32(length(ctx.lights)) + num_shadow_rays = num_rays * num_lights + + tri_type = eltype(accel) + + primary_ray_queue = similar_soa(img, PrimaryRayWork, num_rays) + primary_hit_queue = similar_soa(img, PrimaryHitWork{tri_type}, num_rays) + shadow_ray_queue = similar_soa(img, ShadowRayWork, num_shadow_rays) + shadow_result_queue = similar_soa(img, ShadowResult, num_shadow_rays) + reflection_ray_soa = similar_soa(img, ReflectionRayWork, num_rays) + reflection_hit_soa = similar_soa(img, ReflectionHitWork, num_rays) + shading_queue = similar_soa(img, ShadedResult, num_rays) + sample_accumulator = similar(img, Vec3f, num_rays) + active_count = similar(img, Int32, 1) + + return HikariWavefrontRenderer( + Int32(width), Int32(height), + img, accel, ctx, + camera_pos, camera_lookat, camera_up, + fov, sky_color, Int32(samples_per_pixel), + primary_ray_queue, primary_hit_queue, + shadow_ray_queue, shadow_result_queue, + reflection_ray_soa, reflection_hit_soa, + shading_queue, sample_accumulator, active_count + ) +end + +""" + render!(renderer::HikariWavefrontRenderer) + +Execute the wavefront path tracing pipeline. +""" +function render!(renderer::HikariWavefrontRenderer) + width = Int(renderer.width) + height = Int(renderer.height) + samples_per_pixel = Int(renderer.samples_per_pixel) + + aspect = Float32(width / height) + + backend = KA.get_backend(renderer.framebuffer) + + num_pixels = width * height + num_rays = num_pixels * samples_per_pixel + num_lights = Int(length(renderer.ctx.lights)) + num_shadow_rays = num_rays * num_lights + + # Camera basis vectors + camera_forward = Vec3f(normalize(renderer.camera_lookat - renderer.camera_pos)) + camera_right = Vec3f(normalize(cross(renderer.camera_up, camera_forward))) + camera_up_ortho = Vec3f(cross(camera_forward, camera_right)) + + half_height = tan(deg2rad(renderer.fov / 2)) + half_width = half_height * aspect + + # Stage 1: Generate primary rays + gen_kernel! = generate_primary_rays_lookat!(backend) + gen_kernel!( + renderer.width, renderer.height, + renderer.camera_pos, + camera_right, camera_up_ortho, camera_forward, + half_width, half_height, + renderer.primary_ray_queue, + Val(samples_per_pixel), + ndrange=(height, width) + ) + + # Stage 2: Intersect primary rays + intersect_kernel! = intersect_primary_rays_hikari!(backend) + intersect_kernel!( + renderer.accel, + renderer.primary_ray_queue, + renderer.primary_hit_queue, + ndrange=num_rays + ) + + # Stage 3: Generate shadow rays + shadow_gen_kernel! = generate_shadow_rays_hikari!(backend) + shadow_gen_kernel!( + renderer.primary_hit_queue, + renderer.ctx, + renderer.shadow_ray_queue, + Val(num_lights), + ndrange=num_rays + ) + + # Stage 4: Test shadow rays + shadow_test_kernel! = test_shadow_rays_hikari!(backend) + shadow_test_kernel!( + renderer.accel, + renderer.shadow_ray_queue, + renderer.shadow_result_queue, + ndrange=num_shadow_rays + ) + + # Stage 5: Shade primary hits + shade_kernel! = shade_primary_hits_hikari!(backend) + shade_kernel!( + renderer.primary_hit_queue, + renderer.ctx, + renderer.shadow_result_queue, + renderer.sky_color, + renderer.shading_queue, + Val(num_lights), + ndrange=num_rays + ) + + # Stage 6: Generate reflection rays + refl_gen_kernel! = generate_reflection_rays_hikari!(backend) + refl_gen_kernel!( + renderer.primary_hit_queue, + renderer.ctx, + renderer.reflection_ray_soa, + renderer.active_count, + ndrange=num_rays + ) + + # Stage 7: Intersect reflection rays + refl_intersect_kernel! = intersect_reflection_rays_hikari!(backend) + refl_intersect_kernel!( + renderer.accel, + renderer.reflection_ray_soa, + renderer.reflection_hit_soa, + ndrange=num_rays + ) + + # Stage 8: Shade reflections + refl_shade_kernel! = shade_reflections_and_blend_hikari!(backend) + refl_shade_kernel!( + renderer.primary_hit_queue, + renderer.reflection_hit_soa, + renderer.ctx, + renderer.sky_color, + renderer.shading_queue, + ndrange=num_rays + ) + + # Stage 9: Accumulate final image + accum_kernel! = accumulate_final!(backend) + accum_kernel!( + renderer.shading_queue, + renderer.framebuffer, + renderer.sample_accumulator, + ndrange=num_rays + ) + + final_kernel! = finalize_image!(backend) + final_kernel!( + renderer.sample_accumulator, + renderer.framebuffer, + Val(samples_per_pixel), + ndrange=(height, width) + ) + KA.synchronize(backend) + + return renderer.framebuffer +end + +# ============================================================================ +# Convenience function for creating example scene with Hikari materials +# ============================================================================ + +""" + hikari_example_scene() + +Create an example scene using Hikari materials that matches the original +wavefront renderer's example_scene. + +Returns (material_scene, lights) tuple. +""" +function hikari_example_scene(; glass_cat=false) + cat_mesh = Makie.loadasset("cat.obj") + angle = deg2rad(150f0) + rotation = Makie.Quaternionf(0, sin(angle/2), 0, cos(angle/2)) + rotated_coords = [rotation * Point3f(v) for v in coordinates(cat_mesh)] + + cat_bbox = Rect3f(rotated_coords) + floor_y = -1.5f0 + cat_offset = Vec3f(0, floor_y - cat_bbox.origin[2], 0) + + cat_mesh = GeometryBasics.normal_mesh( + [v + cat_offset for v in rotated_coords], + faces(cat_mesh) + ) + + floor = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(10, 0.01, 10))) + back_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, 8), Vec3f(10, 5, 0.01))) + left_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(0.01, 5, 10))) + + sphere1 = Tesselation(Sphere(Point3f(-2, -1.5 + 0.8, 2), 0.8f0), 64) + sphere2 = Tesselation(Sphere(Point3f(2, -1.5 + 0.6, 1), 0.6f0), 64) + + # Create Hikari materials matching the original scene + # Original: Material(base_color, metallic, roughness, ior, transmission) + cat_material = if glass_cat + # Glass cat: high IOR, transmission + Hikari.GlassMaterial( + Kr=RGBSpectrum(0.95f0, 1.0f0, 0.95f0), + Kt=RGBSpectrum(0.95f0, 1.0f0, 0.95f0), + u_roughness=0f0, + v_roughness=0f0, + index=1.5f0, + remap_roughness=false + ) + else + # Diffuse cat + Hikari.MatteMaterial(Kd=RGBSpectrum(0.8f0, 0.6f0, 0.4f0), σ=0f0) + end + + # Floor: diffuse green + floor_material = Hikari.MatteMaterial(Kd=RGBSpectrum(0.3f0, 0.5f0, 0.3f0), σ=0f0) + + # Back wall: metallic with roughness (original: metallic=0.8, roughness=0.05) + # Using MetalMaterial with reflectance for color tinting + back_wall_material = Hikari.MetalMaterial( + reflectance=(0.8f0, 0.6f0, 0.5f0), + roughness=0.05f0, + remap_roughness=false + ) + + # Left wall: diffuse + left_wall_material = Hikari.MatteMaterial(Kd=RGBSpectrum(0.7f0, 0.7f0, 0.8f0), σ=0f0) + + # Sphere 1: metallic silver with slight roughness (original: metallic=0.8, roughness=0.02) + sphere1_material = Hikari.MetalMaterial( + reflectance=(0.9f0, 0.9f0, 0.9f0), + roughness=0.02f0, + remap_roughness=false + ) + + # Sphere 2: partially metallic blue (using PlasticMaterial for mixed behavior) + # PlasticMaterial is appropriate here since it's only partially metallic + sphere2_material = Hikari.PlasticMaterial( + Kd=RGBSpectrum(0.3f0, 0.6f0, 0.9f0), + Ks=RGBSpectrum(0.5f0, 0.5f0, 0.5f0), + roughness=0.3f0, + remap_roughness=false + ) + + scene_pairs = [ + (cat_mesh, cat_material), + (floor, floor_material), + (back_wall, back_wall_material), + (left_wall, left_wall_material), + (normal_mesh(sphere1), sphere1_material), + (normal_mesh(sphere2), sphere2_material), + ] + + material_scene = Hikari.MaterialScene(scene_pairs) + + # Create Hikari lights matching original + lights = [ + Hikari.PointLight(Point3f(3, 4, -2), RGBSpectrum(50.0f0 * 1.0f0, 50.0f0 * 0.9f0, 50.0f0 * 0.8f0)), + Hikari.PointLight(Point3f(-3, 2, 0), RGBSpectrum(20.0f0 * 0.7f0, 20.0f0 * 0.8f0, 20.0f0 * 1.0f0)), + Hikari.PointLight(Point3f(0, 5, 5), RGBSpectrum(15.0f0, 15.0f0, 15.0f0)), + ] + + return material_scene, lights +end diff --git a/docs/src/hw-accel-benchmarks.png b/docs/src/hw-accel-benchmarks.png new file mode 100644 index 0000000..54bae4d Binary files /dev/null and b/docs/src/hw-accel-benchmarks.png differ diff --git a/docs/src/hw-accel-hw.png b/docs/src/hw-accel-hw.png new file mode 100644 index 0000000..185a0a0 Binary files /dev/null and b/docs/src/hw-accel-hw.png differ diff --git a/docs/src/hw-accel-large.png b/docs/src/hw-accel-large.png new file mode 100644 index 0000000..c997e65 Binary files /dev/null and b/docs/src/hw-accel-large.png differ diff --git a/docs/src/hw-accel-materials.png b/docs/src/hw-accel-materials.png new file mode 100644 index 0000000..185a0a0 Binary files /dev/null and b/docs/src/hw-accel-materials.png differ diff --git a/docs/src/hw_acceleration.md b/docs/src/hw_acceleration.md new file mode 100644 index 0000000..9f55a8d --- /dev/null +++ b/docs/src/hw_acceleration.md @@ -0,0 +1,280 @@ +# Hardware Ray Tracing with Lava + +Modern GPUs include dedicated ray tracing hardware (RT cores on NVIDIA, Ray Accelerators on AMD) that can traverse BVH structures and test ray-triangle intersections in fixed-function silicon. This tutorial shows how to use hardware acceleration with Raycore via the [Lava.jl](https://github.com/SimonDanisch/Lava.jl) Vulkan backend. + +The demo builds the same scene twice — once into a software `Raycore.TLAS`, once into a hardware `Lava.HWTLAS` — traces primary camera rays through both, and verifies the depth buffers agree. + +## When to pick `Raycore.TLAS` vs. `Lava.HWTLAS` + +| Aspect | `Raycore.TLAS` | `Lava.HWTLAS` | +| ------------------:| -----------------------------------------------------:| ------------------------------------:| +| Backend | any KA backend (CUDA, AMDGPU, Metal, Lava) | Lava + Vulkan RT only | +| BVH | software (BVH4 / instanced BVH) | `VkAccelerationStructureKHR` | +| Closest-hit kernel | KA `@kernel`, in-line `Raycore.closest_hit(bvh, ray)` | `vkCmdTraceRaysKHR` over a ray batch | +| Dispatch overhead | low, KA launch | very low, pre-baked SBT | +| Use when | portability / non-Vulkan backend / no RT hardware | max perf on Vulkan RT hardware | + +Both types satisfy `Raycore.AbstractAccel` — `push!`, `delete!`, `update_transform!`, `sync!`, `n_instances`, `n_geometries`, `wait_for_gpu!` work identically. The HW path uses a batched dispatch (`Lava.trace_closest_hits!`) instead of a per-thread `closest_hit` call inside a KA kernel. + +## When Hardware RT Helps + +Hardware RT gives the biggest speedups on scenes with: + + * **High triangle counts** — RT cores traverse the BVH in fixed-function hardware + * **Complex occlusion** — interior scenes, overlapping geometry, more traversal steps per ray + * **Simple shading** — when BVH traversal dominates, not material evaluation + +For trivial scenes the software BVH already runs at GPU memory bandwidth, so the win is modest. The demo below uses ~48k triangles which is enough to see HW pull ahead but small enough to fit in a tutorial. + +Hardware RT requires a Vulkan-capable GPU with `VK_KHR_ray_tracing_pipeline`. NVIDIA RTX (Turing+), AMD RDNA 2+, and Intel Arc all support it. + +## Setup + +```julia +using Raycore, GeometryBasics, LinearAlgebra +using Lava +using WGLMakie +using Adapt +import KernelAbstractions as KA +using KernelAbstractions: @kernel, @index, @Const + +device = Lava.LavaBackend() +``` + +**Lava backend active** — Vulkan device with RT support. + +`LavaBackend` is a KernelAbstractions backend that compiles `@kernel` code through Lava's SPIR-V compiler. It's also the device that owns Vulkan acceleration structures, so SW and HW share the same GPU context. + +## Building a scene twice — software and hardware + +Build a small scene of tessellated spheres on a floor — enough triangles for the BVH traversal cost to matter. Both `Raycore.TLAS` and `Lava.HWTLAS` ingest plain `GeometryBasics.Mesh` objects through `push!`, so the same meshes go into both. + +```julia +function build_meshes() + floor = GeometryBasics.normal_mesh(Rect3f(Vec3f(-3, -3, -0.01), Vec3f(6, 6, 0.01))) + sphere_centers = Point3f[] + for i in -2:2, j in -2:2 + push!(sphere_centers, Point3f(Float32(i)*0.9f0, Float32(j)*0.9f0, 0.4f0)) + end + spheres = [GeometryBasics.normal_mesh(Tesselation(Sphere(c, 0.3f0), 32)) + for c in sphere_centers] + return [floor; spheres] +end + +all_meshes = build_meshes() +total_tris = sum(length(GeometryBasics.faces(m)) for m in all_meshes) + +sw_tlas = Raycore.TLAS(device) +hwtlas = Lava.HWTLAS(device) + +for (i, m) in enumerate(all_meshes) + push!(sw_tlas, m) + push!(hwtlas, m; instance_id=UInt32(i)) +end + +Raycore.sync!(sw_tlas) +Raycore.sync!(hwtlas) +``` + +**Scene built** + +| | meshes | instances | triangles | +|---------------:|-------:|----------:|----------:| +| `Raycore.TLAS` | 26 | 26 | 48 062 | +| `Lava.HWTLAS` | 26 | 26 | 48 062 | + +`sync!` uploads the meshes and builds the acceleration structures. For `Raycore.TLAS` that means GPU LBVH builds (one BLAS per mesh + a TLAS over instances). For `Lava.HWTLAS` it means `vkCmdBuildAccelerationStructuresKHR` calls plus instance-table allocation. + +## Tracing primary rays both ways + +Generate one camera ray per pixel. The SW path uses `Raycore.Ray` (origin + direction), the HW path uses `Raycore.RTRay` (origin + dir + tmin/tmax, 32-byte struct that matches Vulkan's `VkRayTracingShaderRecordKHR` layout). + +```julia +const W, H = 256, 192 + +cam_pos = Point3f(0, -3.5, 1.6) +cam_target = Point3f(0, 0, 0.3) +cam_up = Point3f(0, 0, 1) + +forward = normalize(cam_target - cam_pos) +right = normalize(cross(forward, cam_up)) +up = cross(right, forward) +aspect = Float32(W / H) +focal = 1.0f0 / tan(deg2rad(45.0f0 / 2)) + +function build_rays(W, H, cam_pos, forward, right, up, aspect, focal) + rays_sw = Vector{Raycore.Ray}(undef, W*H) + rays_hw = Vector{Raycore.RTRay}(undef, W*H) + for y in 1:H, x in 1:W + u = (2.0f0 * (Float32(x) - 0.5f0) / Float32(W) - 1.0f0) + v = (1.0f0 - 2.0f0 * (Float32(y) - 0.5f0) / Float32(H)) + d = normalize(forward * focal + right * (u * aspect) + up * v) + i = (y - 1) * W + x + rays_sw[i] = Raycore.Ray(o=cam_pos, d=Vec3f(d)) + rays_hw[i] = Raycore.RTRay(cam_pos[1], cam_pos[2], cam_pos[3], 0f0, + d[1], d[2], d[3], 1f3) + end + rays_sw, rays_hw +end + +rays_sw, rays_hw = build_rays(W, H, cam_pos, forward, right, up, aspect, focal) + +# ---- SW path: KA kernel calls Raycore.closest_hit per pixel +@kernel function depth_kernel_sw!(depth, @Const(bvh), @Const(rays)) + i = @index(Global, Linear) + @inbounds if i <= length(rays) + ray = rays[i] + hit_found, _, dist, _, _ = Raycore.closest_hit(bvh, ray) + depth[i] = hit_found ? dist : -1f0 + end +end + +sw_static = Adapt.adapt(device, sw_tlas) # StaticTLAS for kernels +rays_sw_gpu = Lava.LavaArray(rays_sw) +depth_sw_gpu = Lava.LavaArray(zeros(Float32, W*H)) + +sw_kernel = depth_kernel_sw!(device, 64) +sw_kernel(depth_sw_gpu, sw_static, rays_sw_gpu, ndrange=W*H) +KA.synchronize(device) +depth_sw = Array(depth_sw_gpu) + +# ---- HW path: batched trace_closest_hits! dispatches vkCmdTraceRaysKHR +rays_hw_gpu = Lava.LavaArray(rays_hw) +hits_hw = Lava.LavaArray(fill(Raycore.RTHitResult(0,0,0,0,0,0,0,0), W*H)) + +Lava.trace_closest_hits!(hits_hw, rays_hw_gpu, hwtlas.hw_accel, length(rays_hw)) +Raycore.wait_for_gpu!(hwtlas) + +depth_hw = Float32[h.hit == UInt32(1) ? h.t : -1f0 for h in Array(hits_hw)] + +# ---- Compare +hit_mask_sw = depth_sw .> 0 +hit_mask_hw = depth_hw .> 0 +disagree = count(hit_mask_sw .!= hit_mask_hw) +shared = hit_mask_sw .& hit_mask_hw +max_abs_diff = maximum(abs.(depth_sw[shared] .- depth_hw[shared])) +``` + +**Pixel-wise agreement** + +- hit-mask disagreement: **0 pixels** out of 49 152 +- max abs depth diff on shared hits: **1.29e-5** + +Sub-1e-4 tolerance is the expected noise floor — both paths use the same Möller–Trumbore-style intersection but with slightly different rounding (the HW path goes through Vulkan's intersection shader, the SW path through Raycore's kernel). + +## How the two paths line up + +``` +Software BVH (SW): Hardware RT (HW): +┌──────────────────┐ ┌──────────────────┐ +│ KA @kernel │ │ Build RTRay │ +│ per pixel │ │ batch │ +│ │ └────────┬─────────┘ +│ Raycore. │ │ +│ closest_hit( │ ┌────────▼─────────┐ +│ bvh, ray) │ │ trace_closest_ │ +│ │ │ hits! (one │ +│ writes depth[i] │ │ vkCmdTraceRays │ +└──────────────────┘ │ call) │ + └────────┬─────────┘ + │ + ┌────────▼─────────┐ + │ RTHitResult[] │ + │ — t, prim_id, │ + │ bary, ... │ + └──────────────────┘ +``` + +In the SW path the kernel is fully programmable — anything you can write inside a `@kernel` function (shadow rays, multi-bounce, custom intersection) works the same way. In the HW path the traversal is fixed-function: rays go in as `RTRay`, hit results come out as `RTHitResult`. The raygen / closest-hit / miss shaders are pre-baked and dispatched as one Vulkan call per ray batch. + +## Visualize and time + +```julia +to_disp(d) = d > 0 ? d : NaN32 +img_sw = reshape(depth_sw, W, H) |> permutedims |> x -> to_disp.(x) +img_hw = reshape(depth_hw, W, H) |> permutedims |> x -> to_disp.(x) + +# Warm-up + 5-shot minimum timing +function time_sw() + KA.synchronize(device) + t = @elapsed begin + sw_kernel(depth_sw_gpu, sw_static, rays_sw_gpu, ndrange=W*H) + KA.synchronize(device) + end + return t +end + +function time_hw() + Raycore.wait_for_gpu!(hwtlas) + t = @elapsed begin + Lava.trace_closest_hits!(hits_hw, rays_hw_gpu, hwtlas.hw_accel, length(rays_hw)) + Raycore.wait_for_gpu!(hwtlas) + end + return t +end + +time_sw(); time_sw(); time_hw(); time_hw() # warm +t_sw = minimum(time_sw() for _ in 1:5) +t_hw = minimum(time_hw() for _ in 1:5) + +fig = Figure(size=(900, 380)) +ax1 = Axis(fig[1, 1], title="SW (Raycore.TLAS) $(round(t_sw*1000, digits=2)) ms", + aspect=DataAspect()) +ax2 = Axis(fig[1, 2], title="HW (Lava.HWTLAS) $(round(t_hw*1000, digits=2)) ms", + aspect=DataAspect()) +hidedecorations!(ax1); hidedecorations!(ax2) +heatmap!(ax1, img_sw, colormap=:viridis) +heatmap!(ax2, img_hw, colormap=:viridis) +fig +``` + +![SW vs HW depth heatmaps](assets/hw_acceleration_compare.png) + +The two heatmaps are visually indistinguishable — the depth comparison cell above quantifies that. The timing ratio depends on triangle count, ray coherence, and GPU; the only honest way to know what your scene needs is to measure both. + +## Lifetime and memory management + +`sync!(hwtlas)` owns `hwtlas.static_tlas`. Consumers re-read it (or call `Adapt.adapt(backend, hwtlas)`) per dispatch — do NOT cache across mutations. + +`sync!` does not block the CPU. Backend-internal timeline tracking (Lava's `bq.deferred_as_frees`) handles the "still in flight" case when old acceleration-structure buffers are dropped. If you need a CPU-blocking drain — e.g. before tear-down or between benchmark phases — call `Raycore.wait_for_gpu!(hwtlas)` explicitly. + +## Direct `HWTLAS` usage + +The cells above already show the full direct-API path. The minimum is: + +```julia +using Raycore, Lava, GeometryBasics, LinearAlgebra +using Raycore: RTRay, RTHitResult + +device = Lava.LavaBackend() +hwtlas = Lava.HWTLAS(device) + +mesh = GeometryBasics.normal_mesh(Sphere(Point3f(0, 0, 2), 1.0f0)) +push!(hwtlas, mesh; instance_id=UInt32(1)) +Raycore.sync!(hwtlas) + +rays = Lava.LavaArray([RTRay(0,0,5, 0, 0,0,-1, 1f3)]) +hits = Lava.LavaArray(fill(RTHitResult(0,0,0,0,0,0,0,0), 1)) +Lava.trace_closest_hits!(hits, rays, hwtlas.hw_accel, 1) +Raycore.wait_for_gpu!(hwtlas) +``` + +`HardwareAccel` (`hwtlas.hw_accel`) is the lower-level handle if you need direct control of the pipeline / SBT or want to install a custom any-hit shader (`Lava.set_anyhit_pipeline!`). + +## RT shader intrinsics + +If you write your own raygen / closest-hit / miss shaders in Lava, the RT intrinsics use the `lava_rt_*` naming convention — no `accel` argument since the SBT wires up the hardware automatically: + +| `Raycore.rt_*` (generic) | `lava_rt_*` (Lava-specific) | +| ---------------------------------------------:| -----------------------------------------:| +| `Raycore.rt_primitive_id(accel)` | `lava_rt_primitive_id()` | +| `Raycore.rt_instance_id(accel)` | `lava_rt_instance_id()` | +| `Raycore.rt_instance_custom_index(accel)` | `lava_rt_instance_custom_index()` | +| `Raycore.rt_launch_id_x(accel)` | `lava_rt_launch_id_x()` | +| `Raycore.rt_trace_ray!(accel, ...)` | `lava_rt_trace_ray(...)` | +| `Raycore.rt_ignore_intersection(accel)` | `lava_rt_ignore_intersection()` | +| `Raycore.rt_payload_store!(accel, val, slot)` | `lava_rt_payload_store_f32_at(val, slot)` | +| `Raycore.rt_payload_load(accel, slot)` | `lava_rt_payload_load_f32_at(slot)` | + +The pre-baked shaders shipped with `Lava.HardwareAccel` (raygen / closest-hit / miss for `RTHitResult` payload) are defined in `Lava/src/raytracing/raycore_compat.jl` as a reference implementation. + diff --git a/docs/src/index.md b/docs/src/index.md index bf5d65c..8d1db4d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,21 +1,23 @@ # Raycore.jl -High-performance ray-triangle intersection engine with BVH acceleration for CPU and GPU. +High-performance ray-triangle intersection engine with TLAS/BLAS acceleration for CPU and GPU. ## Features -- **Fast BVH acceleration** for ray-triangle intersection +- **Fast TLAS/BLAS acceleration** for ray-triangle intersection - **CPU and GPU support** via KernelAbstractions.jl +- **MultiTypeSet**: GPU-safe heterogeneous collections with compile-time type-stable dispatch for materials, textures, lights, etc. +- **GPU TLAS**: `Raycore.TLAS` is software, backend-agnostic (any KA backend). For hardware ray tracing on Vulkan, use `Lava.HWTLAS` — a drop-in `AbstractAccel` implemented via `VK_KHR_ray_tracing_pipeline`. See [Hardware Ray Tracing with Lava](@ref). - **Analysis tools**: centroid calculation, illumination analysis, view factors for radiosity - **Makie integration** for visualization ## Interactive Examples -### BVH Hit Tests & Basics +### Hit Tests & Basics -Learn the basics of ray-triangle intersection, BVH construction, and visualization. +Learn the basics of ray-triangle intersection, TLAS construction, and visualization. -![BVH Basics](basics.png) +![Basics](basics.png) [BVH Hit tests](@ref) @@ -35,6 +37,14 @@ Port the ray tracer to the GPU with KernelAbstractions.jl. Learn about kernel op [GPU Ray Tracing with Raycore](@ref) +### Hardware RT Acceleration + +Use dedicated GPU ray tracing hardware (RT cores / Ray Accelerators) for transparent BVH acceleration via Vulkan. + +![Hardware RT](hw-accel-materials.png) + +[Hardware Ray Tracing with Lava](@ref) + ### View Factors Analysis Calculate view factors, illumination, and centroids for radiosity and thermal analysis. diff --git a/docs/src/raytracing-core.jl b/docs/src/raytracing-core.jl index 18273c7..067ce85 100644 --- a/docs/src/raytracing-core.jl +++ b/docs/src/raytracing-core.jl @@ -85,7 +85,7 @@ function compute_light( shadow_dir = normalize(shadow_vec) shadow_ray = Raycore.Ray(o=point + normal * 0.001f0, d=shadow_dir) - shadow_hit, _, hit_dist, _ = Raycore.any_hit(bvh, shadow_ray) + shadow_hit, _, hit_dist, _, _ = Raycore.any_hit(bvh, shadow_ray) if !shadow_hit || hit_dist >= shadow_dist shadow_factor += 1.0f0 @@ -148,7 +148,7 @@ function reflective_kernel(bvh, ctx, tri, dist, bary, ray, sky_color, shadow_sam # Cast reflection ray reflect_ray = Raycore.Ray(o=hit_point + normal * 0.001f0, d=reflect_dir) - refl_hit, refl_tri, refl_dist, refl_bary = Raycore.closest_hit(bvh, reflect_ray) + refl_hit, refl_tri, refl_dist, refl_bary, _ = Raycore.closest_hit(bvh, reflect_ray) reflection_color = if refl_hit refl_point = reflect_ray.o + reflect_ray.d * refl_dist @@ -187,8 +187,8 @@ function example_scene(; glass_cat=false) left_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(0.01, 5, 10))) # Add a couple of spheres for visual interest - sphere1 = Tesselation(Sphere(Point3f(-2, -1.5 + 0.8, 2), 0.8f0), 64) - sphere2 = Tesselation(Sphere(Point3f(2, -1.5 + 0.6, 1), 0.6f0), 64) + sphere1 = normal_mesh(Sphere(Point3f(-2, -1.5 + 0.8, 2), 0.8f0)) + sphere2 = normal_mesh(Sphere(Point3f(2, -1.5 + 0.6, 1), 0.6f0)) # Material: base_color, metallic, roughness, ior, transmission cat_material = if glass_cat @@ -209,7 +209,7 @@ function example_scene(; glass_cat=false) geometries = [g for (g, _) in scene] materials = [m for (_, m) in scene] - bvh = Raycore.BVH(geometries, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + bvh = Raycore.TLAS(geometries, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) lights = default_lights() ctx = RenderContext(lights, materials, 0.1f0) return bvh, ctx @@ -222,7 +222,7 @@ end function sample_light(bvh, ctx, width, height, camera_pos, focal_length, aspect, x, y, sky_color) jitter = rand(Vec2f) ray = camera_ray(x, y, width, height, camera_pos, focal_length, aspect; jitter) - hit_found, tri, dist, bary = Raycore.closest_hit(bvh, ray) + hit_found, tri, dist, bary, _ = Raycore.closest_hit(bvh, ray) if hit_found color = reflective_kernel(bvh, ctx, tri, dist, bary, ray, RGB(0.5f0, 0.7f0, 1.0f0), 8) return to_vec3f(color) diff --git a/docs/src/raytracing_tutorial_content.md b/docs/src/raytracing_tutorial_content.md index 62b3895..a6ca187 100644 --- a/docs/src/raytracing_tutorial_content.md +++ b/docs/src/raytracing_tutorial_content.md @@ -1,6 +1,6 @@ # Ray Tracing in one Hour -Analougus to the famous [Ray Tracing in one Weekend](https://raytracing.github.io/), this tutorial uses Raycore to do the hard work of performant ray triangle intersection and therefore get a high performing ray tracer in a much shorter time. We'll start with the absolute basics and progressively add features until we have a ray tracer that produces beautiful images with shadows, materials, and reflections. +Analogous to the famous [Ray Tracing in one Weekend](https://raytracing.github.io/), this tutorial uses Raycore to do the hard work of performant ray triangle intersection and therefore get a high performing ray tracer in a much shorter time. We'll start with the absolute basics and progressively add features until we have a ray tracer that produces beautiful images with shadows, materials, and reflections. ## Setup @@ -47,10 +47,11 @@ left_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(0.01, 5, 10))) sphere1 = Tesselation(Sphere(Point3f(-2, -1.5 + 0.8, 2), 0.8f0), 64) sphere2 = Tesselation(Sphere(Point3f(2, -1.5 + 0.6, 1), 0.6f0), 64) -# Build our BVH acceleration structure -scene_geometry = [cat_mesh, floor, back_wall, left_wall, sphere1, sphere2] -bvh = Raycore.BVH(scene_geometry) -md"**BVH built with $(length(bvh.primitives)) triangles**" +# Build our TLAS acceleration structure +scene_geometry = [cat_mesh, floor, back_wall, left_wall, + normal_mesh(sphere1), normal_mesh(sphere2)] +bvh = Raycore.TLAS(scene_geometry, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) +md"**TLAS built**" ``` ## Part 2: Helper Functions - Building Blocks @@ -96,7 +97,7 @@ function trace(f, bvh; width=700, height=300, jitter = samples > 1 ? rand(Vec2f) : Vec2f(0) # Calculate the ray shooting from the camera pixel into the scene ray = camera_ray(x, y, width, height, camera_pos, focal_length, aspect; jitter) - hit_found, triangle, distance, bary_coords = Raycore.closest_hit(bvh, ray) + hit_found, triangle, distance, bary_coords, _ = Raycore.closest_hit(bvh, ray) color = if hit_found to_vec3f(f(bvh, ctx, triangle, distance, bary_coords, ray)) else @@ -155,7 +156,7 @@ function compute_light( shadow_dir = normalize(shadow_vec) shadow_ray = Raycore.Ray(o=point + normal * 0.001f0, d=shadow_dir) - shadow_hit, _, hit_dist, _ = Raycore.any_hit(bvh, shadow_ray) + shadow_hit, _, hit_dist, _, _ = Raycore.any_hit(bvh, shadow_ray) if !shadow_hit || hit_dist >= shadow_dist shadow_factor += 1.0f0 @@ -201,11 +202,11 @@ trace((args...)-> shadow_kernel(args...; shadow_samples=8), bvh, samples=8) ## Part 7: Materials and Multiple Lights -Time to add color and multiple lights! To associate materials with geometry, we need to rebuild the BVH with **metadata** that links each triangle to its material. +Time to add color and multiple lights! To associate materials with geometry, we need to rebuild the TLAS with **metadata** that links each triangle to its material. ### Triangle Metadata -When building a BVH, you can pass a `metadata_fn(mesh_idx, tri_idx)` that assigns custom data to each triangle. This metadata is stored in `triangle.metadata` and returned with every ray hit. For materials, we use the mesh index as metadata: +When building a TLAS, you can pass a `metadata_fn(mesh_idx, tri_idx)` that assigns custom data to each triangle. This metadata is stored in `triangle.metadata` and returned with every ray hit. For materials, we use the mesh index as metadata: ```julia (editor=true, logging=false, output=true) struct PointLight @@ -236,9 +237,9 @@ materials = [ Material(RGB(0.3f0, 0.6f0, 0.9f0), 0.5f0, 0.3f0), # 6: sphere2 - semi-metallic ] -# Rebuild BVH with material indices as metadata +# Rebuild TLAS with material indices as metadata # The metadata_fn receives (mesh_idx, tri_idx) and returns data stored per-triangle -bvh = Raycore.BVH(scene_geometry, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) +bvh = Raycore.TLAS(scene_geometry, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) # Create lights lights = [ @@ -310,7 +311,7 @@ function reflective_kernel(bvh, ctx, tri, dist, bary, ray, sky_color) # Cast reflection ray reflect_ray = Raycore.Ray(o=hit_point + normal * 0.001f0, d=reflect_dir) - refl_hit, refl_tri, refl_dist, refl_bary = Raycore.closest_hit(bvh, reflect_ray) + refl_hit, refl_tri, refl_dist, refl_bary, _ = Raycore.closest_hit(bvh, reflect_ray) reflection_color = if refl_hit refl_point = reflect_ray.o + reflect_ray.d * refl_dist @@ -374,7 +375,7 @@ We built a complete ray tracer with: **Core Features:** - * BVH acceleration for fast ray-scene intersections + * TLAS acceleration for fast ray-scene intersections * Perspective camera with configurable FOV * Smooth shading from interpolated normals * Multi-light system with distance attenuation @@ -392,11 +393,11 @@ We built a complete ray tracer with: **Key Raycore Functions:** - * `Raycore.BVH(meshes)` - Build acceleration structure (default metadata = primitive index) - * `Raycore.BVH(meshes, metadata_fn)` - Build with custom per-triangle metadata + * `Raycore.TLAS(meshes)` - Build acceleration structure + * `Raycore.TLAS(meshes, metadata_fn)` - Build with custom per-triangle metadata * `Raycore.Ray(o=origin, d=direction)` - Create ray - * `Raycore.closest_hit(bvh, ray)` - Find nearest intersection, returns `(hit, triangle, distance, bary_coords)` - * `Raycore.any_hit(bvh, ray)` - Test for any intersection (fast shadow test) + * `Raycore.closest_hit(tlas, ray)` - Find nearest intersection, returns `(hit, triangle, distance, bary_coords, instance_id)` + * `Raycore.any_hit(tlas, ray)` - Test for any intersection (fast shadow test) * `Raycore.reflect(wo, normal)` - Compute reflection direction * `triangle.metadata` - Access custom data stored per-triangle @@ -405,8 +406,8 @@ We built a complete ray tracer with: 1. **Material Scene Pattern** - Associate materials with geometry using metadata: ```julia -# Build BVH with mesh index as metadata -bvh = Raycore.BVH(meshes, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) +# Build TLAS with mesh index as metadata +tlas = Raycore.TLAS(meshes, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) # In your shader, look up material from hit triangle mat = materials[triangle.metadata] diff --git a/docs/src/test_hikari_wavefront.jl b/docs/src/test_hikari_wavefront.jl new file mode 100644 index 0000000..c8f6e5b --- /dev/null +++ b/docs/src/test_hikari_wavefront.jl @@ -0,0 +1,23 @@ +using Revise +using Raycore +using Raycore: to_gpu +using KernelAbstractions +using KernelAbstractions: @kernel, @index, @Const +using GeometryBasics, Colors, LinearAlgebra +import Makie +using Makie: RGBf +import KernelAbstractions as KA +using ImageShow + +# Load the Hikari wavefront renderer +# Create the example scene with Hikari materials +material_scene, lights = hikari_example_scene() + +# Create and render +begin + img = fill(RGBf(0, 0, 0), 400, 720) + renderer = HikariWavefrontRenderer(img, material_scene, lights) + render!(renderer) + nothing +end +img diff --git a/docs/src/test_wavefront.jl b/docs/src/test_wavefront.jl index ee0f86d..0cf8e87 100644 --- a/docs/src/test_wavefront.jl +++ b/docs/src/test_wavefront.jl @@ -22,42 +22,3 @@ begin @btime render!(renderer) nothing end -renderer.framebuffer -renderer_instanced.framebuffer -begin - img = fill(RGBf(0, 0, 0), 400, 720) - renderer_instanced = WavefrontRenderer( - img, bvh, ctx; - camera_pos=Point3f(0, -0.9, -2.5), - fov=45.0f0, - sky_color=RGB{Float32}(0.5f0, 0.7f0, 1.0f0), - samples_per_pixel=4 - ) - @btime render!(renderer_instanced) - # on windows + ryzen 395 max - # 381.034 ms (1200456 allocations: 90.13 MiB) - - nothing -end -using ImageShow - -using FileIO -save("wavefront.png", map(col -> mapc(c -> clamp(c, 0f0, 1f0), col), renderer.framebuffer)) - -using AMDGPU -amd_renderer = to_gpu(ROCArray, renderer); -Array(@btime render!(amd_renderer)) -# 36ms on windows + amd 8060s - -using pocl_jll, OpenCL -amd_renderer = to_gpu(CLArray, renderer); -Array(@time render!(amd_renderer)) -r = Raycore.Ray(o=Point3f(0, 0, 0), d=Vec3f(0, 0, 1), t_max=0.0f0) - -@code_warntype any_hit(ibvh, r) - -function test(bvh) - meshes = getfield(bvh, :meshes) -end - -typeof(test(ibvh)) diff --git a/docs/src/viewfactors_content.md b/docs/src/viewfactors_content.md index 510a6db..d340010 100644 --- a/docs/src/viewfactors_content.md +++ b/docs/src/viewfactors_content.md @@ -7,6 +7,7 @@ This example demonstrates Raycore's analysis capabilities for radiosity and illu ```julia (editor=true, logging=false, output=true) using Raycore, GeometryBasics, LinearAlgebra using WGLMakie +import KernelAbstractions as KA function LowSphere(radius, contact=Point3f(0); ntriangles=6) return Tesselation(Sphere(contact .+ Point3f(0, 0, radius), radius), ntriangles) @@ -20,9 +21,35 @@ s3 = LowSphere(0.3f0, Point3f(-0.5, 1, 0); ntriangles) s4 = LowSphere(0.4f0, Point3f(0, 1.0, 0); ntriangles) s5 = LowSphere(0.35f0, Point3f(0.5, 0.0, 0); ntriangles) -# Build BVH acceleration structure -bvh = BVH([s1, s2, s3, s4, s5]) -world_mesh = GeometryBasics.Mesh(bvh) +# Build meshes and TLAS with sequential triangle indices as metadata +sphere_meshes = [normal_mesh(s) for s in [s1, s2, s3, s4, s5]] +tlas = Raycore.TLAS(sphere_meshes, (mesh_idx, tri_idx) -> begin + offset = sum(length(faces(sphere_meshes[i])) for i in 1:mesh_idx-1; init=0) + UInt32(offset + tri_idx) +end) + +# Build visualization mesh from TLAS primitives (which may have degenerate triangles removed) +# This guarantees face count matches between the TLAS and the visualization mesh. +function tlas_to_mesh(tlas) + verts = Point3f[] + norms = Vec3f[] + fs = GLTriangleFace[] + for tri in tlas.all_blas_prims + base = length(verts) + for v in tri.vertices + push!(verts, v) + end + for n in tri.normals + push!(norms, Vec3f(n)) + end + push!(fs, GLTriangleFace(base + 1, base + 2, base + 3)) + end + # GeometryBasics ≥ 0.5 dropped `meta(...)`; vertex normals go via the + # `normal` kwarg on `GeometryBasics.mesh`/`Mesh`. + GeometryBasics.mesh(verts, fs; normal=norms) +end +world_mesh = tlas_to_mesh(tlas) +N = length(tlas.all_blas_prims) # Visualize the scene f, ax, pl = mesh(world_mesh, color=:teal, axis=(show_axis=false,)) @@ -35,11 +62,10 @@ View factors quantify how much each surface "sees" every other surface - essenti ```julia (editor=true, logging=false, output=true) # Calculate view factors between all faces -viewf_matrix = view_factors(bvh, rays_per_triangle=20) +viewf_matrix = view_factors(tlas, rays_per_triangle=20) # Sum up total view factor per face -viewfacts = map(i -> Float32(sum(view(viewf_matrix, :, i))), 1:length(bvh.primitives)) -N = length(world_mesh.faces) +viewfacts = map(i -> Float32(sum(view(viewf_matrix, :, i))), 1:N) # Visualize per_face_vf = FaceView(viewfacts, [GLTriangleFace(i) for i in 1:N]) @@ -59,7 +85,7 @@ Calculate how much each face is exposed to rays from a specific viewing directio viewdir = normalize(ax.scene.camera.view_direction[]) # Compute illumination -illum = get_illumination(bvh, viewdir; grid_size=10) +illum = get_illumination(tlas, viewdir; grid_size=10) # Visualize pf = FaceView(illum, [GLTriangleFace(i) for i in 1:N]) @@ -75,7 +101,7 @@ Find the average position of visible surface points from a given direction. ```julia (editor=true, logging=false, output=true) # Calculate centroid -hitpoints, centroid = get_centroid(bvh, viewdir; grid_size=10) +hitpoints, centroid = get_centroid(tlas, viewdir; grid_size=10) # Visualize f, ax, pl = mesh(world_mesh, color=(:blue, 0.5), transparency=true, axis=(show_axis=false,)) @@ -86,3 +112,4 @@ meshscatter!(ax, [centroid], color=:red, markersize=0.05) f ``` The red sphere marks the centroid - useful for camera placement and focus calculations. + diff --git a/docs/src/wavefront-renderer.jl b/docs/src/wavefront-renderer.jl index 73a0e71..4fc5246 100644 --- a/docs/src/wavefront-renderer.jl +++ b/docs/src/wavefront-renderer.jl @@ -11,15 +11,15 @@ using Statistics # ============================================================================ """ - MaterialScene(geometry_material_pairs) -> (bvh, materials) + MaterialScene(geometry_material_pairs) -> (tlas, materials) -Create a BVH and materials array from a vector of (geometry, material) tuples. +Create a TLAS and materials array from a vector of (geometry, material) tuples. # Arguments - `geometry_material_pairs`: Vector of `(mesh, material)` tuples # Returns -- `bvh`: BVH with material indices as triangle metadata +- `tlas`: TLAS with material indices as triangle metadata - `materials`: Vector of materials # Example @@ -35,12 +35,12 @@ function MaterialScene(geometry_material_pairs::Vector{<:Tuple{<:Any, M}}) where meshes = [p[1] for p in geometry_material_pairs] materials = M[p[2] for p in geometry_material_pairs] - # Create BVH with material index as metadata + # Create TLAS with material index as metadata function metadata_fn(mesh_idx, _tri_idx) return Int32(mesh_idx) end - bvh = Raycore.BVH(meshes, metadata_fn) - return bvh, materials + tlas = Raycore.TLAS(meshes, metadata_fn) + return tlas, materials end # ============================================================================ @@ -212,6 +212,47 @@ end end end +""" +Generate primary rays with look-at camera. +camera_right, camera_up, camera_forward define the camera basis vectors. +""" +@kernel function generate_primary_rays_lookat!( + @Const(width), @Const(height), + @Const(camera_pos), + @Const(camera_right), @Const(camera_up), @Const(camera_forward), + @Const(half_width), @Const(half_height), + ray_queue, + ::Val{NSamples} +) where {NSamples} + i = @index(Global, Cartesian) + y = gpu_int(i[1]) + x = gpu_int(i[2]) + + @inbounds if y <= height && x <= width + pixel_idx = (y - gpu_int(1)) * width + x + ntuple(Val(NSamples)) do s + s_idx = gpu_int(s) + ray_idx = (pixel_idx - gpu_int(1)) * gpu_int(NSamples) + s_idx + jitter = rand(Vec2f) + + # Normalized device coordinates [-1, 1] + u = (2.0f0 * (Float32(x) - 0.5f0 + jitter[1]) / Float32(width) - 1.0f0) + v = (1.0f0 - 2.0f0 * (Float32(y) - 0.5f0 + jitter[2]) / Float32(height)) + + # Ray direction in world space + direction = normalize( + camera_forward + + camera_right * (u * half_width) + + camera_up * (v * half_height) + ) + ray = Raycore.Ray(o=camera_pos, d=direction) + + @set ray_queue[ray_idx] = (ray=ray, pixel_x=x, pixel_y=y, sample_idx=s_idx) + nothing + end + end +end + # ============================================================================ # Stage 2: Intersect Primary Rays # ============================================================================ @@ -227,7 +268,7 @@ end @inbounds if idx <= length(ray_queue.ray) # Read from SoA @get ray, pixel_x, pixel_y, sample_idx = ray_queue[idx] - hit_found, tri, dist, bary = Raycore.closest_hit(bvh, ray) + hit_found, tri, dist, bary, _ = Raycore.closest_hit(bvh, ray) # Write to SoA using @set @set hit_queue[idx] = (hit_found=hit_found, tri=tri, dist=dist, bary=Vec3f(bary), ray=ray, pixel_x=pixel_x, pixel_y=pixel_y, sample_idx=sample_idx) @@ -310,7 +351,7 @@ end # any_hit respects ray.t_max and only returns hits before the light # So if we get a hit, something is blocking the light visible = if ray.t_max > 0.0f0 - hit_found, _, _, _ = Raycore.any_hit(bvh, ray) + hit_found, _, _, _, _ = Raycore.any_hit(bvh, ray) !hit_found # Visible only if no obstruction else false # Dummy ray (sky hits) @@ -451,7 +492,7 @@ end @get ray, hit_idx = reflection_ray_soa[idx] if ray.t_max > 0.0f0 - hit_found, tri, dist, bary = Raycore.closest_hit(bvh, ray) + hit_found, tri, dist, bary, _ = Raycore.closest_hit(bvh, ray) if hit_found # Compute normal here so we don't need to store the triangle v0, v1, v2 = Raycore.normals(tri) @@ -598,7 +639,7 @@ converted to GPU using `to_gpu(ArrayType, renderer)`. # Fields - `framebuffer`: Output image buffer -- `bvh`: BVH acceleration structure +- `bvh`: TLAS acceleration structure - `ctx`: Scene context (materials, lights) - Camera parameters: `camera_pos`, `fov`, `sky_color`, `samples_per_pixel` - Work queues for each wavefront stage @@ -615,6 +656,8 @@ struct WavefrontRenderer{ImgArr <: AbstractMatrix, BVH, Ctx} # Camera parameters camera_pos::Point3f + camera_lookat::Point3f # Look-at target + camera_up::Vec3f # Up vector fov::Float32 sky_color::RGB{Float32} samples_per_pixel::Int32 @@ -632,13 +675,20 @@ struct WavefrontRenderer{ImgArr <: AbstractMatrix, BVH, Ctx} end """ - WavefrontRenderer(img, bvh, ctx; camera_pos, fov, sky_color, samples_per_pixel) + WavefrontRenderer(img, bvh, ctx; camera_pos, camera_lookat, camera_up, fov, sky_color, samples_per_pixel) Create a WavefrontRenderer with all necessary buffers allocated for the given image size and scene. + +# Arguments +- `camera_pos`: Camera position in world space +- `camera_lookat`: Point the camera is looking at (default: origin) +- `camera_up`: Up vector for camera orientation (default: +Y) """ function WavefrontRenderer( img, bvh, ctx; camera_pos=Point3f(0, -0.9, -2.5), + camera_lookat=Point3f(0, -0.9, 10), + camera_up=Vec3f(0, 1, 0), fov=45.0f0, sky_color=RGB{Float32}(0.5f0, 0.7f0, 1.0f0), samples_per_pixel=4 @@ -652,7 +702,7 @@ function WavefrontRenderer( # Allocate work queues as SoA primary_ray_queue = similar_soa(img, PrimaryRayWork, num_rays) - primary_hit_queue = similar_soa(img, PrimaryHitWork{eltype(bvh.primitives)}, num_rays) + primary_hit_queue = similar_soa(img, PrimaryHitWork{eltype(bvh)}, num_rays) shadow_ray_queue = similar_soa(img, ShadowRayWork, num_shadow_rays) shadow_result_queue = similar_soa(img, ShadowResult, num_shadow_rays) reflection_ray_soa = similar_soa(img, ReflectionRayWork, num_rays) @@ -664,7 +714,8 @@ function WavefrontRenderer( return WavefrontRenderer( Int32(width), Int32(height), img, bvh, ctx, - camera_pos, fov, sky_color, Int32(samples_per_pixel), + camera_pos, camera_lookat, camera_up, + fov, sky_color, Int32(samples_per_pixel), primary_ray_queue, primary_hit_queue, shadow_ray_queue, shadow_result_queue, reflection_ray_soa, reflection_hit_soa, @@ -682,7 +733,7 @@ function Raycore.to_gpu(Arr, renderer::WavefrontRenderer) # Convert image img = Arr(renderer.framebuffer) - # Convert BVH and context + # Convert TLAS and context bvh_gpu = Raycore.to_gpu(Arr, renderer.bvh) ctx_gpu = Raycore.to_gpu(Arr, renderer.ctx) @@ -690,6 +741,8 @@ function Raycore.to_gpu(Arr, renderer::WavefrontRenderer) return WavefrontRenderer( img, bvh_gpu, ctx_gpu; camera_pos=renderer.camera_pos, + camera_lookat=renderer.camera_lookat, + camera_up=renderer.camera_up, fov=renderer.fov, sky_color=renderer.sky_color, samples_per_pixel=Int(renderer.samples_per_pixel) @@ -708,7 +761,6 @@ function render!(renderer::WavefrontRenderer) samples_per_pixel = Int(renderer.samples_per_pixel) aspect = Float32(width / height) - focal_length = 1.0f0 / tan(deg2rad(renderer.fov / 2)) backend = KA.get_backend(renderer.framebuffer) @@ -717,11 +769,23 @@ function render!(renderer::WavefrontRenderer) num_lights = Int(length(renderer.ctx.lights)) num_shadow_rays = num_rays * num_lights - # Stage 1: Generate primary rays - gen_kernel! = generate_primary_rays!(backend) + # Compute look-at camera basis vectors + # Standard right-handed: right = up × forward (gives +X when up=+Y, forward=+Z) + camera_forward = Vec3f(normalize(renderer.camera_lookat - renderer.camera_pos)) + camera_right = Vec3f(normalize(cross(renderer.camera_up, camera_forward))) + camera_up_ortho = Vec3f(cross(camera_forward, camera_right)) + + # Compute half-width and half-height based on FOV + half_height = tan(deg2rad(renderer.fov / 2)) + half_width = half_height * aspect + + # Stage 1: Generate primary rays with look-at camera + gen_kernel! = generate_primary_rays_lookat!(backend) gen_kernel!( renderer.width, renderer.height, - renderer.camera_pos, focal_length, aspect, + renderer.camera_pos, + camera_right, camera_up_ortho, camera_forward, + half_width, half_height, renderer.primary_ray_queue, Val(samples_per_pixel), ndrange=(height, width) diff --git a/docs/src/wavefront_dynamic.jl b/docs/src/wavefront_dynamic.jl new file mode 100644 index 0000000..f11942b --- /dev/null +++ b/docs/src/wavefront_dynamic.jl @@ -0,0 +1,201 @@ +# Dynamic Scene Example - Animated TLAS with Transform Updates +# +# This example demonstrates efficient dynamic scene updates using TLAS. +# Instead of rebuilding the entire acceleration structure each frame, +# we update instance transforms and refit the TLAS - much faster! + +using Revise +using Raycore +using KernelAbstractions +using GeometryBasics, Colors, LinearAlgebra +import Makie +using Makie: RGBf +import KernelAbstractions as KA +using FileIO, ImageCore + +# Load helper functions +include("raytracing-core.jl") +include("wavefront-renderer.jl") + +""" +Create a rotation matrix around the Y axis. +""" +function rotation_y(angle::Float32)::Mat4f + c, s = cos(angle), sin(angle) + Mat4f( + c, 0, s, 0, + 0, 1, 0, 0, + -s, 0, c, 0, + 0, 0, 0, 1 + ) +end + +""" +Create a translation matrix. +""" +function translation(x::Float32, y::Float32, z::Float32)::Mat4f + Mat4f( + 1, 0, 0, x, + 0, 1, 0, y, + 0, 0, 1, z, + 0, 0, 0, 1 + ) +end + +""" +Create a scale matrix. +""" +function scaling(sx::Float32, sy::Float32, sz::Float32)::Mat4f + Mat4f( + sx, 0, 0, 0, + 0, sy, 0, 0, + 0, 0, sz, 0, + 0, 0, 0, 1 + ) +end + +""" +Create a dynamic scene with objects that can be animated. +Returns `(tlas, handles, ctx)` — `handles[i]` is the `TLASHandle` for geometry `i`. +""" +function create_dynamic_scene() + sphere1 = Tesselation(Sphere(Point3f(0, 0, 0), 0.5f0), 32) + sphere2 = Tesselation(Sphere(Point3f(0, 0, 0), 0.4f0), 32) + cube = normal_mesh(Rect3f(Vec3f(-0.4f0), Vec3f(0.8f0))) + floor = normal_mesh(Rect3f(Vec3f(-4, -1, -4), Vec3f(8, 0.01, 8))) + back_wall = normal_mesh(Rect3f(Vec3f(-4, -1, 4), Vec3f(8, 4, 0.01))) + + materials = [ + Material(RGB(0.9f0, 0.2f0, 0.2f0), 0.3f0, 0.4f0, 1.0f0, 0.0f0), + Material(RGB(0.2f0, 0.9f0, 0.2f0), 0.5f0, 0.2f0, 1.0f0, 0.0f0), + Material(RGB(0.2f0, 0.2f0, 0.9f0), 0.0f0, 0.6f0, 1.0f0, 0.0f0), + Material(RGB(0.4f0, 0.4f0, 0.4f0), 0.0f0, 0.9f0, 1.0f0, 0.0f0), + Material(RGB(0.8f0, 0.7f0, 0.6f0), 0.0f0, 0.8f0, 1.0f0, 0.0f0), + ] + + println("Building initial TLAS...") + tlas = Raycore.TLAS(KernelAbstractions.CPU()) + handles = [push!(tlas, normal_mesh(g)) for g in [sphere1, sphere2, cube, floor, back_wall]] + Raycore.sync!(tlas) + println("TLAS created with $(Raycore.n_instances(tlas)) instances") + + lights = default_lights() + ctx = RenderContext(lights, materials, 0.1f0) + + return tlas, handles, ctx +end + +""" +Update transforms for frame t (0 to 1 for one animation cycle). +""" +function update_scene!(tlas::Raycore.TLAS, handles::Vector, t::Float32) + orbit_radius = 1.5f0 + orbit_angle = t * 2f0 * Float32(pi) + Raycore.update_transform!(tlas, handles[1], + translation(orbit_radius * cos(orbit_angle), 0.0f0, orbit_radius * sin(orbit_angle))) + + bounce_height = 0.5f0 + 0.8f0 * abs(sin(t * 4f0 * Float32(pi))) + Raycore.update_transform!(tlas, handles[2], translation(-1.5f0, bounce_height, 0.0f0)) + + cube_x = 1.0f0 * sin(t * 2f0 * Float32(pi)) + Raycore.update_transform!(tlas, handles[3], + translation(cube_x, 0.0f0, -1.0f0) * rotation_y(t * 3f0 * Float32(pi))) + + # handles[4] (floor) and handles[5] (wall) are static — no update needed + Raycore.sync!(tlas) +end + +""" +Render a single frame. +""" +function render_frame!(img, tlas, ctx) + renderer = WavefrontRenderer(img, tlas, ctx) + render!(renderer) +end + +# ============================================================================== +# Main Animation Loop +# ============================================================================== + +println("\n" * "="^70) +println("Dynamic Scene Example - Animated TLAS") +println("="^70) + +# Create scene +tlas, ctx = create_dynamic_scene() + +# Animation parameters +num_frames = 60 +width, height = 400, 300 + +println("\nRendering $num_frames frames...") +println(" Resolution: $(width)x$(height)") + +# Render animation frames +frames = Vector{Matrix{RGB{Float32}}}(undef, num_frames) + +# Warmup +img_warmup = fill(RGBf(0, 0, 0), height, width) +update_scene!(tlas, 0.0f0) +render_frame!(img_warmup, tlas, ctx) + +# Time the animation +total_time = @elapsed begin + for frame in 1:num_frames + t = Float32((frame - 1) / num_frames) + + # Update transforms and refit TLAS + update_scene!(tlas, t) + + # Render frame + img = fill(RGBf(0, 0, 0), height, width) + render_frame!(img, tlas, ctx) + frames[frame] = img + + if frame % 10 == 0 + println(" Frame $frame/$num_frames") + end + end +end + +avg_fps = num_frames / total_time +println("\nAnimation complete!") +println(" Total time: $(round(total_time, digits=2))s") +println(" Average FPS: $(round(avg_fps, digits=1))") + +# Save first, middle, and last frames +println("\nSaving sample frames...") +save("dynamic_frame_001.png", map(clamp01nan, frames[1])) +save("dynamic_frame_030.png", map(clamp01nan, frames[30])) +save("dynamic_frame_060.png", map(clamp01nan, frames[60])) +println("Saved: dynamic_frame_001.png, dynamic_frame_030.png, dynamic_frame_060.png") + +# Benchmark refit vs rebuild +println("\n" * "="^70) +println("Benchmarking: Refit vs Rebuild") +println("="^70) + +using BenchmarkTools + +# Benchmark refit only +refit_time = @belapsed begin + update_scene!($tlas, 0.5f0) +end + +println(" Refit time: $(round(refit_time * 1000, digits=3)) ms") + +# For comparison: time to rebuild entire TLAS +geometries = [ + Tesselation(Sphere(Point3f(0, 0, 0), 0.5f0), 32), + Tesselation(Sphere(Point3f(0, 0, 0), 0.4f0), 32), + normal_mesh(Rect3f(Vec3f(-0.4f0), Vec3f(0.8f0))), + normal_mesh(Rect3f(Vec3f(-4, -1, -4), Vec3f(8, 0.01, 8))), + normal_mesh(Rect3f(Vec3f(-4, -1, 4), Vec3f(8, 4, 0.01))), +] + +rebuild_time = @belapsed begin + Raycore.TLAS($geometries, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) +end + +println(" Rebuild time: $(round(rebuild_time * 1000, digits=3)) ms") +println(" Speedup: $(round(rebuild_time / refit_time, digits=1))x faster with refit!") diff --git a/docs/src/wavefront_instanced.jl b/docs/src/wavefront_instanced.jl new file mode 100644 index 0000000..7130845 --- /dev/null +++ b/docs/src/wavefront_instanced.jl @@ -0,0 +1,116 @@ +using Revise +using Raycore +using Raycore: to_gpu +using KernelAbstractions +using KernelAbstractions: @kernel, @index, @Const +using GeometryBasics, Colors, LinearAlgebra +import Makie +using Makie: RGBf +import KernelAbstractions as KA +using ImageShow +using BenchmarkTools + +# Load helper functions +include("raytracing-core.jl") +include("wavefront-renderer.jl") + +function example_scene_tlas(; glass_cat=false) + cat_mesh = Makie.loadasset("cat.obj") + angle = deg2rad(150f0) + rotation = Makie.Quaternionf(0, sin(angle/2), 0, cos(angle/2)) + rotated_coords = [rotation * Point3f(v) for v in coordinates(cat_mesh)] + + # Get bounding box and translate cat to sit on the floor + cat_bbox = Rect3f(rotated_coords) + floor_y = -1.5f0 + cat_offset = Vec3f(0, floor_y - cat_bbox.origin[2], 0) + + cat_mesh = GeometryBasics.normal_mesh( + [v + cat_offset for v in rotated_coords], + GeometryBasics.faces(cat_mesh) + ) + + # Create a simple room: floor, back wall, and side wall + floor = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(10, 0.01, 10))) + back_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, 8), Vec3f(10, 5, 0.01))) + left_wall = normal_mesh(Rect3f(Vec3f(-5, -1.5, -2), Vec3f(0.01, 5, 10))) + + # Add a couple of spheres for visual interest + sphere1 = Tesselation(Sphere(Point3f(-2, -1.5 + 0.8, 2), 0.8f0), 64) + sphere2 = Tesselation(Sphere(Point3f(2, -1.5 + 0.6, 1), 0.6f0), 64) + + # Material: base_color, metallic, roughness, ior, transmission + cat_material = if glass_cat + Material(RGB(0.95f0, 1.0f0, 0.95f0), 0.0f0, 0.0f0, 1.5f0, 1.0f0) + else + Material(RGB(0.8f0, 0.6f0, 0.4f0), 0.0f0, 0.8f0, 1.0f0, 0.0f0) + end + + # (geometry, material) pairs + scene = [ + (cat_mesh, cat_material), + (floor, Material(RGB(0.3f0, 0.5f0, 0.3f0), 0.0f0, 0.9f0, 1.0f0, 0.0f0)), + (back_wall, Material(RGB(0.8f0, 0.6f0, 0.5f0), 0.8f0, 0.05f0, 1.0f0, 0.0f0)), + (left_wall, Material(RGB(0.7f0, 0.7f0, 0.8f0), 0.0f0, 0.8f0, 1.0f0, 0.0f0)), + (sphere1, Material(RGB(0.9f0, 0.9f0, 0.9f0), 0.8f0, 0.02f0, 1.0f0, 0.0f0)), + (sphere2, Material(RGB(0.3f0, 0.6f0, 0.9f0), 0.5f0, 0.3f0, 1.0f0, 0.0f0)), + ] + + geometries = [g for (g, _) in scene] + materials = [m for (_, m) in scene] + + println("\nBuilding TLAS (instanced BVH)...") + println(" Each mesh becomes its own BLAS with a single instance") + + # Use TLAS instead of BVH - drop-in replacement! + tlas = Raycore.TLAS(geometries, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + + lights = default_lights() + ctx = RenderContext(lights, materials, 0.1f0) + return tlas, ctx +end + +println("\n" * "="^70) +println("Creating scene with TLAS...") +println("="^70) + +tlas, ctx = example_scene_tlas() + +println("\n" * "="^70) +println("Rendering image...") +println("="^70) + +# ibvh = Raycore.InstancedBVH(geom) +begin + img = fill(RGBf(0, 0, 0), 400, 720) + # Use original camera parameters to match pre-look-at benchmark image + renderer = WavefrontRenderer(img, tlas, ctx; + camera_pos=Point3f(0, -0.9, -2.5), + camera_lookat=Point3f(0, -0.9, 10), # Look in +Z direction + camera_up=Vec3f(0, 1, 0), # Y up + fov=45.0f0, + sky_color=RGB{Float32}(0.5f0, 0.7f0, 1.0f0), + samples_per_pixel=4`` + ) + @btime render!(renderer) +end +img +using FileIO, ImageCore + +save("wavefront_instanced.png", map(clamp01nan, img)) +using AMDGPU +begin + img = fill(RGBf(0, 0, 0), 400, 720) + # Use original camera parameters to match pre-look-at benchmark image + renderer = WavefrontRenderer(img, tlas, ctx; + camera_pos=Point3f(0, -0.9, -2.5), + camera_lookat=Point3f(0, -0.9, 10), # Look in +Z direction + camera_up=Vec3f(0, 1, 0), # Y up + fov=45.0f0, + sky_color=RGB{Float32}(0.5f0, 0.7f0, 1.0f0), + samples_per_pixel=4 + ) + gpu_renderer = to_gpu(ROCArray, renderer) + @btime render!(gpu_renderer) + Array(gpu_renderer.framebuffer) +end diff --git a/docs/src/wavefront_lego.jl b/docs/src/wavefront_lego.jl new file mode 100644 index 0000000..9c9d500 --- /dev/null +++ b/docs/src/wavefront_lego.jl @@ -0,0 +1,381 @@ +# ============================================================================== +# Animated Lego Figure - TLAS Dynamic Scene Demo +# ============================================================================== +# +# This example demonstrates the power of instanced BVH (TLAS/BLAS) for animated scenes. +# Each body part is a separate BLAS instance that can be transformed independently. +# Animation updates only require refitting the TLAS - no geometry rebuild needed! +# +# Based on the Lego figure model by Kevin-Mattheus-Moerman +# https://twitter.com/KMMoerman/status/1417759722963415041 + +using Revise +using Raycore +using Raycore: Mat4f +using KernelAbstractions +using GeometryBasics, Colors, LinearAlgebra +using FileIO, MeshIO +import Makie +using Makie: RGBf +import KernelAbstractions as KA +using ImageCore + +# Load helper functions +include("raytracing-core.jl") +include("wavefront-renderer.jl") + +const IDENTITY = Mat4f(I) + +# ============================================================================== +# Transform Utilities +# ============================================================================== + +"""Create a rotation matrix around an arbitrary axis (Rodrigues' formula). +Note: Mat4f uses column-major storage, so values are specified column by column.""" +function rotation_axis(axis::Vec3f, angle::Float32)::Mat4f + axis = normalize(axis) + c, s = cos(angle), sin(angle) + t = 1.0f0 - c + x, y, z = axis[1], axis[2], axis[3] + + # Column-major: each group of 4 values is one column + Mat4f( + t*x*x + c, t*x*y + s*z, t*x*z - s*y, 0, # column 1 + t*x*y - s*z, t*y*y + c, t*y*z + s*x, 0, # column 2 + t*x*z + s*y, t*y*z - s*x, t*z*z + c, 0, # column 3 + 0, 0, 0, 1 # column 4 + ) +end + +"""Create a translation matrix. +Note: Mat4f uses column-major storage, translation goes in column 4.""" +function translation(v::Union{Vec3f, Point3f})::Mat4f + # Column-major: translation in column 4, rows 1-3 + Mat4f( + 1, 0, 0, 0, # column 1 + 0, 1, 0, 0, # column 2 + 0, 0, 1, 0, # column 3 + v[1], v[2], v[3], 1 # column 4 + ) +end + +# ============================================================================== +# Lego Figure Configuration +# ============================================================================== + +# Part colors (matching original RPRMakie example) +const LEGO_COLORS = Dict( + "eyes_mouth" => RGB(0.0f0, 0.0f0, 0.0f0), + "belt" => RGB(0.0f0, 0.0f0, 0.35f0), + "arm_right" => RGB(0.0f0, 0.6f0, 0.15f0), + "arm_left" => RGB(0.0f0, 0.6f0, 0.15f0), + "hand_right" => RGB(1.0f0, 0.85f0, 0.0f0), + "hand_left" => RGB(1.0f0, 0.85f0, 0.0f0), + "leg_right" => RGB(0.2f0, 0.4f0, 0.9f0), + "leg_left" => RGB(0.2f0, 0.4f0, 0.9f0), + "torso" => RGB(0.84f0, 0.06f0, 0.15f0), + "head" => RGB(1.0f0, 0.85f0, 0.0f0), +) + +# Joint pivot points (in local mesh coordinates) +const JOINT_ORIGINS = Dict( + "arm_right" => Point3f(0.1427, -6.2127, 5.7342), + "arm_left" => Point3f(0.1427, 6.2127, 5.7342), + "leg_right" => Point3f(0, -1, -8.2), + "leg_left" => Point3f(0, 1, -8.2), +) + +# Rotation axes for joints +const ROTATION_AXES = Dict( + "arm_right" => Vec3f(0.0, -0.9828, 0.1848), + "arm_left" => Vec3f(0.0, 0.9828, 0.1848), + "leg_right" => Vec3f(0, -1, 0), + "leg_left" => Vec3f(0, 1, 0), +) + +# Order of parts for instance indexing +const PART_ORDER = [ + "torso", "head", "eyes_mouth", + "arm_right", "hand_right", + "arm_left", "hand_left", + "belt", "leg_right", "leg_left", + "floor" +] + +# Parent relationships: child => parent +const PART_PARENTS = Dict( + "head" => "torso", + "eyes_mouth" => "head", + "arm_right" => "torso", + "hand_right" => "arm_right", + "arm_left" => "torso", + "hand_left" => "arm_left", + "belt" => "torso", + "leg_right" => "belt", + "leg_left" => "belt", +) + +# ============================================================================== +# Scene Creation +# ============================================================================== + +"""Load a lego part mesh from Makie assets.""" +function load_lego_part(name::String) + path = Makie.assetpath("lego_figure_$name.stl") + mesh = load(path) + return normal_mesh(mesh) +end + +""" +Create the lego scene with TLAS. +Returns (tlas, ctx, materials) +""" +function create_lego_scene() + println("Loading lego figure parts...") + + geometries = [] + materials = Material[] + + # Load all body parts + for part_name in PART_ORDER[1:end-1] # Exclude floor + mesh = load_lego_part(part_name) + push!(geometries, mesh) + + color = get(LEGO_COLORS, part_name, RGB(0.5f0, 0.5f0, 0.5f0)) + # Slight metallic sheen on plastic + mat = Material(color, 0.1f0, 0.4f0, 1.0f0, 0.0f0) + push!(materials, mat) + + println(" Loaded $part_name: $(length(faces(mesh))) triangles") + end + + # Add floor + floor_mesh = normal_mesh(Rect3f(Vec3f(-100, -100, -2), Vec3f(200, 200, 2))) + push!(geometries, floor_mesh) + push!(materials, Material(RGB(0.95f0, 0.95f0, 0.95f0), 0.0f0, 0.9f0, 1.0f0, 0.0f0)) + + println("\nBuilding TLAS...") + tlas = Raycore.TLAS(KernelAbstractions.CPU()) + handles = Dict{String, Raycore.TLASHandle}() + for (i, part_name) in enumerate(PART_ORDER[1:end-1]) + handles[part_name] = push!(tlas, geometries[i]) + end + handles["floor"] = push!(tlas, floor_mesh) + Raycore.sync!(tlas) + println(" Instances: $(Raycore.n_instances(tlas))") + + lights = [ + PointLight(Point3f(50, 0, 100), 8000.0f0, RGB(1.0f0, 0.95f0, 0.9f0)), + PointLight(Point3f(-30, 40, 60), 3000.0f0, RGB(0.8f0, 0.85f0, 1.0f0)), + PointLight(Point3f(0, -50, 80), 2000.0f0, RGB(1.0f0, 1.0f0, 1.0f0)), + ] + + ctx = RenderContext(lights, materials, 0.15f0) + + return tlas, handles, ctx +end + +""" +Compute rotation around a joint pivot point. +Returns transform that rotates around the pivot. +""" +function joint_rotation(pivot::Point3f, axis::Vec3f, angle::Float32)::Mat4f + # Translate to pivot, rotate, translate back + translation(Vec3f(pivot)) * rotation_axis(axis, angle) * translation(-Vec3f(pivot)) +end + +""" +Update all instance transforms for the walking animation. + +joint_angles: Dict mapping joint names to rotation angles +figure_pos: Overall figure position (x translation for walking) +""" +function update_walking_pose!(tlas, handles::Dict{String, Raycore.TLASHandle}, + joint_angles::Dict{String, Float32}, figure_pos::Vec3f) + transforms = Dict{String, Mat4f}() + + base_transform = translation(figure_pos + Vec3f(0, 0, 20)) + transforms["torso"] = base_transform + + for part_name in PART_ORDER[1:end-1] + part_name == "torso" && continue + parent_name = get(PART_PARENTS, part_name, "torso") + transform = get(transforms, parent_name, base_transform) + if haskey(JOINT_ORIGINS, part_name) && haskey(joint_angles, part_name) + transform = transform * joint_rotation(JOINT_ORIGINS[part_name], + ROTATION_AXES[part_name], + joint_angles[part_name]) + end + transforms[part_name] = transform + end + + for part_name in PART_ORDER + t = part_name == "floor" ? IDENTITY : transforms[part_name] + Raycore.update_transform!(tlas, handles[part_name], t) + end + Raycore.sync!(tlas) +end + +""" +Generate the walking angle sequence following the RPRMakie reference. + +Returns a vector of angles: [0→max, max→0, 0→-max, -max→0] +All limbs use the SAME angle - the rotation axes have opposite signs for +left/right limbs, so applying the same angle creates natural alternating motion. +""" +function generate_walk_cycle()::Vector{Float32} + rot_joints_by = 0.25f0 * Float32(pi) + animation_strides = 10 + + a1 = collect(LinRange(0.0f0, rot_joints_by, animation_strides)) + return Vector{Float32}(vcat( + a1, + reverse(a1[1:end-1]), + -a1[2:end], + reverse(-a1[1:end-1]) + )) +end + +""" +Get joint angles for a specific frame in the walk cycle. +All limbs get the same angle - the axes handle the opposition. +""" +function walking_angles(frame_idx::Int, angle_sequence::Vector{Float32})::Dict{String, Float32} + idx = mod1(frame_idx, length(angle_sequence)) + angle = angle_sequence[idx] + + return Dict{String, Float32}( + "arm_right" => angle, + "arm_left" => angle, + "leg_right" => angle, + "leg_left" => angle, + ) +end + +# ============================================================================== +# Main Animation +# ============================================================================== + +println("\n" * "="^70) +println("Animated Lego Figure - TLAS Dynamic Scene Demo") +println("="^70) + +# Create scene with adjusted lighting (dimmer for better visuals) +tlas, ctx_original = create_lego_scene() + +# Create dimmer lighting context +lights_dim = [ + PointLight(Point3f(50, 0, 100), 5000.0f0, RGB(1.0f0, 0.95f0, 0.9f0)), + PointLight(Point3f(-30, 40, 60), 2000.0f0, RGB(0.8f0, 0.85f0, 1.0f0)), + PointLight(Point3f(0, -50, 80), 1500.0f0, RGB(1.0f0, 1.0f0, 1.0f0)), +] +ctx = RenderContext(lights_dim, ctx_original.materials, 0.12f0) + +# Generate walk cycle angles (following RPRMakie reference) +angle_sequence = generate_walk_cycle() +nsteps = length(angle_sequence) + +# Animation parameters - match RPRMakie +total_translation = 50.0f0 +num_frames = nsteps # One frame per angle step +width, height = 1920, 1080 +frames_dir = "lego_walk_frames" + +# Camera positioned like RPRMakie: Vec3f(100, 30, 80) looking at Vec3f(0, 0, -10) +# Figure walks in +X direction toward the camera +camera_pos = Point3f(100, 30, 80) +camera_lookat = Point3f(0, 0, 10) # Look at figure height + +println("\nRendering $num_frames frames...") +println(" Resolution: $(width)x$(height)") +println(" Output: $frames_dir/") +println(" Walk cycle: $nsteps steps") + +# Create output directory +rm(frames_dir; recursive=true) +mkpath(frames_dir) + +# Set initial pose +update_walking_pose!(tlas, walking_angles(1, angle_sequence), Vec3f(0, 0, 0)) + +# Warmup render +println("\nWarmup render...") +img_warmup = fill(RGBf(0, 0, 0), height, width) +renderer = WavefrontRenderer(img_warmup, tlas, ctx; + camera_pos=camera_pos, + camera_lookat=camera_lookat, + camera_up=Vec3f(0, 0, 1), + fov=40.0f0, + sky_color=RGB{Float32}(0.6f0, 0.75f0, 0.95f0), + samples_per_pixel=8 +) +render!(renderer) + +# Animation: figure walks forward toward camera (like RPRMakie) +println("\nRendering animation...") +translations = collect(LinRange(0.0f0, total_translation, nsteps)) + +total_time = @elapsed begin + for frame in 1:num_frames + # Get angle and translation for this frame + angles = walking_angles(frame, angle_sequence) + pos_x = translations[frame] + + # Update pose + update_walking_pose!(tlas, angles, Vec3f(pos_x, 0, 0)) + + # Render frame with fixed camera (figure walks toward it) + fill!(img_warmup, RGBf(0, 0, 0)) + render!(renderer) + + # Save frame with padded number + filename = joinpath(frames_dir, "frame_$(lpad(frame, 4, '0')).png") + save(filename, map(clamp01nan, img_warmup)) + + if frame % 10 == 0 || frame == 1 + println(" Frame $frame/$num_frames") + end + end +end + +avg_fps = num_frames / total_time +println("\nAnimation complete!") +println(" Total time: $(round(total_time, digits=2))s") +println(" Average FPS: $(round(avg_fps, digits=2))") + +# Create video using FFMPEG +println("\nCreating video...") +using FFMPEG_jll +video_output = "lego_walk.mp4" +run(`$(FFMPEG_jll.ffmpeg()) -y -framerate 30 -i $(frames_dir)/frame_%04d.png -c:v libx264 -pix_fmt yuv420p $video_output`) +println("Video saved: $video_output") + +# Performance comparison +println("\n" * "="^70) +println("Performance: TLAS Refit vs Full Rebuild") +println("="^70) + +using BenchmarkTools + +# Benchmark pose update + refit +refit_time = @belapsed begin + angles = walking_angles(15, $angle_sequence) + update_walking_pose!($tlas, angles, Vec3f(15, 0, 0)) +end + +println(" Pose update + refit: $(round(refit_time * 1000, digits=3)) ms") + +# Benchmark full rebuild +geometries_rebuild = vcat( + [load_lego_part(p) for p in PART_ORDER[1:end-1]], + [normal_mesh(Rect3f(Vec3f(-100, -100, -2), Vec3f(200, 200, 2)))] +) + +rebuild_time = @belapsed begin + Raycore.TLAS($geometries_rebuild, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) +end + +println(" Full TLAS rebuild: $(round(rebuild_time * 1000, digits=3)) ms") +println(" Speedup: $(round(rebuild_time / refit_time, digits=1))x faster with refit!") + +println("\nDone!") diff --git a/docs/src/wavefront_particles.jl b/docs/src/wavefront_particles.jl new file mode 100644 index 0000000..4be3f46 --- /dev/null +++ b/docs/src/wavefront_particles.jl @@ -0,0 +1,505 @@ +# ============================================================================== +# Particle Simulation - TLAS Instancing Demo with 10k+ Particles +# ============================================================================== +# +# This example demonstrates massive instancing with TLAS: +# - 10,000+ sphere particles sharing a single BLAS +# - Dynamic transforms (position updates each frame) via the handle-based +# `update_transforms!` + `sync!` API +# - Material changes based on particle velocity (heating effect) +# - Efficient TLAS refit instead of full rebuild +# + +using Revise +using Raycore +using Raycore: Mat4f +using KernelAbstractions +using GeometryBasics, Colors, LinearAlgebra +import Makie +using Makie: RGBf +import KernelAbstractions as KA +using ImageCore +using FileIO + +# Load helper functions +include("raytracing-core.jl") +include("wavefront-renderer.jl") + +# ============================================================================== +# Particle System +# ============================================================================== + +struct Particle + position::Point3f + velocity::Vec3f + radius::Float32 +end + +""" +Particle system - works on CPU (Vector) or GPU (ROCArray, CuArray). +The transforms field is used for TLAS updates. +""" +struct ParticleSystem{ArrParticle, ArrMat4} + particles::ArrParticle + transforms::ArrMat4 + bounds_min::Point3f + bounds_max::Point3f + gravity::Vec3f + damping::Float32 +end + +"""Create a particle system with random initial positions and velocities.""" +function ParticleSystem(n_particles::Int; + bounds_min=Point3f(-50, -50, 0), + bounds_max=Point3f(50, 50, 100), + radius_range=(0.3f0, 0.8f0)) + + particles = Particle[] + + for _ in 1:n_particles + pos = Point3f( + bounds_min[1] + rand(Float32) * (bounds_max[1] - bounds_min[1]), + bounds_min[2] + rand(Float32) * (bounds_max[2] - bounds_min[2]), + bounds_min[3] + rand(Float32) * (bounds_max[3] - bounds_min[3]) + ) + vel = Vec3f( + (rand(Float32) - 0.5f0) * 20, + (rand(Float32) - 0.5f0) * 20, + rand(Float32) * 30 + 10 + ) + r = radius_range[1] + rand(Float32) * (radius_range[2] - radius_range[1]) + push!(particles, Particle(pos, vel, r)) + end + + # Build initial transforms + transforms = [translation(Vec3f(p.position)) * scale(p.radius) for p in particles] + + ParticleSystem(particles, transforms, bounds_min, bounds_max, Vec3f(0, 0, -30), 0.98f0) +end + +"""Convert particle system to GPU.""" +function Raycore.to_gpu(ArrayType, ps::ParticleSystem) + ParticleSystem( + ArrayType(ps.particles), + ArrayType(ps.transforms), + ps.bounds_min, + ps.bounds_max, + ps.gravity, + ps.damping + ) +end + +# ============================================================================== +# GPU Kernels for ParticleSystem +# ============================================================================== + +"""GPU kernel: Physics step - update particles with gravity and boundary bouncing.""" +KA.@kernel function particle_physics_kernel!( + particles, + bounds_min::Point3f, + bounds_max::Point3f, + gravity::Vec3f, + damping::Float32, + dt::Float32 +) + i = @index(Global, Linear) + @inbounds begin + p = particles[i] + pos = Vec3f(p.position...) + vel = p.velocity + r = p.radius + + new_vel = (vel + gravity * dt) * damping + new_pos = pos + new_vel * dt + + # Bounce off boundaries + if new_pos[1] - r < bounds_min[1] + new_pos = Vec3f(bounds_min[1] + r, new_pos[2], new_pos[3]) + new_vel = Vec3f(-new_vel[1] * 0.8f0, new_vel[2], new_vel[3]) + elseif new_pos[1] + r > bounds_max[1] + new_pos = Vec3f(bounds_max[1] - r, new_pos[2], new_pos[3]) + new_vel = Vec3f(-new_vel[1] * 0.8f0, new_vel[2], new_vel[3]) + end + + if new_pos[2] - r < bounds_min[2] + new_pos = Vec3f(new_pos[1], bounds_min[2] + r, new_pos[3]) + new_vel = Vec3f(new_vel[1], -new_vel[2] * 0.8f0, new_vel[3]) + elseif new_pos[2] + r > bounds_max[2] + new_pos = Vec3f(new_pos[1], bounds_max[2] - r, new_pos[3]) + new_vel = Vec3f(new_vel[1], -new_vel[2] * 0.8f0, new_vel[3]) + end + + if new_pos[3] - r < bounds_min[3] + new_pos = Vec3f(new_pos[1], new_pos[2], bounds_min[3] + r) + new_vel = Vec3f(new_vel[1], new_vel[2], -new_vel[3] * 0.8f0) + elseif new_pos[3] + r > bounds_max[3] + new_pos = Vec3f(new_pos[1], new_pos[2], bounds_max[3] - r) + new_vel = Vec3f(new_vel[1], new_vel[2], -new_vel[3] * 0.8f0) + end + + particles[i] = Particle(Point3f(new_pos...), new_vel, r) + end +end + +"""GPU kernel: Build transform matrices from particles.""" +KA.@kernel function build_transforms_kernel!( + transforms, + @Const(particles) +) + i = @index(Global, Linear) + @inbounds begin + p = particles[i] + pos = p.position + r = p.radius + # Translation * Scale matrix (column-major) + transforms[i] = Mat4f( + r, 0, 0, 0, + 0, r, 0, 0, + 0, 0, r, 0, + pos[1], pos[2], pos[3], 1 + ) + end +end + +"""Step particle physics (works on CPU or GPU via KernelAbstractions).""" +function step!(ps::ParticleSystem, dt::Float32) + n = length(ps.particles) + backend = KA.get_backend(ps.particles) + kernel! = particle_physics_kernel!(backend) + kernel!(ps.particles, ps.bounds_min, ps.bounds_max, ps.gravity, ps.damping, dt, ndrange=n) + return nothing +end + +"""Build transforms from current particle state.""" +function build_transforms!(ps::ParticleSystem) + n = length(ps.particles) + backend = KA.get_backend(ps.particles) + kernel! = build_transforms_kernel!(backend) + kernel!(ps.transforms, ps.particles, ndrange=n) + return nothing +end + +""" +Update sphere material color based on the average particle speed. +Particles all share a single material slot (face_meta=1), so the +heat-color effect drives a single material rather than per-particle. +""" +function update_sphere_material!(materials, ps::ParticleSystem; max_speed::Float32=50.0f0) + # Pull velocity of one representative particle from the GPU + p = Array(ps.particles[1:1])[1] + vel = p.velocity + speed = sqrt(vel[1]^2 + vel[2]^2 + vel[3]^2) + t = clamp(speed / max_speed, 0.0f0, 1.0f0) + + color = if t < 0.25f0 + s = t / 0.25f0 + RGB{Float32}(0.1f0, 0.2f0 + 0.5f0 * s, 0.8f0) + elseif t < 0.5f0 + s = (t - 0.25f0) / 0.25f0 + RGB{Float32}(0.1f0 + 0.6f0 * s, 0.7f0, 0.8f0 - 0.6f0 * s) + elseif t < 0.75f0 + s = (t - 0.5f0) / 0.25f0 + RGB{Float32}(0.7f0 + 0.3f0 * s, 0.7f0 - 0.4f0 * s, 0.2f0 - 0.1f0 * s) + else + s = (t - 0.75f0) / 0.25f0 + RGB{Float32}(1.0f0, 0.3f0 + 0.7f0 * s, 0.1f0 + 0.9f0 * s) + end + + new_mat = Material(color, 0.6f0, 0.3f0, 1.0f0, 0.0f0) + # Slot 1 is the sphere material (see create_particle_scene) + materials[1:1] .= [new_mat] + return nothing +end + +""" +Full GPU update: physics -> transforms -> TLAS update -> sphere material refresh. + +`tlas` — the mutable TLAS owning the sphere instances. +`sphere_handle` — the TLASHandle returned when the sphere was pushed. +""" +function update_gpu!(renderer_gpu, tlas, sphere_handle, ps::ParticleSystem, + dt::Float32; max_speed::Float32=50.0f0) + # 1. Physics step + step!(ps, dt) + + # 2. Build transforms from new positions (writes into ps.transforms) + build_transforms!(ps) + + # 3. Stage new transforms for the sphere instance group + Raycore.update_transforms!(tlas, sphere_handle, ps.transforms) + + # 4. Commit: refit the TLAS BVH in place (no topology change → fast refit path) + Raycore.sync!(tlas) + + # 5. Refresh the sphere's material to reflect speed + update_sphere_material!(renderer_gpu.ctx.materials, ps; max_speed=max_speed) + + return nothing +end + +# ============================================================================== +# Scene Creation with Instanced Spheres +# ============================================================================== + +"""Create a translation matrix.""" +function translation(v::Union{Vec3f, Point3f})::Mat4f + Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + v[1], v[2], v[3], 1 + ) +end + +"""Create a uniform scale matrix.""" +function scale(s::Float32)::Mat4f + Mat4f( + s, 0, 0, 0, + 0, s, 0, 0, + 0, 0, s, 0, + 0, 0, 0, 1 + ) +end + +"""Map velocity magnitude to a heat color (blue -> red -> yellow -> white).""" +function velocity_to_color(speed::Float32, max_speed::Float32=50.0f0)::RGB{Float32} + t = clamp(speed / max_speed, 0.0f0, 1.0f0) + if t < 0.25f0 + s = t / 0.25f0 + RGB{Float32}(0.1f0, 0.2f0 + 0.5f0 * s, 0.8f0) + elseif t < 0.5f0 + s = (t - 0.25f0) / 0.25f0 + RGB{Float32}(0.1f0 + 0.6f0 * s, 0.7f0, 0.8f0 - 0.6f0 * s) + elseif t < 0.75f0 + s = (t - 0.5f0) / 0.25f0 + RGB{Float32}(0.7f0 + 0.3f0 * s, 0.7f0 - 0.4f0 * s, 0.2f0 - 0.1f0 * s) + else + s = (t - 0.75f0) / 0.25f0 + RGB{Float32}(1.0f0, 0.3f0 + 0.7f0 * s, 0.1f0 + 0.9f0 * s) + end +end + +""" +Tag every face of `mesh` with the same metadata value `meta`. + +Used so that all triangles of a given BLAS look up the same material slot +(`materials[meta]`) inside the renderer's `tri.metadata` path. +""" +function mesh_with_const_meta(mesh::GeometryBasics.Mesh, meta::UInt32) + fs = decompose(TriangleFace{UInt32}, mesh) + n_faces = length(fs) + face_meta = fill(meta, n_faces) + return GeometryBasics.mesh(mesh; face_meta=GeometryBasics.per_face(face_meta, mesh)) +end + +""" +Create a TLAS with instanced spheres for the particle system. + +Architecture: +- Single unit-sphere BLAS, instanced once per particle (one `push!` with all + initial transforms returns one `TLASHandle` covering the whole batch). +- Three static wall BLASes (floor, back wall, left wall) at identity transform. +- Each BLAS has `face_meta` set to a constant so the renderer's + `materials[tri.metadata]` lookup stays in-bounds: + 1 → all sphere triangles + 2 → floor + 3 → back wall + 4 → left wall + +Returns `(tlas, sphere_handle, ctx, materials)`. Hand `sphere_handle` to +`update_transforms!` each frame. +""" +function create_particle_scene(ps::ParticleSystem; backend=KA.CPU()) + n_particles = length(ps.particles) + println("Creating particle scene with $n_particles particles...") + + # ----- Geometry ----- + unit_sphere = normal_mesh(Tesselation(Sphere(Point3f(0), 1.0f0), 16)) + sphere_tagged = mesh_with_const_meta(unit_sphere, UInt32(1)) + + bmin = ps.bounds_min + bmax = ps.bounds_max + pad = 10.0f0 + wall_thickness = 1.0f0 + + floor_mesh = normal_mesh(Rect3f( + Vec3f(bmin[1] - pad, bmin[2] - pad, bmin[3] - wall_thickness), + Vec3f(bmax[1] - bmin[1] + 2pad, bmax[2] - bmin[2] + 2pad, wall_thickness) + )) + back_wall_mesh = normal_mesh(Rect3f( + Vec3f(bmin[1] - pad, bmin[2] - pad - wall_thickness, bmin[3]), + Vec3f(bmax[1] - bmin[1] + 2pad, wall_thickness, bmax[3] - bmin[3] + pad) + )) + left_wall_mesh = normal_mesh(Rect3f( + Vec3f(bmin[1] - pad - wall_thickness, bmin[2] - pad, bmin[3]), + Vec3f(wall_thickness, bmax[2] - bmin[2] + 2pad, bmax[3] - bmin[3] + pad) + )) + floor_tagged = mesh_with_const_meta(floor_mesh, UInt32(2)) + back_tagged = mesh_with_const_meta(back_wall_mesh, UInt32(3)) + left_tagged = mesh_with_const_meta(left_wall_mesh, UInt32(4)) + + # ----- TLAS ----- + println("Building TLAS on backend $backend...") + tlas = Raycore.TLAS(backend) + + # Initial sphere transforms (CPU; push! adapts to backend internally) + initial_transforms = Mat4f[ + translation(Vec3f(p.position)) * scale(p.radius) + for p in ps.particles + ] + + sphere_handle = push!(tlas, sphere_tagged, initial_transforms) + push!(tlas, floor_tagged) + push!(tlas, back_tagged) + push!(tlas, left_tagged) + + Raycore.sync!(tlas) + println(" TLAS instances: $(Raycore.n_instances(tlas)) geometries: $(Raycore.n_geometries(tlas))") + + # ----- Materials (one slot per face_meta tag) ----- + p1 = ps.particles[1] + sphere_mat = Material(velocity_to_color(norm(p1.velocity)), 0.6f0, 0.3f0, 1.0f0, 0.0f0) + floor_mat = Material(RGB(0.9f0, 0.9f0, 0.92f0), 1.0f0, 0.05f0, 1.0f0, 0.0f0) + back_mat = Material(RGB(0.8f0, 0.75f0, 0.7f0), 0.9f0, 0.1f0, 1.0f0, 0.0f0) + left_mat = Material(RGB(0.95f0, 0.95f0, 0.95f0), 1.0f0, 0.02f0, 1.0f0, 0.0f0) + materials = [sphere_mat, floor_mat, back_mat, left_mat] + + # ----- Lights ----- + lights = [ + PointLight(Point3f(0, 0, 120), 3000.0f0, RGB(1.0f0, 0.98f0, 0.95f0)), + PointLight(Point3f(80, 70, 50), 15000.0f0, RGB(1.0f0, 1.0f0, 1.0f0)), + PointLight(Point3f(-60, 40, 60), 8000.0f0, RGB(0.9f0, 0.92f0, 1.0f0)), + ] + + ctx = RenderContext(lights, materials, 0.25f0) + return tlas, sphere_handle, ctx, materials +end + +# ============================================================================== +# Main Animation +# ============================================================================== + +println("\n" * "="^70) +println("Particle Simulation - 10k Instanced Spheres Demo") +println("="^70) + +# Create particle system +n_particles = 10_000 +println("\nInitializing $n_particles particles...") +ps = ParticleSystem(n_particles; + bounds_min=Point3f(-40, -40, 0), + bounds_max=Point3f(40, 40, 80) +) + +# Backend selection — TLAS arrays must live on the same backend the renderer +# uses, otherwise refit kernels and intersect kernels would talk to different +# memory spaces. +using AMDGPU: ROCBackend +backend = ROCBackend() + +# Create scene on the GPU backend so refits update arrays in place that the +# renderer's StaticTLAS reads. +tlas, sphere_handle, ctx, materials = create_particle_scene(ps; backend=backend) + +# Animation parameters +num_frames = 120 +dt = 1.0f0 / 30.0f0 +width, height = 1920, 1080 +frames_dir = "particle_frames" + +camera_pos = Point3f(70, 55, 45) +camera_lookat = Point3f(0, 0, 30) + +println("\nRendering $num_frames frames...") +println(" Resolution: $(width)x$(height)") +println(" Particles: $n_particles") +println(" Output: $frames_dir/") + +isdir(frames_dir) && rm(frames_dir; recursive=true) +mkpath(frames_dir) + +# Create renderer +img = fill(RGBf(0, 0, 0), height, width) +renderer = WavefrontRenderer(img, tlas, ctx; + camera_pos=camera_pos, + camera_lookat=camera_lookat, + camera_up=Vec3f(0, 0, 1), + fov=50.0f0, + sky_color=RGB{Float32}(0.05f0, 0.05f0, 0.1f0), + samples_per_pixel=8 +); +renderer_gpu = to_gpu(ROCArray, renderer); + +# Warmup +println("\nWarmup render...") +Array(render!(renderer_gpu)) + +# GPU particle system +ps_gpu = to_gpu(ROCArray, ps) +println("GPU particle system created") + +# Animation loop +println("\nRendering animation...") +total_time = @elapsed begin + for frame in 1:num_frames + update_gpu!(renderer_gpu, tlas, sphere_handle, ps_gpu, dt) + + fill!(renderer_gpu.framebuffer, RGBf(0, 0, 0)) + render!(renderer_gpu) + + filename = joinpath(frames_dir, "frame_$(lpad(frame, 4, '0')).png") + save(filename, Array(map(clamp01nan, renderer_gpu.framebuffer))) + + if frame % 20 == 0 || frame == 1 + println(" Frame $frame/$num_frames") + end + end +end + +avg_fps = num_frames / total_time +println("\nAnimation complete!") +println(" Total time: $(round(total_time, digits=2))s") +println(" Average FPS: $(round(avg_fps, digits=2))") +println(" Time per frame: $(round(total_time/num_frames*1000, digits=1))ms") + +# Create video +println("\nCreating video...") +using FFMPEG_jll +video_output = "particles.mp4" +run(`$(FFMPEG_jll.ffmpeg()) -y -framerate 30 -i $(frames_dir)/frame_%04d.png -c:v libx264 -pix_fmt yuv420p -crf 18 $video_output`) +println("Video saved: $video_output") + +# Performance stats +println("\n" * "="^70) +println("Performance Statistics (GPU)") +println("="^70) + +using BenchmarkTools + +gpu_physics_time = @belapsed step!($ps_gpu, $dt) +println(" GPU Physics step: $(round(gpu_physics_time * 1000, digits=3)) ms") + +gpu_transform_time = @belapsed build_transforms!($ps_gpu) +println(" GPU Transform build: $(round(gpu_transform_time * 1000, digits=3)) ms") + +gpu_update_time = @belapsed update_gpu!($renderer_gpu, $tlas, $sphere_handle, $ps_gpu, $dt) +println(" GPU Full update: $(round(gpu_update_time * 1000, digits=3)) ms") + +gpu_render_time = @belapsed begin + fill!($renderer_gpu.framebuffer, RGBf(0, 0, 0)) + render!($renderer_gpu) +end +println(" GPU Render: $(round(gpu_render_time * 1000, digits=1)) ms") + +total_frame_time = gpu_update_time + gpu_render_time +println("\n Total frame time: $(round(total_frame_time * 1000, digits=1)) ms") +println(" Theoretical max FPS: $(round(1.0 / total_frame_time, digits=1))") + +# Per-mesh stats from the mutable TLAS +sphere_blas = tlas.blas_storage[1] +floor_blas = tlas.blas_storage[2] +println("\n Total triangles: $(n_particles * length(sphere_blas.primitives) + length(floor_blas.primitives))") +println(" Unique BLAS geometries: $(length(tlas.blas_storage))") +println(" Memory saved by instancing: ~$(round(n_particles * length(sphere_blas.primitives) * 48 / 1024 / 1024, digits=1)) MB") + +println("\nDone!") diff --git a/ext/RaycoreMakieExt.jl b/ext/RaycoreMakieExt.jl index 7375582..6df259c 100644 --- a/ext/RaycoreMakieExt.jl +++ b/ext/RaycoreMakieExt.jl @@ -3,53 +3,110 @@ module RaycoreMakieExt using Raycore using Makie using GeometryBasics +using Adapt import Makie: plot, plot! +# ============================================================================ +# TLAS → Mesh conversion (plot geometry directly) +# ============================================================================ + +Makie.plottype(::Raycore.TLAS) = Makie.Mesh +Makie.plottype(::Raycore.TLAS4) = Makie.Mesh + +function Makie.convert_arguments(::Type{Makie.Mesh}, tlas::Union{Raycore.TLAS, Raycore.TLAS4}) + vertices = Point3f[] + faces = GeometryBasics.TriangleFace{Int}[] + colors = Float32[] + normals = Vec3f[] + + metadata_to_color = Dict{Any, Float32}() + next_color_idx = Ref(0f0) + + function get_color_for_metadata(meta) + get!(metadata_to_color, meta) do + next_color_idx[] += 1f0 + next_color_idx[] + end + end + + for blas in tlas.blas_array + for prim in blas.primitives + start_idx = length(vertices) + color_val = get_color_for_metadata(prim.metadata) + for (v, n) in zip(prim.vertices, prim.normals) + push!(vertices, v) + push!(colors, color_val) + push!(normals, Vec3f(n)) + end + push!(faces, GeometryBasics.TriangleFace(start_idx + 1, start_idx + 2, start_idx + 3)) + end + end + return (GeometryBasics.Mesh(vertices, faces; normal=normals, color=colors), ) +end + +# ============================================================================ +# RayPlot recipe — visualize rays traced through a TLAS +# ============================================================================ + +""" + RayIntersectionResult + +Stores rays and their intersection results for visualization. +Created by [`trace_rays`](@ref). +""" +struct RayIntersectionResult + rays::Vector{Raycore.Ray} + hits::Vector{Tuple{Bool, Raycore.Triangle, Float32, GeometryBasics.Vec{3,Float32}, UInt32}} + tlas::Raycore.TLAS +end + +""" + trace_rays(tlas::Raycore.TLAS, rays::AbstractVector{Raycore.Ray}) + +Trace rays against a TLAS and return a `RayIntersectionResult` for visualization. + +# Example +```julia +using Raycore, RayMakie + +tlas = TLAS(KA.CPU()) +push!(tlas, mesh) +sync!(tlas) + +rays = [Raycore.Ray(o=Point3f(0,0,-5), d=Vec3f(0,0,1))] +result = trace_rays(tlas, rays) +plot(result) +``` +""" +function Raycore.trace_rays(tlas::Raycore.TLAS, rays::AbstractVector{<:Raycore.AbstractRay}) + static_tlas = Adapt.adapt(tlas.backend, tlas) + hits = map(rays) do ray + Raycore.closest_hit(static_tlas, ray) + end + RayIntersectionResult(collect(rays), collect(hits), tlas) +end + """ - plot(session::RayIntersectionSession; kwargs...) + plot(result::RayIntersectionResult; kwargs...) -Makie recipe for visualizing a RayIntersectionSession. +Makie recipe for visualizing ray intersection results. # Keyword Arguments -- `show_bvh::Bool = true`: Whether to show the BVH geometry -- `bvh_alpha::Float64 = 0.4`: Transparency for BVH meshes -- `bvh_colors = [:red, :yellow, :blue]`: Colors to cycle through for different meshes -- `ray_colors = nothing`: Colors for rays. If `nothing`, uses a gradient based on hit distance -- `ray_color::Symbol = :black`: Default color for all rays if `ray_colors` is `nothing` +- `show_geometry::Bool = true`: Whether to show the TLAS geometry +- `geometry_alpha::Float64 = 0.4`: Transparency for geometry meshes +- `ray_color::Symbol = :green`: Default color for hit rays - `hit_color::Symbol = :green`: Color for hit point markers -- `miss_color::Symbol = :gray`: Color for rays that missed +- `miss_color = (:gray, 0.5)`: Color for rays that missed - `ray_length::Float32 = 15.0f0`: Length to draw rays that miss - `show_hit_points::Bool = true`: Whether to show markers at hit points -- `hit_markersize::Float64 = 0.2`: Size of hit point markers +- `hit_markersize::Float64 = 0.1`: Size of hit point markers - `show_labels::Bool = false`: Whether to show text labels at hit points -- `axis = nothing`: Optional axis to draw on (if not provided, creates new figure) - -# Example -```julia -using Raycore, GeometryBasics, GLMakie - -# Create geometry -sphere1 = Tesselation(Sphere(Point3f(0, 0, 1), 1.0f0), 20) -sphere2 = Tesselation(Sphere(Point3f(0, 0, 3), 1.0f0), 20) -bvh = Raycore.BVH([sphere1, sphere2]) - -# Create rays -rays = [ - Raycore.Ray(Point3f(0, 0, -5), Vec3f(0, 0, 1)), - Raycore.Ray(Point3f(1, 0, -5), Vec3f(0, 0, 1)), -] - -# Create and visualize session -session = RayIntersectionSession(rays, bvh, Raycore.closest_hit) -plot(session) -``` """ -@recipe(RayPlot, session) do scene +@recipe(RayPlot, result) do scene Attributes( - show_bvh = true, - bvh_alpha = 1.0, - bvh_colors = Makie.wong_colors(), - ray_colors = nothing, + show_geometry = true, + geometry_alpha = 0.4, + geometry_colors = Makie.wong_colors(), ray_color = :green, hit_color = :green, miss_color = (:gray, 0.5), @@ -60,17 +117,14 @@ plot(session) ) end -Makie.plottype(::Raycore.RayIntersectionSession) = RayPlot +Makie.plottype(::RayIntersectionResult) = RayPlot Makie.preferred_axis_type(::RayPlot) = LScene function Makie.plot!(plot::RayPlot) - session = plot[:session][] + result = plot[:result][] - # Extract attributes - show_bvh = plot[:show_bvh][] - bvh_alpha = plot[:bvh_alpha][] - bvh_colors = plot[:bvh_colors][] - ray_colors = plot[:ray_colors][] + show_geometry = plot[:show_geometry][] + geometry_alpha = plot[:geometry_alpha][] ray_color = plot[:ray_color][] hit_color = plot[:hit_color][] miss_color = plot[:miss_color][] @@ -79,18 +133,13 @@ function Makie.plot!(plot::RayPlot) hit_markersize = plot[:hit_markersize][] show_labels = plot[:show_labels][] - # Draw BVH if requested - if show_bvh - draw_bvh!(plot, session.bvh, bvh_colors, bvh_alpha) + # Draw geometry + if show_geometry + geo_mesh = Makie.convert_arguments(Makie.Mesh, result.tlas)[1] + mesh!(plot, geo_mesh; alpha=geometry_alpha) end - # Determine ray colors if not provided - if isnothing(ray_colors) - # Use single color for all rays - ray_colors = fill(ray_color, length(session.rays)) - end - - # Collect all data for batch rendering + # Classify rays into hits and misses hit_ray_starts = Point3f[] hit_ray_directions = Vec3f[] hit_ray_colors = [] @@ -102,110 +151,54 @@ function Makie.plot!(plot::RayPlot) hit_labels_pos = Point3f[] hit_labels_text = String[] - for (i, (ray, hit)) in enumerate(zip(session.rays, session.hits)) - hit_found, hit_primitive, distance, bary_coords = hit - - # Get color for this ray - color = i <= length(ray_colors) ? ray_colors[i] : ray_color + for (i, (ray, hit)) in enumerate(zip(result.rays, result.hits)) + hit_found, hit_triangle, distance, bary_coords, instance_id = hit if hit_found - # Calculate hit point - hit_point = Raycore.sum_mul(bary_coords, hit_primitive.vertices) + hit_point = sum(bary_coords .* hit_triangle.vertices) - # Collect ray data push!(hit_ray_starts, ray.o) push!(hit_ray_directions, hit_point - ray.o) - push!(hit_ray_colors, color) + push!(hit_ray_colors, ray_color) - # Collect hit point data if show_hit_points push!(hit_points_pos, hit_point) - # Collect label data if show_labels push!(hit_labels_pos, hit_point .+ Vec3f(0.2, 0.2, 0.2)) push!(hit_labels_text, "Hit $i\nd=$(round(distance, digits=2))") end end else - # Ray missed - collect miss ray data push!(miss_ray_starts, ray.o) push!(miss_ray_directions, ray.d * ray_length) end end - # Draw all hit rays in one call + # Draw hit rays. Makie deprecated `arrows!` in favor of `arrows3d!` for + # 3-D inputs; both arrow batches here are Vec3f (rays in world space). if !isempty(hit_ray_starts) - arrows3d!( - plot, - hit_ray_starts, - hit_ray_directions, - color = hit_ray_colors, - markerscale = 0.3 - ) + arrows3d!(plot, hit_ray_starts, hit_ray_directions, + color=hit_ray_colors, tipradius=0.05f0, tiplength=0.15f0, shaftradius=0.02f0) end - # Draw all miss rays in one call + # Draw miss rays if !isempty(miss_ray_starts) - arrows3d!( - plot, - miss_ray_starts, - miss_ray_directions, - color = miss_color, - markerscale = 0.3 - ) + arrows3d!(plot, miss_ray_starts, miss_ray_directions, + color=miss_color, tipradius=0.05f0, tiplength=0.15f0, shaftradius=0.02f0) end - # Draw all hit points in one call + # Draw hit points if show_hit_points && !isempty(hit_points_pos) - meshscatter!( - plot, - hit_points_pos, - color = hit_color, - markersize = hit_markersize - ) + meshscatter!(plot, hit_points_pos, color=hit_color, markersize=hit_markersize) end - # Draw all labels in one call + # Draw labels if show_labels && !isempty(hit_labels_pos) - text!( - plot, - hit_labels_pos, - text = hit_labels_text, - color = hit_color, - fontsize = 12 - ) + text!(plot, hit_labels_pos, text=hit_labels_text, color=hit_color, fontsize=12) end return plot end -""" -Helper function to draw BVH geometry -""" -function draw_bvh!(plot, bvh::Raycore.BVH, colors, alpha) - # Group primitives by their metadata - mesh!(plot, convert_arguments(Makie.Mesh, bvh)[1]) -end - -Makie.plottype(::Raycore.BVH) = Makie.Mesh - -function Makie.convert_arguments(::Type{Makie.Mesh}, bvh::Raycore.BVH) - # Convert BVH to a Mesh for plotting - vertices = Point3f[] - faces = GeometryBasics.TriangleFace{Int}[] - colors = Float32[] - normals = Vec3f[] - for (i, prim) in enumerate(bvh.primitives) - start_idx = length(vertices) - for (v, n) in zip(prim.vertices, prim.normals) - push!(vertices, v) - push!(colors, Float32(prim.metadata)) - push!(normals, Vec3f(n)) - end - push!(faces, GeometryBasics.TriangleFace(start_idx + 1, start_idx + 2, start_idx + 3)) - end - return (GeometryBasics.Mesh(vertices, faces; normal=normals, color=colors), ) -end - end # module diff --git a/src/Raycore.jl b/src/Raycore.jl index 99a1592..77fd57f 100644 --- a/src/Raycore.jl +++ b/src/Raycore.jl @@ -6,10 +6,47 @@ using StaticArrays using KernelAbstractions import GeometryBasics as GB using Statistics +using Adapt +using GPUArraysCore: @allowscalar abstract type AbstractRay end -abstract type AbstractShape end abstract type Primitive end +""" + AbstractAccel + +Mutable acceleration structure for ray/geometry intersection queries. + +Concrete implementations: +- `Raycore.TLAS` — software BVH/TLAS, runs on any KernelAbstractions backend. +- `Lava.HWTLAS` — hardware ray tracing via `VK_KHR_ray_tracing_pipeline`. + +# Mutation API +- `push!(accel, mesh, transform)`: add geometry, return a `TLASHandle`. +- `delete!(accel, handle)`, `update_transform!(accel, handle, transform)`, + `update_transforms!(accel, handle, transforms)`. + +# Lifecycle +- `sync!(accel)` — sole owner of `accel.static_tlas`. Rebuilds in place + where possible; reassigns when a buffer had to grow. No-op on a clean + accel. Does NOT block the CPU on a GPU fence; backend-internal timeline + tracking handles the "still in flight" case. +- `Adapt.adapt(backend, accel) === accel.static_tlas` between `sync!`s. + Consumers re-read `accel.static_tlas` (or call `Adapt.adapt`) **per + dispatch**. Caching the adapted form across mutations is a contract + violation. + +# Query +- `closest_hit(adapted, ray) -> (hit, tri, t, bary, instance_override)` +- `any_hit(adapted, ray) -> Bool` +- `world_bound(accel)`, `n_instances(accel)`, `n_geometries(accel)`. + +# Flush +- `wait_for_gpu!(accel)` — block CPU until all pending GPU work on this + accel's queue has completed. Convenience for tear-down and benchmark + isolation. Not part of the hot path. +""" +abstract type AbstractAccel end +abstract type AbstractAdaptedAccel end const Maybe{T} = Union{T,Nothing} GB.@fixed_vector Normal = StaticVector @@ -39,24 +76,78 @@ include("bounds.jl") include("transformations.jl") include("math.jl") include("triangle_mesh.jl") -include("bvh.jl") +include("instanced-bvh.jl") +include("instanced-bvh-kernels.jl") +include("bvh4.jl") include("kernel-abstractions.jl") include("kernels.jl") -include("ray_intersection_session.jl") +include("collision.jl") +include("soa.jl") +include("multitypeset.jl") +include("unrolled.jl") +include("rt_transport.jl") + +# Macros +export @_inbounds # Core types -export Ray, RayDifferentials, Triangle, TriangleMesh, AccelPrimitive, BVH, Bounds3, Normal3f +export Ray, RayDifferentials, Triangle, Bounds3, Normal3f, empty_triangle + +# Instanced BVH types +export BLAS, BLASDescriptor, TLAS, InstanceDescriptor, BVHNode2, build_blas, build_tlas, INVALID_NODE +export build_triangle, is_degenerate_face + +# TLAS (GPU two-level acceleration structure) +export TLASHandle, StaticTLAS, INVALID_HANDLE +export sync!, update!, n_total_instances + +# BVH4 types (HIPRT-style 4-wide nodes) +export BVHNode4, BLAS4, TLAS4, build_blas4, closest_hit4, any_hit4 # Ray intersection functions -export closest_hit, any_hit, world_bound +export AbstractAccel, AbstractAdaptedAccel +export closest_hit, any_hit, world_bound, trace_rays +export n_instances, n_geometries, wait_for_gpu! + +# RT transport types (used by Lava.HWTLAS and consumers) +export RTRay, RTHitResult + +# Stubs for Lava/Makie extensions +function trace_rays end + +""" + instance_buffer(tlas, handle::TLASHandle) + +Return the underlying GPU instance buffer (a `Lava.LavaArray{LavaInstanceRecord, 1}`) +that the named batch is using. The caller can write into this buffer (e.g., +via a compute kernel) and then call `refit_tlas!(tlas)` to commit the changes. + +Errors loudly if the handle does not refer to an instance batch (e.g., it +refers to a per-mesh push! instance, which has no GPU instance buffer). +""" +function instance_buffer end + +export instance_buffer # Math utilities export reflect +# Collision detection +export ContactPair, CollisionResult, collide_instances, collide_instances_any + # Analysis functions export get_centroid, get_illumination, view_factors -# Ray intersection session -export RayIntersectionSession, hit_points, hit_distances, hit_count, miss_count +# SoA utilities +export @get, @set, similar_soa + +# GPU-safe unrolled iteration +export FastClosure, for_unrolled, map_unrolled, reduce_unrolled, sum_unrolled, getindex_unrolled + +# MultiTypeSet - type-stable heterogeneous collections +export SetKey, MultiTypeSet, StaticMultiTypeSet, TextureRef +export is_invalid, is_valid, with_index, n_slots, deref, get_static, to_tuple +export maybe_convert_field, store_texture +export free! end diff --git a/src/bounds.jl b/src/bounds.jl index ce971ba..78acfb9 100644 --- a/src/bounds.jl +++ b/src/bounds.jl @@ -148,10 +148,10 @@ function offset(b::Bounds3, p::Point3f) ) end -function bounding_sphere(b::Bounds3)::Tuple{Point3f,Float32} +function bounding_sphere(b::Bounds3)::Sphere{Float32} center = (b.p_min + b.p_max) / 2f0 radius = inside(b, center) ? distance(center, b.p_max) : 0f0 - center, radius + Sphere(center, radius) end function intersect(b::Bounds3, ray::AbstractRay)::Tuple{Bool,Float32,Float32} diff --git a/src/bvh.jl b/src/bvh.jl deleted file mode 100644 index c1789bc..0000000 --- a/src/bvh.jl +++ /dev/null @@ -1,738 +0,0 @@ -abstract type AccelPrimitive <: Primitive end - -struct BVHPrimitiveInfo - primitive_number::UInt32 - bounds::Bounds3 - centroid::Point3f - - function BVHPrimitiveInfo(primitive_number::Integer, bounds::Bounds3) - new( - primitive_number, bounds, - 0.5f0 * bounds.p_min + 0.5f0 * bounds.p_max, - ) - end -end - -struct BVHNode - bounds::Bounds3 - children::Tuple{Maybe{BVHNode},Maybe{BVHNode}} - split_axis::UInt8 - offset::UInt32 - n_primitives::UInt32 - - """ - Construct leaf node. - """ - function BVHNode(offset::Integer, n_primitives::Integer, bounds::Bounds3) - new(bounds, (nothing, nothing), 0, offset, n_primitives) - end - """ - Construct intermediary node. - """ - function BVHNode(axis::Integer, left::BVHNode, right::BVHNode) - new(left.bounds ∪ right.bounds, (left, right), axis, 0, 0) - end -end - -abstract type LinearNode end - -struct LinearBVH <: LinearNode - bounds::Bounds3 - offset::UInt32 - n_primitives::UInt32 - split_axis::UInt8 - is_interior::Bool -end - -function LinearBVHLeaf(bounds::Bounds3, primitives_offset::Integer, n_primitives::Integer) - LinearBVH(bounds, primitives_offset, n_primitives, 0, false) -end - -function LinearBVHInterior(bounds::Bounds3, second_child_offset::Integer, split_axis::Integer) - LinearBVH(bounds, second_child_offset, 0, split_axis, true) -end - -function primitives_to_bvh(primitives, max_node_primitives=1) - max_node_primitives = min(255, max_node_primitives) - isempty(primitives) && return (primitives, max_node_primitives, LinearBVH[]) - primitives_info = [ - BVHPrimitiveInfo(i, world_bound(p)) - for (i, p) in enumerate(primitives) - ] - total_nodes = Ref(0) - ordered_primitives = similar(primitives, 0) - root = _init( - primitives, primitives_info, 1, length(primitives), - total_nodes, ordered_primitives, max_node_primitives, - ) - - offset = Ref{UInt32}(1) - flattened = Vector{LinearBVH}(undef, total_nodes[]) - _unroll(flattened, root, offset) - @real_assert total_nodes[] + 1 == offset[] - return (ordered_primitives, max_node_primitives, flattened) -end - -""" - CompactTriangle - -Pre-transformed triangle data for efficient GPU ray-triangle intersection. -Uses edge vectors for Möller-Trumbore intersection algorithm. - -Fields: -- v0: First vertex position (4th component for alignment) -- edge1: v1 - v0 edge vector -- edge2: v2 - v0 edge vector -- normal: Face normal (4th component for alignment) -""" -struct CompactTriangle - v0::SVector{4, Float32} - edge1::SVector{4, Float32} - edge2::SVector{4, Float32} - normal::SVector{4, Float32} -end - -""" - to_compact_triangle(tri::Triangle) -> CompactTriangle - -Convert a Triangle to CompactTriangle format with pre-computed edge vectors. -""" -@inline function to_compact_triangle(tri::Triangle)::CompactTriangle - vs = vertices(tri) - v0, v1, v2 = vs[1], vs[2], vs[3] - edge1 = v1 - v0 - edge2 = v2 - v0 - - # Compute face normal - face_normal = normalize(cross(Vec3f(edge1), Vec3f(edge2))) - - return CompactTriangle( - SVector{4, Float32}(v0[1], v0[2], v0[3], 0f0), - SVector{4, Float32}(edge1[1], edge1[2], edge1[3], 0f0), - SVector{4, Float32}(edge2[1], edge2[2], edge2[3], 0f0), - SVector{4, Float32}(face_normal[1], face_normal[2], face_normal[3], 0f0), - ) -end - -""" - intersect_compact_triangle(tri::CompactTriangle, ray_o, ray_d, t_max) - -Watertight Möller-Trumbore ray-triangle intersection for GPU. -Uses pre-computed edge vectors for efficiency. - -Returns: (hit, t, u, v) where u,v are barycentric coordinates -""" -@inline function intersect_compact_triangle( - tri::CompactTriangle, - ray_o::Point3f, - ray_d::Vec3f, - t_max::Float32 -) - EPSILON = 1.0f-6 - - # Möller-Trumbore algorithm - work directly with SVector components to avoid allocations - # h = cross(ray_d, edge2) - h_x = ray_d[2] * tri.edge2[3] - ray_d[3] * tri.edge2[2] - h_y = ray_d[3] * tri.edge2[1] - ray_d[1] * tri.edge2[3] - h_z = ray_d[1] * tri.edge2[2] - ray_d[2] * tri.edge2[1] - - # a = dot(edge1, h) - a = tri.edge1[1] * h_x + tri.edge1[2] * h_y + tri.edge1[3] * h_z - - # Check if ray is parallel to triangle - if abs(a) < EPSILON - return (false, t_max, 0f0, 0f0) - end - - f = 1f0 / a - - # s = ray_o - v0 - # s = ray_o - v0 - s_x = ray_o[1] - tri.v0[1] - s_y = ray_o[2] - tri.v0[2] - s_z = ray_o[3] - tri.v0[3] - - # u = f * dot(s, h) - u = f * (s_x * h_x + s_y * h_y + s_z * h_z) - - if u < 0f0 || u > 1f0 - return (false, t_max, 0f0, 0f0) - end - - # q = cross(s, edge1) - q_x = s_y * tri.edge1[3] - s_z * tri.edge1[2] - q_y = s_z * tri.edge1[1] - s_x * tri.edge1[3] - q_z = s_x * tri.edge1[2] - s_y * tri.edge1[1] - - # v = f * dot(ray_d, q) - v = f * (ray_d[1] * q_x + ray_d[2] * q_y + ray_d[3] * q_z) - - if v < 0f0 || u + v > 1f0 - return (false, t_max, 0f0, 0f0) - end - - # t = f * dot(edge2, q) - t = f * (tri.edge2[1] * q_x + tri.edge2[2] * q_y + tri.edge2[3] * q_z) - - if t > EPSILON && t < t_max - return (true, t, u, v) - end - - return (false, t_max, 0f0, 0f0) -end - -""" - BVH{NodeVec, TriVec, OrigTriVec} - -GPU-optimized BVH acceleration structure. - -Key optimizations: -- Uses LinearBVH node structure (flat array, depth-first layout) -- Pre-transforms triangles with edge vectors for Möller-Trumbore -- Designed for efficient GPU kernel traversal - -Fields: -- nodes: LinearBVH nodes (flat array, depth-first layout) -- triangles: Pre-transformed compact triangles -- primitives: Original triangles (for normals, UVs, metadata) -- max_node_primitives: Maximum primitives per leaf node -""" -struct BVH{ - NodeVec <: AbstractVector{LinearBVH}, - TriVec <: AbstractVector{CompactTriangle}, - OrigTriVec <: AbstractVector{<:Triangle} -} <: AccelPrimitive - nodes::NodeVec - triangles::TriVec - primitives::OrigTriVec - max_node_primitives::UInt8 -end - - -to_triangle_mesh(x::TriangleMesh) = x - -function to_triangle_mesh(x::GeometryBasics.AbstractGeometry) - m = GeometryBasics.uv_normal_mesh(x) - return TriangleMesh(m) -end - -""" - BVH(primitives, metadata_fn, max_node_primitives=1) - -Construct a BVH acceleration structure from a list of primitives (meshes or geometries). - -Arguments: -- `primitives`: Vector of triangle meshes or GeometryBasics geometries -- `metadata_fn`: Function `(mesh_index, triangle_index) -> metadata` to generate metadata for each triangle -- `max_node_primitives`: Maximum number of primitives per leaf node (default: 1) - -Returns a GPU-optimized BVH with pre-transformed triangles for efficient ray tracing. - -# Example -```julia -# Simple case: no metadata -bvh = BVH(meshes) - -# With metadata function -bvh = BVH(meshes, (mesh_idx, tri_idx) -> MaterialIndex(UInt8(1), UInt32(mesh_idx))) -``` -""" -function BVH( - primitives::AbstractVector{P}, - metadata_fn::Function, - max_node_primitives::Integer=1, - ) where {P} - # First pass: collect all triangles to determine the metadata type - first_mesh = to_triangle_mesh(first(primitives)) - first_metadata = metadata_fn(1, 1) - TMetadata = typeof(first_metadata) - - triangles = Triangle{TMetadata}[] - for (mi, prim) in enumerate(primitives) - triangle_mesh = to_triangle_mesh(prim) - for i in 1:div(length(triangle_mesh.indices), 3) - metadata = metadata_fn(mi, length(triangles) + 1) - push!(triangles, Triangle(triangle_mesh, i, metadata)) - end - end - ordered_primitives, max_prim, nodes = primitives_to_bvh(triangles, max_node_primitives) - - # Convert triangles to compact format with pre-computed edges - compact_tris = map(to_compact_triangle, ordered_primitives) - - return BVH( - nodes, - compact_tris, - ordered_primitives, - UInt8(max_prim) - ) -end - -# Convenience constructor: metadata defaults to primitive index -function BVH(primitives::AbstractVector{P}, max_node_primitives::Integer=1) where {P} - return BVH(primitives, (_, tri_idx) -> tri_idx, max_node_primitives) -end - -mutable struct BucketInfo - count::UInt32 - bounds::Bounds3 -end - -function _init( - primitives::AbstractVector, primitives_info::Vector{BVHPrimitiveInfo}, - from::Integer, to::Integer, total_nodes::Ref{Int64}, - ordered_primitives::AbstractVector, max_node_primitives::Integer, - ) - - total_nodes[] += 1 - n_primitives = to - from + 1 - # Compute bounds for all primitives in BVH node. - bounds = mapreduce( - i -> primitives_info[i].bounds, ∪, from:to, init = Bounds3(), - ) - @inline function _create_leaf()::BVHNode - first_offset = length(ordered_primitives) + 1 - for i in from:to - push!( - ordered_primitives, - primitives[primitives_info[i].primitive_number], - ) - end - return BVHNode(first_offset, n_primitives, bounds) - end - - n_primitives == 1 && return _create_leaf() - # Compute bound of primitive centroids, choose split dimension. - centroid_bounds = mapreduce( - i -> Bounds3(primitives_info[i].centroid), ∪, from:to, - init = Bounds3(), - ) - dim = maximum_extent(centroid_bounds) - ( # Create leaf node. - !is_valid(centroid_bounds) || - centroid_bounds.p_min[dim] == centroid_bounds.p_max[dim] - ) && return _create_leaf() - # Partition primitives into sets and build children. - if n_primitives <= 2 # Equally-sized subsets. - mid = (from + to) ÷ 2 - pmid = mid > from ? mid - from + 1 : 1 - partialsort!( - @view(primitives_info[from:to]), pmid, by = i -> i.centroid[dim], - ) - else # Perform Surface-Area-Heuristic partitioning. - n_buckets = 12 - buckets = [BucketInfo(0, Bounds3(Point3f(0f0))) for _ in 1:n_buckets] - # Initialize buckets. - for i in from:to - b = Int32(floor(n_buckets * offset( - centroid_bounds, primitives_info[i].centroid, - )[dim])) + 1 - (b == n_buckets + 1) && (b -= 1) - buckets[b].count += 1 - buckets[b].bounds = buckets[b].bounds ∪ primitives_info[i].bounds - end - # Compute costs for splitting after each bucket. - costs = Vector{Float32}(undef, n_buckets - 1) - for i in 1:(n_buckets-1) - it1, it2 = 1:i, (i+1):(n_buckets-1) - s1, s2 = 0, 0 - if length(it1) > 0 - s1 = length(it1) * surface_area( - mapreduce(b -> buckets[b].bounds, ∪, it1), - ) - end - if length(it2) > 0 - s2 = length(it2) * surface_area( - mapreduce(b -> buckets[b].bounds, ∪, it2), - ) - end - costs[i] = 1f0 + (s1 + s2) / surface_area(bounds) - end - # Find bucket to split that minimizes SAH metric. - min_cost_id = argmin(costs) - leaf_cost = n_primitives - # Either create leaf or split primitives at selected SAH bucket. - !( - n_primitives > max_node_primitives - || - costs[min_cost_id] < leaf_cost - ) && return _create_leaf() - mid = partition!(primitives_info, from:to, i -> begin - b = Int32(floor( - n_buckets * offset(centroid_bounds, i.centroid)[dim], - )) + 1 - (b == n_buckets + 1) && (b -= 1) - b <= min_cost_id - end) - end - BVHNode( - dim, - _init( - primitives, primitives_info, from, mid, - total_nodes, ordered_primitives, max_node_primitives, - ), - _init( - primitives, primitives_info, mid + 1, to, - total_nodes, ordered_primitives, max_node_primitives, - ), - ) -end - -function _unroll( - linear_nodes::Vector{LinearBVH}, node::BVHNode, offset::Ref{UInt32}, - ) - - l_offset = offset[] - offset[] += 1 - - if node.children[1] isa Nothing - linear_nodes[l_offset] = LinearBVHLeaf( - node.bounds, node.offset, node.n_primitives, - ) - return l_offset + 1 - end - - _unroll(linear_nodes, node.children[1], offset) - second_child_offset = _unroll(linear_nodes, node.children[2], offset) - 1 - linear_nodes[l_offset] = LinearBVHInterior( - node.bounds, second_child_offset, node.split_axis, - ) - l_offset + 1 -end - -@inline function world_bound(bvh::BVH)::Bounds3 - length(bvh.nodes) > 0 ? bvh.nodes[1].bounds : Bounds3() -end - -struct MemAllocator -end -@inline _allocate(::MemAllocator, T::Type, n::Val{N}) where {N} = MVector{N,T}(undef) -Base.@propagate_inbounds function _setindex(arr::MVector{N, T}, idx::Integer, value::T) where {N, T} - arr[idx] = value - return arr -end - - -""" - closest_hit(bvh::BVH, ray::AbstractRay) - -Find the closest intersection between a ray and the GPU BVH. -Uses manual traversal with compact triangle intersection for best performance. - -Returns: -- `hit_found`: Boolean indicating if an intersection was found -- `hit_primitive`: The primitive that was hit (if any) -- `distance`: Distance along the ray to the hit point (hit_point = ray.o + ray.d * distance) -- `barycentric_coords`: Barycentric coordinates of the hit point -""" -@inline function closest_hit(bvh::BVH, ray::AbstractRay, allocator=MemAllocator()) - ray = check_direction(ray) - inv_dir = 1f0 ./ ray.d - dir_is_neg = is_dir_negative(ray.d) - - # Initialize traversal - local to_visit_offset::Int32 = Int32(1) - current_node_idx = Int32(1) - # Direct MVector construction for type stability (critical for GPU, especially OpenCL/SPIR-V) - nodes_to_visit = MVector{64, Int32}(undef) - nodes = bvh.nodes - triangles = bvh.triangles - original_tris = bvh.primitives - - # Track closest hit - hit_found = false - hit_tri_idx = Int32(0) - closest_t = ray.t_max - hit_u = 0f0 - hit_v = 0f0 - - # Traverse BVH - @_inbounds while true - current_node = nodes[current_node_idx] - - # Test ray against current node's bounding box - if intersect_p(current_node.bounds, ray, inv_dir, dir_is_neg) - local cnprim::Int32 = current_node.n_primitives % Int32 - if !current_node.is_interior && cnprim > Int32(0) - # Leaf node - test all triangles - offset = current_node.offset % Int32 - - for i in Int32(0):(cnprim - Int32(1)) - tri_idx = offset + i - compact_tri = triangles[tri_idx] - - # Use compact intersection - tmp_hit, t, u, v = intersect_compact_triangle(compact_tri, ray.o, ray.d, closest_t) - if tmp_hit && t < closest_t - closest_t = t - hit_found = true - hit_tri_idx = tri_idx - hit_u = u - hit_v = v - end - end - - # Done with leaf, pop next node from stack - if to_visit_offset === Int32(1) - break - end - to_visit_offset -= Int32(1) - current_node_idx = nodes_to_visit[to_visit_offset] - else - # Interior node - push children to stack - # Explicitly unroll axis cases to avoid LLVM select chains in SPIR-V - local is_neg = if current_node.split_axis == Int32(1) - dir_is_neg[1] == Int32(2) - elseif current_node.split_axis == Int32(2) - dir_is_neg[2] == Int32(2) - else # split_axis == 3 - dir_is_neg[3] == Int32(2) - end - - if is_neg - nodes_to_visit[to_visit_offset] = current_node_idx + Int32(1) - current_node_idx = current_node.offset % Int32 - else - nodes_to_visit[to_visit_offset] = current_node.offset % Int32 - current_node_idx += Int32(1) - end - to_visit_offset += Int32(1) - end - else - # Miss - pop next node from stack - if to_visit_offset === Int32(1) - break - end - to_visit_offset -= Int32(1) - current_node_idx = nodes_to_visit[to_visit_offset] - end - end - - # Return result - if hit_found - orig_tri = original_tris[hit_tri_idx] - w = 1f0 - hit_u - hit_v - bary_point = SVector{3, Float32}(w, hit_u, hit_v) - return (true, orig_tri, closest_t, bary_point) - else - # Return dummy result matching standard BVH behavior - dummy_tri = original_tris[1] - bary_point = SVector{3, Float32}(0f0, 0f0, 0f0) - return (false, dummy_tri, 0f0, bary_point) - end -end - -""" - any_hit(bvh::BVH, ray::AbstractRay) - -Test if a ray intersects any primitive in the GPU BVH (for occlusion testing). -Stops at the first intersection found. - -Returns: -- `hit_found`: Boolean indicating if any intersection was found -- `hit_primitive`: The primitive that was hit (if any) -- `distance`: Distance along the ray to the hit point (hit_point = ray.o + ray.d * distance) -- `barycentric_coords`: Barycentric coordinates of the hit point -""" -@inline function any_hit(bvh::BVH, ray::AbstractRay, allocator=MemAllocator()) - ray = check_direction(ray) - inv_dir = 1f0 ./ ray.d - dir_is_neg = is_dir_negative(ray.d) - - # Initialize traversal - local to_visit_offset::Int32 = Int32(1) - current_node_idx = Int32(1) - # Direct MVector construction for type stability (critical for GPU, especially OpenCL/SPIR-V) - nodes_to_visit = MVector{64, Int32}(undef) - nodes = bvh.nodes - triangles = bvh.triangles - original_tris = bvh.primitives - - # Traverse BVH - @_inbounds while true - current_node = nodes[current_node_idx] - - # Test ray against current node's bounding box - if intersect_p(current_node.bounds, ray, inv_dir, dir_is_neg) - local cnprim::Int32 = current_node.n_primitives % Int32 - if !current_node.is_interior && cnprim > Int32(0) - # Leaf node - test triangles - offset = current_node.offset % Int32 - - for i in Int32(0):(cnprim - Int32(1)) - tri_idx = offset + i - compact_tri = triangles[tri_idx] - - # Test for any hit - tmp_hit, t, u, v = intersect_compact_triangle(compact_tri, ray.o, ray.d, ray.t_max) - if tmp_hit - # Return immediately on first hit - orig_tri = original_tris[tri_idx] - w = 1f0 - u - v - bary_point = SVector{3, Float32}(w, u, v) - return (true, orig_tri, t, bary_point) - end - end - - # Done with leaf, pop next node from stack - if to_visit_offset === Int32(1) - break - end - to_visit_offset -= Int32(1) - current_node_idx = nodes_to_visit[to_visit_offset] - else - # Interior node - push children to stack - # Explicitly unroll axis cases to avoid LLVM select chains in SPIR-V - local is_neg = if current_node.split_axis == Int32(1) - dir_is_neg[1] == Int32(2) - elseif current_node.split_axis == Int32(2) - dir_is_neg[2] == Int32(2) - else # split_axis == 3 - dir_is_neg[3] == Int32(2) - end - - if is_neg - nodes_to_visit[to_visit_offset] = current_node_idx + Int32(1) - current_node_idx = current_node.offset % Int32 - else - nodes_to_visit[to_visit_offset] = current_node.offset % Int32 - current_node_idx += Int32(1) - end - to_visit_offset += Int32(1) - end - else - # Miss - pop next node from stack - if to_visit_offset === Int32(1) - break - end - to_visit_offset -= Int32(1) - current_node_idx = nodes_to_visit[to_visit_offset] - end - end - - # No hit found - dummy_tri = original_tris[1] - bary_point = SVector{3, Float32}(0f0, 0f0, 0f0) - return (false, dummy_tri, 0f0, bary_point) -end - -function calculate_ray_grid_bounds(bounds::GeometryBasics.Rect, ray_direction::Vec3f) - # Normalize the direction vector (in case it's not already a unit vector) - direction = normalize(ray_direction) - # 1. Find a plane perpendicular to the ray direction - # We need two basis vectors that are perpendicular to the ray direction - # First, find a non-parallel vector to create our first basis vector - if abs(direction[1]) < 0.9f0 - temp = Vec3f(1.0, 0.0, 0.0) - else - temp = Vec3f(0.0, 1.0, 0.0) - end - - # Create two perpendicular basis vectors for the grid - basis1 = normalize(cross(direction, temp)) - basis2 = normalize(cross(direction, basis1)) - - corners = decompose(Point3f, bounds) - - # 3. Project corners onto our basis vectors - proj1 = [dot(corner, basis1) for corner in corners] - proj2 = [dot(corner, basis2) for corner in corners] - - # 4. Find the min and max projections to determine grid bounds - min_proj1, max_proj1 = extrema(proj1) - min_proj2, max_proj2 = extrema(proj2) - - # 5. Add a small margin to ensure coverage - margin = 0.05f0 * max(max_proj1 - min_proj1, max_proj2 - min_proj2) - grid_width = max_proj1 - min_proj1 + 2 * margin - grid_height = max_proj2 - min_proj2 + 2 * margin - - # 6. Calculate the origin point of the grid - # Choose a point that's sufficiently far back from the bounding box - # Project all corners onto the ray direction - depth_proj = [dot(corner, direction) for corner in corners] - min_depth = minimum(depth_proj) - margin - - # Grid center in world space - grid_center = Point3f(0, 0, 0) + min_depth * direction + - ((min_proj1 + max_proj1) / 2f0) * basis1 + - ((min_proj2 + max_proj2) / 2f0) * basis2 - - # 7. Return the grid information - return ( - center=grid_center, - width=grid_width, - height=grid_height, - basis1=basis1, - basis2=basis2, - ) -end - -# Function to generate ray origins for the grid -function generate_ray_grid(grid_info, grid_size::Int) - ray_origins = Matrix{Point3f}(undef, grid_size, grid_size) - cell_size_width = grid_info.width / grid_size - cell_size_height = grid_info.height / grid_size - for i in 1:grid_size - for j in 1:grid_size - # Calculate the offset from the center - u = (i - (grid_size + 1) / 2) * cell_size_width - v = (j - (grid_size + 1) / 2) * cell_size_height - - # Calculate the ray origin - ray_origins[i, j] = grid_info.center + u * grid_info.basis1 + v * grid_info.basis2 - end - end - return ray_origins -end - -""" - generate_ray_grid(bvh::BVH, ray_direction::Vec3f, grid_size::Int) - -Generate a grid of ray origins based on the BVH bounding box and a given ray direction. -""" -function generate_ray_grid(bvh::BVH, ray_direction::Vec3f, grid_size::Int) - bounds = world_bound(bvh) - bb = Rect3f(bounds.p_min, bounds.p_max .- bounds.p_min) - grid_info = calculate_ray_grid_bounds(bb, ray_direction) - return generate_ray_grid(grid_info, grid_size) -end - - -function GeometryBasics.Mesh(bvh::BVH) - points = Point3f[] - faces = GLTriangleFace[] - prims = bvh.primitives # Use original triangles, not compact ones - for (ti, tringle) in enumerate(prims) - push!(points, tringle.vertices...) - tt = ((ti - 1) * 3) + 1 - face = GLTriangleFace(tt, tt + 1, tt + 2) - push!(faces, face) - end - return GeometryBasics.Mesh(points, faces) -end - -# Pretty printing for BVH -function Base.show(io::IO, ::MIME"text/plain", bvh::BVH) - n_triangles = length(bvh.triangles) - n_nodes = length(bvh.nodes) - - # Count leaf vs interior nodes - n_leaves = count(node -> !node.is_interior, bvh.nodes) - n_interior = n_nodes - n_leaves - - println(io, "BVH:") - println(io, " Triangles: ", n_triangles, " (pre-transformed)") - println(io, " BVH nodes: ", n_nodes, " (", n_interior, " interior, ", n_leaves, " leaves)") - print(io, " Max prims: ", Int(bvh.max_node_primitives), " per leaf") -end - -function Base.show(io::IO, bvh::BVH) - if get(io, :compact, false) - n_triangles = length(bvh.triangles) - n_nodes = length(bvh.nodes) - print(io, "BVH(triangles=", n_triangles, ", nodes=", n_nodes, ")") - else - show(io, MIME("text/plain"), bvh) - end -end diff --git a/src/bvh4.jl b/src/bvh4.jl new file mode 100644 index 0000000..592166b --- /dev/null +++ b/src/bvh4.jl @@ -0,0 +1,774 @@ +# ============================================================================== +# BVH4 - 4-Wide Bounding Volume Hierarchy (HIPRT-style optimization) +# ============================================================================== +# +# Key optimizations from AMD HIPRT: +# - 4-wide nodes reduce tree depth by ~50% vs binary BVH +# - Better cache utilization (128-byte aligned nodes) +# - Fewer memory fetches during traversal +# - Built via collapse pass from LBVH binary tree +# +# Architecture: +# 1. Build binary LBVH as before (fast parallel construction) +# 2. Collapse pass converts BVH2 -> BVH4 (parallel GPU kernel) +# 3. Traversal uses 4-wide AABB tests + +using StaticArrays +import KernelAbstractions as KA +using KernelAbstractions: @index +using Atomix: @atomicswap, @atomic + +# ============================================================================== +# BVH4 Node Structure +# ============================================================================== + +const INVALID_NODE4 = 0xffffffff + +""" + BVHNode4 + +4-wide BVH node optimized for GPU traversal. +Stores up to 4 children with their AABBs inline. + +Memory layout: 128 bytes (one cache line on most GPUs) +- 4 child indices: 16 bytes +- 4 AABBs: 96 bytes (24 bytes each) +- Metadata: 16 bytes + +Leaf nodes indicated by child_count == 0 and first primitive index in child0. +""" +struct BVHNode4 + # Child indices (INVALID_NODE4 for unused slots) + child0::UInt32 + child1::UInt32 + child2::UInt32 + child3::UInt32 + + # Child 0 AABB + aabb0_min::Point3f + aabb0_max::Point3f + + # Child 1 AABB + aabb1_min::Point3f + aabb1_max::Point3f + + # Child 2 AABB + aabb2_min::Point3f + aabb2_max::Point3f + + # Child 3 AABB + aabb3_min::Point3f + aabb3_max::Point3f + + # Metadata + parent::UInt32 + child_count::UInt8 # 0 = leaf, 1-4 = interior + primitive_count::UInt8 # For leaves: number of primitives + _pad1::UInt8 + _pad2::UInt8 +end + +# Verify size (should be 128 bytes for cache alignment) +# 4*4 + 4*24 + 4 + 4 = 16 + 96 + 8 = 120 bytes + padding + +"""Create an empty/invalid BVH4 node.""" +@inline function empty_bvh4_node() + BVHNode4( + INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + INVALID_NODE4, UInt8(0), UInt8(0), UInt8(0), UInt8(0) + ) +end + +"""Check if node is a leaf.""" +@inline is_leaf4(node::BVHNode4) = node.child_count == 0 + +"""Check if node is interior.""" +@inline is_interior4(node::BVHNode4) = node.child_count > 0 + +"""Get child index by position (0-3).""" +@inline function get_child4(node::BVHNode4, i::Int)::UInt32 + i == 1 && return node.child0 + i == 2 && return node.child1 + i == 3 && return node.child2 + return node.child3 +end + +"""Get child AABB by position (1-4).""" +@inline function get_child_aabb4(node::BVHNode4, i::Int)::Bounds3 + if i == 1 + return Bounds3(node.aabb0_min, node.aabb0_max) + elseif i == 2 + return Bounds3(node.aabb1_min, node.aabb1_max) + elseif i == 3 + return Bounds3(node.aabb2_min, node.aabb2_max) + else + return Bounds3(node.aabb3_min, node.aabb3_max) + end +end + +"""Get the node's total AABB (union of all valid children).""" +@inline function get_node_aabb4(node::BVHNode4)::Bounds3 + aabb = Bounds3(node.aabb0_min, node.aabb0_max) + if node.child_count >= 2 + aabb = aabb ∪ Bounds3(node.aabb1_min, node.aabb1_max) + end + if node.child_count >= 3 + aabb = aabb ∪ Bounds3(node.aabb2_min, node.aabb2_max) + end + if node.child_count >= 4 + aabb = aabb ∪ Bounds3(node.aabb3_min, node.aabb3_max) + end + return aabb +end + +# ============================================================================== +# BVH4 Leaf Node (stores triangle data) +# ============================================================================== + +""" + BVH4Leaf + +Leaf node that stores triangle vertices directly (like BVH2IL format). +For BVH4, we keep triangles separate and reference by index. +""" +struct BVH4Leaf + prim_start::UInt32 # First primitive index + prim_count::UInt32 # Number of primitives (usually 1-2) + _pad1::UInt32 + _pad2::UInt32 +end + +# ============================================================================== +# BLAS4 / TLAS4 Structures +# ============================================================================== + +""" + BLAS4{NodeArray, TriArray} + +Bottom-Level Acceleration Structure using BVH4 nodes. +""" +struct BLAS4{ + NodeArray <: AbstractVector{BVHNode4}, + TriArray <: AbstractVector{<:Triangle} +} + nodes::NodeArray + primitives::TriArray + root_aabb::Bounds3 + num_interior::Int32 # Number of interior nodes +end + +""" + TLAS4{NodeArray, InstArray, BLASArray} + +Top-Level Acceleration Structure using BVH4 nodes. +""" +struct TLAS4{ + NodeArray <: AbstractVector{BVHNode4}, + InstArray <: AbstractVector{InstanceDescriptor}, + BLASArray <: AbstractVector{<:BLAS4} +} + nodes::NodeArray + instances::InstArray + blas_array::BLASArray + root_aabb::Bounds3 +end + +# ============================================================================== +# Collapse Pass: BVH2 -> BVH4 +# ============================================================================== + +""" + CollapseTask + +Work item for the BVH2 -> BVH4 collapse pass. +""" +struct CollapseTask + bvh2_node_idx::UInt32 # Source node in BVH2 + bvh4_node_idx::UInt32 # Destination node in BVH4 + depth::UInt32 # Current depth (for load balancing) +end + +""" +Gather up to 4 children from a subtree of the binary BVH. +Performs a BFS-like traversal to collect the 4 best children. + +Returns: (child_indices, child_aabbs, child_count, is_leaf_flags) +""" +@inline function gather_children_bvh2( + root_idx::UInt32, + nodes2::AbstractVector{BVHNode2}, + n_prims::Int32 +) + # Use a small fixed-size work queue + # Start with root's two children + queue = MVector{8, UInt32}(undef) + queue_size = 0 + + # Output arrays + children = MVector{4, UInt32}(INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4) + aabbs = MVector{4, Bounds3}(Bounds3(), Bounds3(), Bounds3(), Bounds3()) + child_is_leaf = MVector{4, Bool}(false, false, false, false) + child_count = 0 + + # Start with root node + @inbounds root = nodes2[root_idx] + + if is_leaf(root) + # Root is already a leaf - single child + children[1] = root_idx + aabbs[1] = get_node_aabb(root, false) + child_is_leaf[1] = true + return (children, aabbs, Int32(1), child_is_leaf) + end + + # Add root's children to queue + queue[1] = root.child0 + queue[2] = root.child1 + queue_size = 2 + + # Expand until we have 4 children or can't expand more + while child_count < 4 && queue_size > 0 + # Find the best node to add (prefer interior nodes to expand) + best_idx = 1 + for i in 1:queue_size + @inbounds node_idx = queue[i] + @inbounds node = nodes2[node_idx] + # Prefer non-leaves that can be expanded + if is_interior(node) && child_count + queue_size - 1 + 2 <= 4 + best_idx = i + break + end + end + + # Pop from queue + @inbounds node_idx = queue[best_idx] + @inbounds queue[best_idx] = queue[queue_size] + queue_size -= 1 + + @inbounds node = nodes2[node_idx] + is_node_interior = is_interior(node) + + if is_node_interior && child_count + queue_size + 2 <= 4 + # Expand this node - add its children to queue + queue_size += 1 + @inbounds queue[queue_size] = node.child0 + queue_size += 1 + @inbounds queue[queue_size] = node.child1 + else + # Add this node as a child + child_count += 1 + @inbounds children[child_count] = node_idx + @inbounds aabbs[child_count] = if is_node_interior + # Interior node - union of child AABBs + Bounds3( + min.(node.aabb0_min, node.aabb1_min), + max.(node.aabb0_max, node.aabb1_max) + ) + else + # Leaf node - compute from triangle vertices + get_node_aabb(node, false) + end + @inbounds child_is_leaf[child_count] = !is_node_interior + end + end + + # Any remaining queue items become children directly + while queue_size > 0 && child_count < 4 + child_count += 1 + @inbounds node_idx = queue[queue_size] + queue_size -= 1 + + @inbounds children[child_count] = node_idx + @inbounds node = nodes2[node_idx] + is_node_interior = is_interior(node) + @inbounds aabbs[child_count] = if is_node_interior + Bounds3( + min.(node.aabb0_min, node.aabb1_min), + max.(node.aabb0_max, node.aabb1_max) + ) + else + get_node_aabb(node, false) + end + @inbounds child_is_leaf[child_count] = !is_node_interior + end + + return (children, aabbs, Int32(child_count), child_is_leaf) +end + +""" +Collapse a BVH2 into a BVH4. + +This is a sequential CPU implementation for simplicity. +A GPU version would use work queues similar to HIPRT. + +The key insight is that we need to: +1. Collect up to 4 subtrees at each BVH4 interior node +2. For leaf subtrees, create BVH4 leaf nodes pointing to primitives +3. For interior subtrees, recursively process them +4. Fix up child pointers to point to BVH4 indices (not BVH2 indices) +""" +function collapse_bvh2_to_bvh4( + nodes2::AbstractVector{BVHNode2}, + primitives::AbstractVector{<:Triangle}, + n_prims::Int32 +) + # Estimate BVH4 node count + max_nodes4 = length(nodes2) + 1 + nodes4 = Vector{BVHNode4}(undef, max_nodes4) + node4_count = 0 + + # Map from BVH2 index to BVH4 index (for subtrees that become BVH4 nodes) + bvh2_to_bvh4 = Dict{UInt32, UInt32}() + + # Work queue: (bvh2_idx, slot to update in parent, parent bvh4 idx) + # slot = which child slot (1-4) in the parent to update + queue = Vector{Tuple{UInt32, Int, UInt32}}() + + # Process root + @inbounds root = nodes2[1] + + if is_leaf(root) + # Single triangle - create one leaf node + node4_count += 1 + prim_idx = root.child1 + v0 = Point3f(root.aabb0_min...) + v1 = Point3f(root.aabb0_max...) + v2 = Point3f(root.aabb1_min...) + p_min = min.(min.(v0, v1), v2) + p_max = max.(max.(v0, v1), v2) + + nodes4[node4_count] = BVHNode4( + prim_idx, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, + p_min, p_max, + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + INVALID_NODE4, UInt8(0), UInt8(1), UInt8(0), UInt8(0) + ) + else + # Gather children from root + children_bvh2, aabbs, child_count, child_is_leaf_flags = gather_children_bvh2(UInt32(1), nodes2, n_prims) + + node4_count += 1 + root4_idx = UInt32(node4_count) + + # Create placeholders for child indices (will be filled in) + child_indices = MVector{4, UInt32}(INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4) + + # Process each child + for i in 1:child_count + bvh2_child_idx = children_bvh2[i] + @inbounds node2_child = nodes2[bvh2_child_idx] + + if child_is_leaf_flags[i] + # This is a leaf - create BVH4 leaf node + node4_count += 1 + leaf4_idx = UInt32(node4_count) + child_indices[i] = leaf4_idx + + prim_idx = node2_child.child1 + v0 = Point3f(node2_child.aabb0_min...) + v1 = Point3f(node2_child.aabb0_max...) + v2 = Point3f(node2_child.aabb1_min...) + p_min = min.(min.(v0, v1), v2) + p_max = max.(max.(v0, v1), v2) + + nodes4[node4_count] = BVHNode4( + prim_idx, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, + p_min, p_max, + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + root4_idx, UInt8(0), UInt8(1), UInt8(0), UInt8(0) + ) + else + # Interior subtree - queue for processing + push!(queue, (bvh2_child_idx, i, root4_idx)) + end + end + + # Create root interior node (will update child pointers later for queued children) + nodes4[1] = BVHNode4( + child_indices[1], child_indices[2], child_indices[3], child_indices[4], + aabbs[1].p_min, aabbs[1].p_max, + aabbs[2].p_min, aabbs[2].p_max, + aabbs[3].p_min, aabbs[3].p_max, + aabbs[4].p_min, aabbs[4].p_max, + INVALID_NODE4, UInt8(child_count), UInt8(0), UInt8(0), UInt8(0) + ) + + # Process queued interior subtrees + while !isempty(queue) + bvh2_idx, parent_slot, parent4_idx = popfirst!(queue) + + # Gather children from this subtree + sub_children, sub_aabbs, sub_count, sub_is_leaf = gather_children_bvh2(bvh2_idx, nodes2, n_prims) + + node4_count += 1 + current4_idx = UInt32(node4_count) + + # Update parent's child pointer + @inbounds parent = nodes4[parent4_idx] + if parent_slot == 1 + nodes4[parent4_idx] = BVHNode4( + current4_idx, parent.child1, parent.child2, parent.child3, + parent.aabb0_min, parent.aabb0_max, parent.aabb1_min, parent.aabb1_max, + parent.aabb2_min, parent.aabb2_max, parent.aabb3_min, parent.aabb3_max, + parent.parent, parent.child_count, parent.primitive_count, parent._pad1, parent._pad2 + ) + elseif parent_slot == 2 + nodes4[parent4_idx] = BVHNode4( + parent.child0, current4_idx, parent.child2, parent.child3, + parent.aabb0_min, parent.aabb0_max, parent.aabb1_min, parent.aabb1_max, + parent.aabb2_min, parent.aabb2_max, parent.aabb3_min, parent.aabb3_max, + parent.parent, parent.child_count, parent.primitive_count, parent._pad1, parent._pad2 + ) + elseif parent_slot == 3 + nodes4[parent4_idx] = BVHNode4( + parent.child0, parent.child1, current4_idx, parent.child3, + parent.aabb0_min, parent.aabb0_max, parent.aabb1_min, parent.aabb1_max, + parent.aabb2_min, parent.aabb2_max, parent.aabb3_min, parent.aabb3_max, + parent.parent, parent.child_count, parent.primitive_count, parent._pad1, parent._pad2 + ) + else + nodes4[parent4_idx] = BVHNode4( + parent.child0, parent.child1, parent.child2, current4_idx, + parent.aabb0_min, parent.aabb0_max, parent.aabb1_min, parent.aabb1_max, + parent.aabb2_min, parent.aabb2_max, parent.aabb3_min, parent.aabb3_max, + parent.parent, parent.child_count, parent.primitive_count, parent._pad1, parent._pad2 + ) + end + + # Create child index array + sub_child_indices = MVector{4, UInt32}(INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4) + + # Process children + for i in 1:sub_count + bvh2_child = sub_children[i] + @inbounds node2_child = nodes2[bvh2_child] + + if sub_is_leaf[i] + # Leaf + node4_count += 1 + leaf4_idx = UInt32(node4_count) + sub_child_indices[i] = leaf4_idx + + prim_idx = node2_child.child1 + v0 = Point3f(node2_child.aabb0_min...) + v1 = Point3f(node2_child.aabb0_max...) + v2 = Point3f(node2_child.aabb1_min...) + p_min = min.(min.(v0, v1), v2) + p_max = max.(max.(v0, v1), v2) + + nodes4[node4_count] = BVHNode4( + prim_idx, INVALID_NODE4, INVALID_NODE4, INVALID_NODE4, + p_min, p_max, + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + Point3f(0), Point3f(0), + current4_idx, UInt8(0), UInt8(1), UInt8(0), UInt8(0) + ) + else + # Interior - queue for later + push!(queue, (bvh2_child, i, current4_idx)) + end + end + + # Create this interior node + nodes4[current4_idx] = BVHNode4( + sub_child_indices[1], sub_child_indices[2], sub_child_indices[3], sub_child_indices[4], + sub_aabbs[1].p_min, sub_aabbs[1].p_max, + sub_aabbs[2].p_min, sub_aabbs[2].p_max, + sub_aabbs[3].p_min, sub_aabbs[3].p_max, + sub_aabbs[4].p_min, sub_aabbs[4].p_max, + parent4_idx, UInt8(sub_count), UInt8(0), UInt8(0), UInt8(0) + ) + end + end + + # Resize to actual count + resize!(nodes4, node4_count) + + return nodes4 +end + +# ============================================================================== +# BVH4 Build Function +# ============================================================================== + +""" + build_blas4(primitives) -> BLAS4 + +Build a BLAS using BVH4 nodes for faster traversal. + +1. Build standard LBVH (BVH2) using existing kernels +2. Collapse BVH2 -> BVH4 +""" +function build_blas4(primitives::AbstractVector{T}) where {T <: Triangle} + n = length(primitives) + n == 0 && error("Cannot build BLAS4 from empty primitive list") + + # First build BVH2 using existing infrastructure + blas2 = build_blas(primitives) + + # Collapse to BVH4 + nodes4 = collapse_bvh2_to_bvh4(blas2.nodes, blas2.primitives, Int32(n)) + + return BLAS4(nodes4, blas2.primitives, blas2.root_aabb, Int32(length(nodes4))) +end + +# ============================================================================== +# BVH4 Traversal +# ============================================================================== + +""" + fast_intersect_bbox4(ray_o, ray_inv_d, node, child_idx, t_min, t_max) -> (hit, t_entry) + +Test ray against one child AABB of a BVH4 node. +""" +@inline function fast_intersect_bbox4( + ray_o::Point3f, + ray_inv_d::Vec3f, + node::BVHNode4, + child_idx::Int, + t_min::Float32, + t_max::Float32 +)::Tuple{Bool, Float32} + aabb = get_child_aabb4(node, child_idx) + + oxinvdir = -ray_o .* ray_inv_d + f = aabb.p_max .* ray_inv_d .+ oxinvdir + n = aabb.p_min .* ray_inv_d .+ oxinvdir + + tmax_vec = max.(f, n) + tmin_vec = min.(f, n) + + max_t = min(minimum(tmax_vec), t_max) + min_t = max(maximum(tmin_vec), t_min) + + return (min_t <= max_t, min_t) +end + +""" + intersect_all_children4(node, ray_inv_d, ray_o, t_min, t_max) -> sorted hits + +Test ray against all 4 children AABBs and return sorted by distance. +Returns up to 4 (child_idx, t_entry) pairs, sorted near-to-far. +""" +@inline function intersect_all_children4( + node::BVHNode4, + ray_inv_d::Vec3f, + ray_o::Point3f, + t_min::Float32, + t_max::Float32 +) + # Test all children + hits = MVector{4, Tuple{UInt32, Float32}}( + (INVALID_NODE4, Inf32), + (INVALID_NODE4, Inf32), + (INVALID_NODE4, Inf32), + (INVALID_NODE4, Inf32) + ) + hit_count = 0 + + for i in 1:Int(node.child_count) + child_idx = get_child4(node, i) + if child_idx != INVALID_NODE4 + hit, t_entry = fast_intersect_bbox4(ray_o, ray_inv_d, node, i, t_min, t_max) + if hit + hit_count += 1 + hits[hit_count] = (child_idx, t_entry) + end + end + end + + # Simple insertion sort for up to 4 elements + for i in 2:hit_count + j = i + while j > 1 && hits[j][2] < hits[j-1][2] + hits[j], hits[j-1] = hits[j-1], hits[j] + j -= 1 + end + end + + return hits, hit_count +end + +""" + closest_hit4(blas::BLAS4, ray::AbstractRay) -> (hit, primitive, distance, barycentric) + +Traverse BVH4 to find closest intersection. +""" +@inline function closest_hit4(blas::BLAS4, ray::R) where {R <: AbstractRay} + ray = check_direction(ray) + ray_o::Point3f = ray.o + ray_d::Vec3f = ray.d + ray_mint::Float32 = 0.0f0 + ray_maxt::Float32 = ray.t_max + ray_inv_d::Vec3f = safe_invdir(ray_d) + + # Stack for traversal (BVH4 needs smaller stack than BVH2) + stack = MVector{32, UInt32}(undef) + stack_ptr::Int32 = Int32(0) + + # Track closest hit + closest_prim::UInt32 = INVALID_NODE4 + hit_u::Float32 = 0.0f0 + hit_v::Float32 = 0.0f0 + + nodes = blas.nodes + prims = blas.primitives + + # Start at root + node_idx::UInt32 = UInt32(1) + + @inbounds while true + node = nodes[node_idx] + + if is_interior4(node) + # Interior node - test all children + hits, hit_count = intersect_all_children4(node, ray_inv_d, ray_o, ray_mint, ray_maxt) + + # Push far children to stack (in reverse order so nearest is popped first) + for i in hit_count:-1:2 + if hits[i][1] != INVALID_NODE4 + stack_ptr += Int32(1) + stack[stack_ptr] = hits[i][1] + end + end + + # Visit nearest child + if hit_count > 0 && hits[1][1] != INVALID_NODE4 + node_idx = hits[1][1] + continue + end + else + # Leaf node - test triangle + prim_idx = node.child0 + if prim_idx != INVALID_NODE4 && prim_idx <= length(prims) + tri = prims[prim_idx] + verts = tri.vertices + hit, t, u, v = fast_intersect_triangle( + ray_o, ray_d, + verts[1], verts[2], verts[3], + ray_mint, ray_maxt + ) + if hit + ray_maxt = t + closest_prim = prim_idx + hit_u = u + hit_v = v + end + end + end + + # Pop from stack + if stack_ptr > Int32(0) + node_idx = stack[stack_ptr] + stack_ptr -= Int32(1) + else + break + end + end + + # Return result + @inbounds if closest_prim != INVALID_NODE4 + tri = prims[closest_prim] + w = 1.0f0 - hit_u - hit_v + bary = SVector{3, Float32}(w, hit_u, hit_v) + return (true, tri, ray_maxt, bary) + else + dummy_tri = empty_triangle(eltype(prims)) + bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) + return (false, dummy_tri, 0.0f0, bary) + end +end + +""" + any_hit4(blas::BLAS4, ray::AbstractRay) -> (hit, primitive, distance, barycentric) + +Traverse BVH4 to find any intersection (early exit). +""" +@inline function any_hit4(blas::BLAS4, ray::R) where {R <: AbstractRay} + ray = check_direction(ray) + ray_o::Point3f = ray.o + ray_d::Vec3f = ray.d + ray_mint::Float32 = 0.0f0 + ray_maxt::Float32 = ray.t_max + ray_inv_d::Vec3f = safe_invdir(ray_d) + + # Stack for traversal + stack = MVector{32, UInt32}(undef) + stack_ptr::Int32 = Int32(0) + + nodes = blas.nodes + prims = blas.primitives + + # Start at root + node_idx::UInt32 = UInt32(1) + + @inbounds while true + node = nodes[node_idx] + + if is_interior4(node) + # Interior node - test all children + hits, hit_count = intersect_all_children4(node, ray_inv_d, ray_o, ray_mint, ray_maxt) + + # Push all hit children (order doesn't matter for any_hit) + for i in hit_count:-1:2 + if hits[i][1] != INVALID_NODE4 + stack_ptr += Int32(1) + stack[stack_ptr] = hits[i][1] + end + end + + if hit_count > 0 && hits[1][1] != INVALID_NODE4 + node_idx = hits[1][1] + continue + end + else + # Leaf node - test triangle + prim_idx = node.child0 + if prim_idx != INVALID_NODE4 && prim_idx <= length(prims) + tri = prims[prim_idx] + verts = tri.vertices + hit, t, u, v = fast_intersect_triangle( + ray_o, ray_d, + verts[1], verts[2], verts[3], + ray_mint, ray_maxt + ) + if hit + # Early exit on first hit + w = 1.0f0 - u - v + bary = SVector{3, Float32}(w, u, v) + return (true, tri, t, bary) + end + end + end + + # Pop from stack + if stack_ptr > Int32(0) + node_idx = stack[stack_ptr] + stack_ptr -= Int32(1) + else + break + end + end + + # No hit + dummy_tri = prims[1] + bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) + return (false, dummy_tri, 0.0f0, bary) +end + +# ============================================================================== +# Exports +# ============================================================================== + +export BVHNode4, BLAS4, TLAS4 +export build_blas4, closest_hit4, any_hit4 +export is_leaf4, is_interior4, get_child4, get_child_aabb4 diff --git a/src/collision.jl b/src/collision.jl new file mode 100644 index 0000000..0c1bad3 --- /dev/null +++ b/src/collision.jl @@ -0,0 +1,261 @@ +# ============================================================================== +# Collision Detection for TLAS +# ============================================================================== +# +# Two-pass GPU collision detection using the existing BVH2 structure. +# Follows ImplicitBVH's leaf-vs-tree (LVT) pattern adapted for Raycore's +# two-level TLAS/BLAS with instance transforms. +# +# Architecture: +# Pass 1: Count contacts per instance (write nothing) +# Prefix sum: Compute write offsets +# Pass 2: Write contact pairs at pre-computed offsets +# +# Works on any KernelAbstractions backend (CPU, CUDA, AMDGPU, Lava, etc.) + +# ============================================================================== +# Contact types +# ============================================================================== + +""" + ContactPair + +A pair of contacting instances in a TLAS, identified by their 1-based instance indices. +""" +struct ContactPair + instance_a::UInt32 + instance_b::UInt32 +end + +""" + CollisionResult{C, B} + +Result of collision detection, containing contact pairs and a reusable cache buffer. + +Fields: +- `contacts`: Vector of `ContactPair` (or similar) on the compute backend +- `num_contacts`: Number of valid contacts +- `cache`: Internal buffer for reuse across frames via `cache` keyword +""" +struct CollisionResult{C, B} + contacts::C + num_contacts::Int + cache::B +end + +# ============================================================================== +# AABB overlap test for BVHNode2 +# ============================================================================== + +"""Test if two AABBs overlap.""" +@inline function aabb_overlaps(a_min::Point3f, a_max::Point3f, b_min::Point3f, b_max::Point3f) + all(a_max .>= b_min) && all(a_min .<= b_max) +end + +"""Get the AABB of a TLAS node (internal: union of children, leaf: stored AABB).""" +@inline function tlas_node_aabb(node::BVHNode2) + if is_interior(node) + # Interior: union of children AABBs + p_min = min.(node.aabb0_min, node.aabb1_min) + p_max = max.(node.aabb0_max, node.aabb1_max) + return (p_min, p_max) + else + # Leaf: instance AABB stored in aabb0 + return (node.aabb0_min, node.aabb0_max) + end +end + +# ============================================================================== +# Instance-level collision kernel (TLAS broad-phase) +# ============================================================================== + +""" + collide_instances_kernel! + +For each TLAS leaf (instance), traverse the TLAS tree to find overlapping instances. +Two-pass: when `contacts` is nothing, only counts. When not nothing, writes pairs. + +Uses stack-based depth-first traversal, same pattern as closest_hit but testing +AABB-AABB overlap instead of ray-AABB intersection. +""" +@kernel function collide_instances_kernel!( + contact_counts, + contacts, # Nothing on counting pass, AbstractVector{ContactPair} on writing pass + @Const(nodes), + n_instances::Int32 +) + i = @index(Global, Linear) + + @inbounds if i <= n_instances + # Get this instance's leaf node and AABB + leaf_idx = Int(n_instances) - 1 + i + leaf_node = nodes[leaf_idx] + a_min, a_max = tlas_node_aabb(leaf_node) + instance_a = leaf_node.child1 # 0-indexed instance index stored in child1 + + # Stack for traversal (16 levels is plenty for TLAS) + stack = MVector{16, UInt32}(undef) + stack_ptr = Int32(0) + + # Start at root + node_index = UInt32(1) + count = UInt32(0) + + while true + node = nodes[node_index] + + if is_interior(node) + # Test children AABBs + overlap0 = aabb_overlaps(a_min, a_max, node.aabb0_min, node.aabb0_max) + overlap1 = aabb_overlaps(a_min, a_max, node.aabb1_min, node.aabb1_max) + + if overlap0 && overlap1 + # Both overlap — push far, visit near + stack_ptr += Int32(1) + stack[stack_ptr] = node.child1 + node_index = node.child0 + continue + elseif overlap0 + node_index = node.child0 + continue + elseif overlap1 + node_index = node.child1 + continue + end + # Neither overlaps — fall through to pop + else + # Leaf node — check if it's a different instance + instance_b = node.child1 + if instance_b > instance_a # Only count each pair once (a < b) + b_min, b_max = tlas_node_aabb(node) + if aabb_overlaps(a_min, a_max, b_min, b_max) + count += UInt32(1) + if contacts !== nothing + # Writing pass: write at pre-computed offset + write_idx = contact_counts[i] - count + UInt32(1) + contacts[write_idx] = ContactPair(instance_a + UInt32(1), instance_b + UInt32(1)) + end + end + end + end + + # Pop from stack + if stack_ptr > Int32(0) + node_index = stack[stack_ptr] + stack_ptr -= Int32(1) + else + break + end + end + + # Store count (on counting pass, this is the total; on writing pass, it's for offset calc) + if contacts === nothing + contact_counts[i] = count + end + end +end + +# ============================================================================== +# Public API +# ============================================================================== + +""" + collide_instances(tlas::TLAS; cache=nothing) -> CollisionResult + +Find all pairs of instances whose world-space AABBs overlap. + +This is a broad-phase test — it identifies which instances *might* be in contact +based on their bounding boxes. For exact triangle-triangle contact, use `collide`. + +Returns a `CollisionResult` with `ContactPair`s (1-indexed instance IDs). + +# Example +```julia +tlas = TLAS(backend) +push!(tlas, mesh_a, transform_a) +push!(tlas, mesh_b, transform_b) +sync!(tlas) + +result = collide_instances(tlas) +for i in 1:result.num_contacts + pair = result.contacts[i] + println("Instance \$(pair.instance_a) overlaps instance \$(pair.instance_b)") +end + +# Reuse buffers for next frame: +result2 = collide_instances(tlas; cache=result.cache) +``` +""" +function collide_instances(tlas::TLAS; cache=nothing) + sync!(tlas) + n = Int32(length(tlas.instances)) + n == 0 && return CollisionResult( + KA.allocate(tlas.backend, ContactPair, 0), 0, + KA.allocate(tlas.backend, UInt32, 0) + ) + + backend = tlas.backend + nodes = tlas.nodes + + # Allocate or reuse count buffer + if cache !== nothing && length(cache) >= n + contact_counts = cache + else + contact_counts = KA.allocate(backend, UInt32, Int(n)) + end + contact_counts .= UInt32(0) + + # Pass 1: Count contacts per instance + kern! = collide_instances_kernel!(backend) + kern!(contact_counts, nothing, nodes, n, ndrange=Int(n)) + KA.synchronize(backend) + + # Prefix sum to get write offsets + # After accumulate, contact_counts[i] = total contacts for instances 1..i + AK.accumulate!(+, contact_counts, init=UInt32(0)) + + # Total contacts + total = Int(@allowscalar contact_counts[end]) + if total == 0 + return CollisionResult( + KA.allocate(backend, ContactPair, 0), 0, contact_counts + ) + end + + # Allocate contacts + contacts = KA.allocate(backend, ContactPair, total) + + # Pass 2: Write contact pairs + kern!(contact_counts, contacts, nodes, n, ndrange=Int(n)) + KA.synchronize(backend) + + return CollisionResult(contacts, total, contact_counts) +end + +""" + collide_instances_any(tlas::TLAS, handle_a::TLASHandle, handle_b::TLASHandle) -> Bool + +Test whether two specific instance groups overlap (broad-phase AABB test). +Fast early-exit — returns true on first overlap found. +""" +function collide_instances_any(tlas::TLAS, handle_a::TLASHandle, handle_b::TLASHandle) + sync!(tlas) + range_a = tlas.handle_to_range[handle_a] + range_b = tlas.handle_to_range[handle_b] + + # CPU check — for a quick boolean we just test instance AABBs directly + nodes = Array(tlas.nodes) # Small download for TLAS nodes + instances = Array(tlas.instances) + n = Int32(length(instances)) + + for ia in range_a, ib in range_b + leaf_a = nodes[Int(n) - 1 + ia] + leaf_b = nodes[Int(n) - 1 + ib] + a_min, a_max = tlas_node_aabb(leaf_a) + b_min, b_max = tlas_node_aabb(leaf_b) + if aabb_overlaps(a_min, a_max, b_min, b_max) + return true + end + end + return false +end diff --git a/src/instanced-bvh-kernels.jl b/src/instanced-bvh-kernels.jl new file mode 100644 index 0000000..74bcd57 --- /dev/null +++ b/src/instanced-bvh-kernels.jl @@ -0,0 +1,572 @@ +# ============================================================================== +# Instanced BVH - GPU Kernels with KernelAbstractions +# ============================================================================== +# +# Follows KA best practices: +# - @kernel only for parallel dispatch +# - All logic in regular type-stable Julia functions +# - Minimal code inside @kernel functions + +import KernelAbstractions as KA +using KernelAbstractions: @index +using Atomix: @atomic + +# ============================================================================== +# GPU Kernel 0: Fill arrays (workaround for OpenCL fill! struct issue) +# ============================================================================== + +"""GPU kernel: Fill array with a value (workaround for OpenCL's fill! not supporting structs).""" +KA.@kernel function fill_bvhnode2_kernel!(arr, val) + i = @index(Global, Linear) + @inbounds arr[i] = val +end + +"""GPU kernel: Fill array with sequential indices [1, 2, 3, ..., n].""" +KA.@kernel function iota_kernel!(arr) + i = @index(Global, Linear) + @inbounds arr[i] = i +end + +# ============================================================================== +# GPU Kernel 0b: Compute Instance World AABBs +# ============================================================================== + +""" +Compute world AABB for a single instance by transforming local AABB corners. +Returns (min_point, max_point) as two Point3f values. +""" +@inline function compute_instance_world_aabb( + inst::InstanceDescriptor, + blas_array::AbstractVector{<:BLAS} +) + blas = blas_array[inst.blas_index] + local_aabb = blas.root_aabb + + # Initialize with first corner + corner1 = transform_point(inst.transform, corner(local_aabb, 1)) + min_p = corner1 + max_p = corner1 + + # Expand to include all 8 corners + for c in 2:8 + world_corner = transform_point(inst.transform, corner(local_aabb, c)) + min_p = Point3f(min(min_p[1], world_corner[1]), + min(min_p[2], world_corner[2]), + min(min_p[3], world_corner[3])) + max_p = Point3f(max(max_p[1], world_corner[1]), + max(max_p[2], world_corner[2]), + max(max_p[3], world_corner[3])) + end + + return (min_p, max_p) +end + +"""GPU kernel: Compute world AABBs for all instances, storing min/max points separately.""" +KA.@kernel function compute_instance_aabbs_kernel!( + aabb_mins::AbstractVector{Point3f}, + aabb_maxs::AbstractVector{Point3f}, + @Const(instances), + @Const(blas_array) +) + i = @index(Global, Linear) + @inbounds begin + inst = instances[i] + min_p, max_p = compute_instance_world_aabb(inst, blas_array) + aabb_mins[i] = min_p + aabb_maxs[i] = max_p + end +end + +# ============================================================================== +# GPU Kernel 1: Calculate Morton Codes +# ============================================================================== + +""" +Calculate Morton code for a single primitive. +This is a regular Julia function, callable from CPU or GPU. +""" +@inline function calculate_morton_code_for_prim( + prim_idx::Int, + primitives::AbstractVector{<:Triangle}, + scene_min::Point3f, + scene_extent::Vec3f +)::UInt32 + tri_aabb = world_bound(primitives[prim_idx]) + centroid = 0.5f0 * (tri_aabb.p_min + tri_aabb.p_max) + normalized = (centroid - scene_min) ./ scene_extent + return morton_code_30bit(normalized) +end + +"""GPU kernel: Parallel dispatch for Morton code calculation.""" +KA.@kernel function calculate_morton_codes_kernel!( + morton_codes, + primitives, + scene_min, + scene_extent +) + i = @index(Global, Linear) + @inbounds morton_codes[i] = calculate_morton_code_for_prim(i, primitives, scene_min, scene_extent) +end + +# ============================================================================== +# GPU Kernel 2: Emit Topology +# ============================================================================== + +""" +Build topology for one internal node. +Regular Julia function for testability. +""" +@inline function build_topology_for_node( + idx::Int32, + morton_codes::AbstractVector{UInt32}, + n_prims::Int32 +)::BVHNode2 + # Helper functions + @inline leaf_idx(j::Int32) = n_prims - Int32(1) + j + + # Find span + span_left, span_right = find_span_for_node(idx, morton_codes, n_prims) + + # Find split + split = find_split_in_span(span_left, span_right, morton_codes, n_prims) + + # Determine children + # If split is at boundary, it's a leaf. Otherwise it's a valid internal node. + child0 = (split == span_left) ? leaf_idx(split) : split + child1_idx = split + Int32(1) + child1 = (child1_idx == span_right) ? leaf_idx(child1_idx) : child1_idx + + return BVHNode2( + Point3f(0), Point3f(0), Point3f(0), Point3f(0), + UInt32(child0), UInt32(child1), INVALID_NODE + ) +end + +"""GPU kernel: Parallel topology emission.""" +KA.@kernel function emit_topology_kernel!(nodes, morton_codes, n_prims::Int32) + i = @index(Global, Linear) + idx = Int32(i) + if idx < n_prims + @inbounds nodes[idx] = build_topology_for_node(idx, morton_codes, n_prims) + end +end + +# ============================================================================== +# GPU Kernel 3: Set Parent Pointers +# ============================================================================== + +"""Set parent pointers for one node's children. Regular Julia function.""" +@inline function set_parents_for_node( + node_idx::Int32, + nodes::AbstractVector{BVHNode2} +)::Tuple{Int32, BVHNode2, Int32, BVHNode2} + @inbounds node = nodes[node_idx] + child0_idx = Int32(node.child0) + child1_idx = Int32(node.child1) + + @inbounds old0 = nodes[child0_idx] + @inbounds old1 = nodes[child1_idx] + + new0 = BVHNode2( + old0.aabb0_min, old0.aabb0_max, old0.aabb1_min, old0.aabb1_max, + old0.child0, old0.child1, UInt32(node_idx) + ) + new1 = BVHNode2( + old1.aabb0_min, old1.aabb0_max, old1.aabb1_min, old1.aabb1_max, + old1.child0, old1.child1, UInt32(node_idx) + ) + + return (child0_idx, new0, child1_idx, new1) +end + +"""GPU kernel: Parallel parent pointer assignment.""" +KA.@kernel function set_parent_pointers_kernel!(nodes, n_prims::Int32) + i = @index(Global, Linear) + idx = Int32(i) + if idx < n_prims + child0_idx, new0, child1_idx, new1 = set_parents_for_node(idx, nodes) + @inbounds nodes[child0_idx] = new0 + @inbounds nodes[child1_idx] = new1 + end +end + +# ============================================================================== +# GPU Kernel 4: Create Leaf Nodes +# ============================================================================== + +"""Create leaf node for one primitive. Regular Julia function.""" +@inline function create_leaf_for_prim( + prim_idx::Int, + primitives::AbstractVector{<:Triangle}, + parent_node::BVHNode2, + n_prims::Int32 +)::BVHNode2 + # Store triangle vertices directly in leaf node (BVH2IL format) + tri = primitives[prim_idx] + verts = tri.vertices + v0 = verts[1] + v1 = verts[2] + v2 = verts[3] + + return BVHNode2( + v0, v1, v2, Point3f(0), # Store v0, v1, v2 in aabb slots + INVALID_NODE, UInt32(prim_idx), parent_node.parent + ) +end + +"""GPU kernel: Parallel leaf node creation.""" +KA.@kernel function create_leaf_nodes_kernel!(nodes, primitives, n_prims::Int32) + i = @index(Global, Linear) + if i <= n_prims + @inline leaf_idx(j::Int) = Int(n_prims) - 1 + j + leaf_node_idx = leaf_idx(i) + @inbounds parent_node = nodes[leaf_node_idx] + @inbounds nodes[leaf_node_idx] = create_leaf_for_prim(i, primitives, parent_node, n_prims) + end +end + +# ============================================================================== +# Refit AABBs Kernel (Parallel Bottom-Up) +# ============================================================================== + +""" +Parallel bottom-up AABB refit using atomic counters. + +Each thread starts at a leaf and walks up the tree. Uses atomic operations +to ensure each internal node is updated exactly once after both children are ready. +Based on RadeonRays Refit kernel. +""" +KA.@kernel function refit_aabbs_kernel!( + nodes, + update_flags, + n_prims::Int32 +) + prim_idx = @index(Global, Linear) + + # Start at parent of this leaf + leaf_idx = leaf_index(prim_idx, n_prims) + @inbounds parent_idx = nodes[leaf_idx].parent + + # Walk up the tree + while parent_idx != INVALID_NODE + # Atomic increment: mark this node as visited + # If new_value == 1: we're first thread (was 0), bail out + # If new_value == 2: we're second thread (was 1), update AABB and continue + # Note: @atomicswap doesn't work on OpenCL, so we use @atomic += instead + new_value = @inbounds @atomic update_flags[parent_idx] += UInt32(1) + + if new_value == UInt32(2) + # Second thread arrived - compute AABB from both children + @inbounds begin + node = nodes[parent_idx] + child0 = node.child0 + child1 = node.child1 + + is_child0_internal = child0 < n_prims + is_child1_internal = child1 < n_prims + + aabb0 = get_node_aabb(nodes[child0], is_child0_internal) + aabb1 = get_node_aabb(nodes[child1], is_child1_internal) + + # Update this node's AABBs + updated_node = BVHNode2( + aabb0.p_min, aabb0.p_max, + aabb1.p_min, aabb1.p_max, + node.child0, node.child1, node.parent + ) + nodes[parent_idx] = updated_node + # Move to parent + parent_idx = node.parent + end + else + # First thread - bail out + break + end + end +end + +# ============================================================================== +# TLAS-Specific Kernels +# ============================================================================== + +""" +Calculate Morton code for a single instance centroid. +""" +@inline function calculate_tlas_morton_code( + inst_idx::Int, + instances::AbstractVector{InstanceDescriptor}, + blas_array::AbstractVector{<:BLAS}, + scene_min::Point3f, + scene_extent::Vec3f +)::UInt32 + inst = instances[inst_idx] + blas = blas_array[inst.blas_index] + local_aabb = blas.root_aabb + + # Transform centroid to world space + local_center = 0.5f0 * (local_aabb.p_min + local_aabb.p_max) + world_center = transform_point(inst.transform, local_center) + + # Normalize and compute Morton code + normalized = (world_center - scene_min) ./ scene_extent + return morton_code_30bit(normalized) +end + +"""GPU kernel: Calculate Morton codes for TLAS instances.""" +KA.@kernel function calculate_tlas_morton_codes_kernel!( + morton_codes, + @Const(instances), + @Const(blas_array), + scene_min, + scene_extent +) + i = @index(Global, Linear) + @inbounds morton_codes[i] = calculate_tlas_morton_code( + i, instances, blas_array, scene_min, scene_extent + ) +end + +""" +Create TLAS leaf node for one instance (stores world-space AABB, not triangle vertices). +""" +@inline function create_tlas_leaf_for_instance( + sorted_leaf_idx::Int, + sorted_indices::AbstractVector{<:Integer}, + instances::AbstractVector{InstanceDescriptor}, + blas_array::AbstractVector{<:BLAS}, + parent::UInt32 +)::BVHNode2 + # Get original instance index (sorted_indices maps sorted position -> original position) + original_idx = sorted_indices[sorted_leaf_idx] + inst = instances[original_idx] + blas = blas_array[inst.blas_index] + local_aabb = blas.root_aabb + + # Transform AABB to world space (8 corners) + world_aabb = Bounds3() + for c in 1:8 + world_corner = transform_point(inst.transform, corner(local_aabb, c)) + world_aabb = world_aabb ∪ Bounds3(world_corner) + end + + return BVHNode2( + world_aabb.p_min, world_aabb.p_max, Point3f(0), Point3f(0), + INVALID_NODE, UInt32(original_idx - 1), # 0-indexed instance index + parent + ) +end + +"""GPU kernel: Create TLAS leaf nodes.""" +KA.@kernel function create_tlas_leaf_nodes_kernel!( + nodes, + @Const(sorted_indices), + @Const(instances), + @Const(blas_array), + n_instances::Int32 +) + i = @index(Global, Linear) + if i <= n_instances + leaf_node_idx = Int(n_instances) - 1 + i + @inbounds parent = nodes[leaf_node_idx].parent + @inbounds nodes[leaf_node_idx] = create_tlas_leaf_for_instance( + i, sorted_indices, instances, blas_array, parent + ) + end +end + +""" +Parallel bottom-up AABB refit for TLAS using atomic counters. +Uses get_tlas_node_aabb which treats leaves as storing AABBs directly (not triangle vertices). +""" +KA.@kernel function refit_tlas_aabbs_kernel!( + nodes, + update_flags, + n_instances::Int32 +) + inst_idx = @index(Global, Linear) + + # Start at parent of this leaf + leaf_idx = Int(n_instances) - 1 + inst_idx + @inbounds parent_idx = nodes[leaf_idx].parent + + # Walk up the tree + while parent_idx != INVALID_NODE + # Atomic increment: mark this node as visited + # Note: @atomicswap doesn't work on OpenCL, so we use @atomic += instead + new_value = @inbounds @atomic update_flags[parent_idx] += UInt32(1) + + if new_value == UInt32(2) + # Second thread arrived - compute AABB from both children + @inbounds begin + node = nodes[parent_idx] + child0 = node.child0 + child1 = node.child1 + + is_child0_internal = child0 < n_instances + is_child1_internal = child1 < n_instances + + # Use TLAS-specific AABB computation (leaves store AABBs, not vertices) + aabb0 = get_tlas_node_aabb(nodes[child0], is_child0_internal) + aabb1 = get_tlas_node_aabb(nodes[child1], is_child1_internal) + + # Update this node's AABBs + updated_node = BVHNode2( + aabb0.p_min, aabb0.p_max, + aabb1.p_min, aabb1.p_max, + node.child0, node.child1, node.parent + ) + nodes[parent_idx] = updated_node + + # Move to parent + parent_idx = node.parent + end + else + # First thread - bail out + break + end + end +end + +# ============================================================================== +# GPU Dynamic Update Kernels +# ============================================================================== + +KA.@kernel function update_instance_transforms_kernel!( + instances, + @Const(transforms), + n_particles::Int32 +) + i = @index(Global, Linear) + if i <= n_particles + @inbounds begin + old_inst = instances[i] + transform = transforms[i] + instances[i] = InstanceDescriptor( + old_inst.blas_index, + old_inst.instance_id, + transform, + mat3x4_inverse(transform), + old_inst.flags + ) + end + end +end + +KA.@kernel function update_instance_transforms_offset_kernel!( + instances, + @Const(transforms), + n_particles::Int32, + first_idx::Int32 +) + i = @index(Global, Linear) + if i <= n_particles + @inbounds begin + inst_idx = first_idx + i - Int32(1) + old_inst = instances[inst_idx] + transform = transforms[i] + instances[inst_idx] = InstanceDescriptor( + old_inst.blas_index, + old_inst.instance_id, + transform, + mat3x4_inverse(transform), + old_inst.flags + ) + end + end +end + +""" +GPU kernel: Update TLAS leaf node AABBs from instance transforms. + +After transforms are updated, this kernel recomputes world-space AABBs for +all leaf nodes. Must be called before refit_tlas_aabbs_kernel!. + +NOTE: Leaf nodes are Morton-sorted, so we must use the stored instance index +(child1) to look up the correct instance. +""" +KA.@kernel function update_tlas_leaf_aabbs_kernel!( + nodes, + @Const(instances), + @Const(blas_array), + n_instances::Int32 +) + i = @index(Global, Linear) + if i <= n_instances + @inbounds begin + leaf_node_idx = Int(n_instances) - 1 + i + old_node = nodes[leaf_node_idx] + + # Get the actual instance index from the leaf node (stored as 0-indexed in child1) + inst_idx = Int(old_node.child1) + 1 + inst = instances[inst_idx] + blas = blas_array[inst.blas_index] + local_aabb = blas.root_aabb + + # Transform AABB to world space (8 corners) + world_aabb = Bounds3() + for c in 1:8 + world_corner = transform_point(inst.transform, corner(local_aabb, c)) + world_aabb = world_aabb ∪ Bounds3(world_corner) + end + + # Update leaf node with new AABB (preserve topology) + nodes[leaf_node_idx] = BVHNode2( + world_aabb.p_min, world_aabb.p_max, Point3f(0), Point3f(0), + old_node.child0, old_node.child1, old_node.parent + ) + end + end +end + +""" +GPU kernel: Batch update materials based on particle velocities. + +Updates material colors using a heat-map based on velocity magnitude. +Also updates metallic/roughness with per-particle noise variation. +""" +KA.@kernel function update_particle_materials_kernel!( + materials, + @Const(positions), + @Const(velocities), + @Const(radii), + n_particles::Int32, + max_speed::Float32 +) + i = @index(Global, Linear) + if i <= n_particles + @inbounds begin + vel = velocities[i] + speed = sqrt(vel[1]^2 + vel[2]^2 + vel[3]^2) + + # Velocity to heat color (same as CPU version) + t = clamp(speed / max_speed, 0.0f0, 1.0f0) + color = if t < 0.25f0 + s = t / 0.25f0 + RGB{Float32}(0.1f0, 0.2f0 + 0.5f0 * s, 0.8f0) + elseif t < 0.5f0 + s = (t - 0.25f0) / 0.25f0 + RGB{Float32}(0.1f0 + 0.6f0 * s, 0.7f0, 0.8f0 - 0.6f0 * s) + elseif t < 0.75f0 + s = (t - 0.5f0) / 0.25f0 + RGB{Float32}(0.7f0 + 0.3f0 * s, 0.7f0 - 0.4f0 * s, 0.2f0 - 0.1f0 * s) + else + s = (t - 0.75f0) / 0.25f0 + RGB{Float32}(1.0f0, 0.3f0 + 0.7f0 * s, 0.1f0 + 0.9f0 * s) + end + + # Noise for metallic/roughness variety + noise = sin(Float32(i) * 0.1f0) * 0.5f0 + 0.5f0 + metallic = 0.3f0 + noise * 0.6f0 + roughness = 0.2f0 + (1.0f0 - noise) * 0.4f0 + + materials[i] = Material(color, metallic, roughness, 1.0f0, 0.0f0) + end + end +end + +# Export kernel helper functions for testing +export calculate_morton_code_for_prim, build_topology_for_node +export set_parents_for_node, create_leaf_for_prim, refit_node_aabb +export calculate_tlas_morton_code, create_tlas_leaf_for_instance +export update_instance_transforms_kernel!, update_instance_transforms_offset_kernel!, update_tlas_leaf_aabbs_kernel! +export update_particle_materials_kernel! diff --git a/src/instanced-bvh.jl b/src/instanced-bvh.jl new file mode 100644 index 0000000..e2b759f --- /dev/null +++ b/src/instanced-bvh.jl @@ -0,0 +1,2440 @@ +# ============================================================================== +# Instanced BVH - Two-Level Acceleration Structure +# ============================================================================== +# +# Based on AMD RadeonRays SDK architecture: +# - BLAS (Bottom-Level Acceleration Structure): BVH over triangle geometry +# - TLAS (Top-Level Acceleration Structure): BVH over instances with transforms +# - Two-level traversal with transform handling +# +# Key optimizations: +# - LBVH (Linear BVH) construction using 30-bit Morton codes +# - Parametrized array types for CPU/GPU compatibility +# - Fully type-stable traversal kernels +# - Compact memory layout for cache efficiency +# +# Architecture (GPU-First): +# - BLASes are built directly on the backend (GPU arrays from the start) +# - TLAS stores a Vector of backend BLASes (CPU vector for management, but arrays inside are on GPU) +# - During sync!/adapt, isbits device pointers are extracted for kernel traversal +# - CPU-side dictionaries provide O(1) instance lookup +# - No CPU→GPU copy needed for BLAS data during sync! + +using StaticArrays +using LinearAlgebra: I +import KernelAbstractions as KA +import AcceleratedKernels as AK + +# Vulkan-compatible 3×4 transform (row-major). SMatrix{4,3} is column-major and +# byte-identical to a Vulkan row-major 3×4, so the two conventions share the +# same memory layout without any reinterpret. +const Mat3x4f = SMatrix{4, 3, Float32, 12} + +# ============================================================================== +# Core Data Structures +# ============================================================================== + +""" + BVHNode2 + +Compact BVH node for two-child binary trees. +Stores AABBs for both children inline (BVH2IL layout from RadeonRays). + +Fields: +- `aabb0_min`, `aabb0_max`: Child 0 bounding box (or vertex data for leaves) +- `aabb1_min`, `aabb1_max`: Child 1 bounding box (or unused for leaves) +- `child0`: Child 0 index (INVALID_NODE for leaves) +- `child1`: Child 1 index (or primitive index for leaves) +- `parent`: Parent node index +""" +struct BVHNode2 + # Child 0 AABB (or triangle vertex 0 for leaves) + aabb0_min::Point3f + aabb0_max::Point3f + + # Child 1 AABB (or triangle vertices 1,2 for leaves) + aabb1_min::Point3f + aabb1_max::Point3f + + # Topology + child0::UInt32 # INVALID_NODE (0xFFFFFFFF) indicates leaf + child1::UInt32 # Primitive index for leaves, child index for interior + parent::UInt32 +end + +const INVALID_NODE = 0xffffffff + +"""Check if a node is a leaf node.""" +@inline is_leaf(node::BVHNode2) = node.child0 == INVALID_NODE + +"""Check if a node is an interior node.""" +@inline is_interior(node::BVHNode2) = node.child0 != INVALID_NODE + +""" + InstanceDescriptor + +Describes an instance of a bottom-level BVH in world space. + +Fields: +- `blas_index`: Index into BLAS array +- `instance_id`: Scene-binding slot for this instance. `0` means "inherit + from the triangle's per-face metadata"; any nonzero value is forwarded + verbatim by `closest_hit` / `any_hit` as the 5th return value. Hikari + uses it as a `medium_interface_idx` override so N instances of one BLAS + can have distinct materials / media / emission without duplicating the + BLAS geometry. Matches Vulkan's `gl_InstanceCustomIndexEXT`. +- `transform`: Local-to-world transform (Vulkan row-major 3×4, `Mat3x4f`) +- `inv_transform`: World-to-local transform (Vulkan row-major 3×4, `Mat3x4f`) +- `flags`: Instance flags (reserved for future use) +""" +struct InstanceDescriptor + blas_index::UInt32 + instance_id::UInt32 + transform::Mat3x4f + inv_transform::Mat3x4f + flags::UInt32 +end + +# Mat4f convenience: `convert(Mat3x4f, ::Mat4f)` fails (different SMatrix +# shape), so an explicit outer constructor lets callers pass the natural +# homogeneous 4×4 form. Mirrors push!(tlas, mesh, ::Mat4f). +InstanceDescriptor(blas_index, instance_id, transform::Mat4f, inv_transform::Mat4f, flags) = + InstanceDescriptor(blas_index, instance_id, mat4_to_mat3x4(transform), mat4_to_mat3x4(inv_transform), flags) + +""" + BLAS{NodeArray, TriArray} + +Bottom-Level Acceleration Structure - BVH over triangle geometry. + +Type parameters allow CPU (Vector) or GPU (CuArray, ROCArray) storage. +""" +struct BLAS{ + NodeArray <: AbstractVector{BVHNode2}, + TriArray <: AbstractVector{<:Triangle} +} + nodes::NodeArray + primitives::TriArray + root_aabb::Bounds3 +end + +""" + BLASDescriptor + +Lightweight descriptor for a BLAS in flat-array layout. +Instead of storing device pointers to per-BLAS arrays (which fail on Metal when +stored in GPU buffers), this stores offsets into concatenated flat arrays. + +Fields: +- `nodes_offset`: 0-based offset into the flat all_blas_nodes array +- `primitives_offset`: 0-based offset into the flat all_blas_prims array +- `root_aabb`: Bounding box of the BLAS in local space +""" +struct BLASDescriptor + nodes_offset::UInt32 + primitives_offset::UInt32 + root_aabb::Bounds3 +end + +# ============================================================================== +# StaticTLAS - Immutable structure for kernel traversal +# ============================================================================== + +""" + StaticTLAS{NodeArray, InstArray, BLASNodeArray, BLASPrimArray, DescArray} + +Immutable Top-Level Acceleration Structure for GPU kernel traversal. +This is what `Adapt.adapt_structure` returns from a TLAS. + +Uses flat arrays with offset-based indexing instead of per-BLAS pointer arrays. +This avoids the Metal issue where device pointers stored in GPU buffers cannot +be reliably dereferenced by kernels. + +The struct is immutable and contains only the arrays needed for ray traversal. +No management state (dictionaries, free lists, etc.) - those stay on CPU in TLAS. +""" +struct StaticTLAS{ + NodeArray <: AbstractVector{BVHNode2}, + InstArray <: AbstractVector{InstanceDescriptor}, + BLASNodeArray <: AbstractVector{BVHNode2}, + BLASPrimArray <: AbstractVector{<:Triangle}, + DescArray <: AbstractVector{BLASDescriptor} +} <: AbstractAdaptedAccel + nodes::NodeArray + instances::InstArray + all_blas_nodes::BLASNodeArray + all_blas_prims::BLASPrimArray + blas_descriptors::DescArray + root_aabb::Bounds3 +end + +# ============================================================================== +# TLAS - Mutable structure with backend arrays + CPU management +# ============================================================================== + +""" + TLASHandle + +Stable handle for referencing instances in a TLAS. +Simple unique ID for O(1) lookup in handle_to_range dictionary. +""" +struct TLASHandle + id::UInt32 +end + +# Sentinel for invalid handle +const INVALID_HANDLE = TLASHandle(UInt32(0)) + +""" + BLASArrays + +Per-BLAS backing GPU arrays. Field ownership transitively keeps the +`nodes` / `primitives` buffers alive while the TLAS holds them, so the +isbits device pointers stored in `blas_array` (and the flat arrays used +by StaticTLAS) remain valid. + +`primitives` is left at the bare `AbstractVector` bound rather than +`AbstractVector{<:Triangle}` because differently-parameterized +`Triangle{TMetadata}` subtypes need to coexist across BLASes; the tighter +bound forces a UnionAll that doesn't help dispatch. +""" +struct BLASArrays + nodes::AbstractVector{BVHNode2} + primitives::AbstractVector +end + +""" + TLAS{Backend} + +Mutable Top-Level Acceleration Structure with direct GPU arrays. + +GPU-first design: instances are appended directly to GPU array using efficient +GPU append. CPU-side dictionary provides O(1) handle lookups. + +# Adapted-form invariant (READ THIS) + +`sync!(tlas)` is the single owner of the GPU-adapted form. It rebuilds as +efficiently as possible — in place via `resize!`/`copyto!` where the backing +buffer can be reused, freshly allocated only when a buffer grew — and stores +the result in `tlas.static_tlas`. `sync!` MAY reassign `tlas.static_tlas = +new_static_tlas` when a buffer was reallocated. + +Every consumer that hands an accel to a raytracing kernel MUST go through +`tlas.static_tlas` or `Adapt.adapt(backend, tlas)` (which reads / refreshes +`tlas.static_tlas`) per dispatch. Those are cheap; `sync!` did the heavy +lifting. NEVER cache the `StaticTLAS` returned by `adapt` across mutations — +after a reshape-driven reallocation the cached snapshot holds a stale device +pointer. + +The same contract applies to `HWTLAS` — see its docstring. + +# Type Parameters +- `Backend`: KernelAbstractions backend (CPU(), LavaBackend(), CUDABackend(), etc.) + +# Fields +- `backend`: KernelAbstractions backend for kernels +- `nodes`: BVH nodes array (on backend, grown in place on sync! where possible) +- `instances`: Instance descriptors array (GPU array, direct append) +- `blas_array`: BLAS objects array (GPU array with isbits pointers) +- `root_aabb`: World-space bounding box +- `handle_to_range`: Handle -> range in instances array (CPU-side) +- `deleted_handles`: Handles deleted but not yet compacted (CPU-side) +- `blas_storage`: Per-BLAS backing arrays; keeps GPU buffers alive for isbits pointers +- `static_tlas`: GPU-adapted form, owned by `sync!`. Consumers read this, do + not cache it across mutations. +- `dirty`: Whether BVH topology needs rebuild +- `transforms_dirty`: Whether instance transforms changed and the TLAS needs refit +- `revision`: Monotonic counter bumped by every mutation that reshapes + GPU-visible arrays. Intended purely as a diagnostic signal — consumers that + go through `tlas.static_tlas` don't need to check it. + +# Usage +```julia +tlas = TLAS(CPU()) +h1 = push!(tlas, mesh) +h2 = push!(tlas, mesh, transforms) +update_transform!(tlas, h2, new_transform) +delete!(tlas, h1) +sync!(tlas) # Rebuild / refresh tlas.static_tlas +static = adapt(backend, tlas) # Reads tlas.static_tlas — cheap, call per dispatch +``` +""" +mutable struct TLAS{Backend} <: AbstractAccel + backend::Backend + + # Backend arrays for kernel traversal (GPU from start). Field types are + # bounded by the abstract element type so `tlas.instances[i]` returns a + # concrete `InstanceDescriptor` without boxing, while still allowing the + # backing array to be reallocated to a different concrete container (e.g. + # `Vector` ↔ `LavaArray`) across `sync!`. The container itself stays + # abstract because `KA.allocate` returns backend-specific types we don't + # want to fix at struct definition. + nodes::AbstractVector{BVHNode2} # rebuilt on sync! + instances::AbstractVector{InstanceDescriptor} # direct GPU append + # `blas_array` is a backend array of isbits `BLAS{NodeArr,TriArr}`. Element + # type varies by mesh metadata, so no tighter bound; `nothing` until the + # first `push!` because the concrete BLAS type isn't known at construction. + blas_array::Union{Nothing, AbstractVector} + + root_aabb::Bounds3 + + # CPU-side management (dictionaries must stay on CPU for O(1) lookup) + handle_to_range::Dict{TLASHandle, UnitRange{Int}} + deleted_handles::Set{TLASHandle} + + # Per-BLAS backing arrays. Structural composition: the TLAS owning + # `blas_storage` transitively pins every BLAS's `nodes` / `primitives` + # buffer, so the isbits pointers stored in `blas_array` stay valid as + # long as the TLAS does. + blas_storage::Vector{BLASArrays} + + # Flat BLAS arrays for StaticTLAS traversal (built during sync!, kept alive for isbits pointers). + # `nothing` before the first `build_flat_blas_arrays!`. + _flat_blas_nodes::Union{Nothing, AbstractVector{BVHNode2}} + _flat_blas_prims::Union{Nothing, AbstractVector} # see blas_array note + _flat_blas_descs::Union{Nothing, AbstractVector{BLASDescriptor}} + + # GPU-adapted form, owned by sync!. Consumers read this via `tlas.static_tlas` + # or `Adapt.adapt(backend, tlas)` per dispatch — do NOT cache across + # mutations. See the TLAS docstring for the full invariant. + # `nothing` until the first sync!. + static_tlas::Union{Nothing, StaticTLAS} + + # Whether BVH topology needs rebuild (geometry added/removed) + dirty::Bool + + # Whether instance transforms changed and the TLAS needs refit (leaf AABB update) + transforms_dirty::Bool + + # Counters + next_handle_id::UInt32 +end + +# Note: get_isbits_ptr is defined in multitypeset.jl and reused here + +# ------------------------------------------------------------------------------ +# TLAS Constructor and Core Operations +# ------------------------------------------------------------------------------ + +""" + TLAS(backend) -> TLAS + +Create an empty TLAS for the given backend. +Use `push!` to add geometries/instances, then `sync!` to rebuild the BVH. +`Adapt.adapt_structure` returns a StaticTLAS for kernel traversal. + +# Example +```julia +tlas = TLAS(OpenCLBackend()) +h1 = push!(tlas, geometry) +h2 = push!(tlas, Instance(geometry, transforms)) +sync!(tlas) # Rebuild BVH on backend +static = adapt(backend, tlas) # StaticTLAS with isbits pointers for kernels +``` +""" +function TLAS(backend) + # GPU-first design: all arrays on backend from the start + tlas = TLAS( + backend, + KA.allocate(backend, BVHNode2, 0), # nodes (empty, rebuilt on sync!) + KA.allocate(backend, InstanceDescriptor, 0), # instances (direct GPU append) + allocate_empty_blas_array(backend), # blas_array (GPU array of isbits BLASes) + Bounds3(), # root_aabb + Dict{TLASHandle, UnitRange{Int}}(), # handle_to_range + Set{TLASHandle}(), # deleted_handles + BLASArrays[], # blas_storage + nothing, # _flat_blas_nodes + nothing, # _flat_blas_prims + nothing, # _flat_blas_descs + nothing, # static_tlas (owned by sync!) + true, # dirty (topology) + false, # transforms_dirty + UInt32(1), # next_handle_id + ) + + # Register finalizer to free GPU memory when TLAS is garbage collected + finalizer(free!, tlas) + + return tlas +end + +""" + free!(x) + +Trigger the registered finalizer on `x` to release GPU memory. +Safe to call on any object — no-op if no finalizer is registered. +""" +free!(x) = (finalize(x); nothing) + +""" + free!(tlas::TLAS) + +Release all GPU memory held by `tlas`. Does **not** synchronize. + +**Precondition (caller's responsibility):** the GPU must be idle for +`tlas.backend` before this is called — either because the caller just +returned from `sync!(tlas)` / `Hikari.sync!(scene)`, from a +`colorbuffer` that issued its own `device_wait_idle`, or because the +caller has otherwise issued `KA.synchronize(tlas.backend)`. + +Calling `free!` while dispatches are still in flight through a `LavaArray` +/ `VkAccelerationStructureKHR` BDA captured in an arg buffer is a +use-after-free. +""" +function free!(tlas::TLAS) + finalize(tlas.nodes) + finalize(tlas.instances) + finalize(tlas.blas_array) + for ba in tlas.blas_storage + finalize(ba.nodes) + finalize(ba.primitives) + end + empty!(tlas.blas_storage) + tlas._flat_blas_nodes !== nothing && finalize(tlas._flat_blas_nodes) + tlas._flat_blas_prims !== nothing && finalize(tlas._flat_blas_prims) + tlas._flat_blas_descs !== nothing && finalize(tlas._flat_blas_descs) + tlas._flat_blas_nodes = nothing + tlas._flat_blas_prims = nothing + tlas._flat_blas_descs = nothing + return nothing +end + +"""Helper to create initial empty BLAS array placeholder.""" +function allocate_empty_blas_array(_backend) + # Return nothing - the array will be created on first push with the correct type + return nothing +end + +"""Get the isbits pointer type for a given element type and backend.""" +function get_isbits_ptr_type(backend::KA.CPU, ::Type{T}) where T + return Vector{T} # On CPU, Vector is already isbits-compatible for our purposes +end + +function get_isbits_ptr_type(backend, ::Type{T}) where T + # For GPU backends, use argconvert to get the isbits device pointer type + arr = KA.allocate(backend, T, 1) + isbits_ptr = get_isbits_ptr(backend, arr) + return typeof(isbits_ptr) +end + +""" +Convert a BLAS with backend arrays to an isbits BLAS with device pointers. + +Appends a `BLASArrays(nodes, primitives)` to `blas_storage` so the backing +buffers outlive every isbits pointer that references them. + +Note: The isbits BLAS is only used by management kernels that read root_aabb +(inline data). For traversal, StaticTLAS uses flat arrays with offset-based +indexing instead (see BLASDescriptor). +""" +function to_isbits_blas(backend, blas::BLAS, blas_storage::Vector{BLASArrays}) + push!(blas_storage, BLASArrays(blas.nodes, blas.primitives)) + + # Get isbits device pointers + isbits_nodes = get_isbits_ptr(backend, blas.nodes) + isbits_prims = get_isbits_ptr(backend, blas.primitives) + + return BLAS(isbits_nodes, isbits_prims, blas.root_aabb) +end + +""" +Append a single isbits BLAS to blas_array using GPU-friendly append!. +Returns the (possibly new) blas_array. +""" +function append_blas!(backend, blas_array, isbits_blas) + # Create a single-element array on CPU with the isbits BLAS, then adapt to backend + single_arr = [isbits_blas] + backend_arr = Adapt.adapt(backend, single_arr) + + if blas_array === nothing + # First BLAS - create the array with correct type + return backend_arr + else + # Append to existing array + append!(blas_array, backend_arr) + return blas_array + end +end + +""" + build_flat_blas_arrays!(tlas::TLAS) + +Build concatenated flat arrays from individual BLAS GPU arrays and store them +in `tlas._flat_blas_nodes`, `tlas._flat_blas_prims`, `tlas._flat_blas_descs`. + +This avoids storing device pointers in GPU buffers (which fails on Metal). +Instead, traversal kernels use BLASDescriptor offsets to index into the flat arrays. + +The flat arrays are MtlVector/CuVector etc., kept alive by the TLAS. +During adapt, they are converted to isbits device pointers for kernels. +""" +function build_flat_blas_arrays!(tlas::TLAS) + n_blas = length(tlas.blas_storage) + backend = tlas.backend + + if n_blas == 0 + tlas._flat_blas_nodes = nothing + tlas._flat_blas_prims = nothing + tlas._flat_blas_descs = nothing + return + end + + # Read root_aabb from blas_array (inline data, always correct even on Metal) + cpu_blas = Array(tlas.blas_array) + + # Compute total sizes and build descriptors + descriptors = Vector{BLASDescriptor}(undef, n_blas) + total_nodes = 0 + total_prims = 0 + for i in 1:n_blas + ba = tlas.blas_storage[i] + descriptors[i] = BLASDescriptor(UInt32(total_nodes), UInt32(total_prims), cpu_blas[i].root_aabb) + total_nodes += length(ba.nodes) + total_prims += length(ba.primitives) + end + + # Allocate flat arrays on backend (use first BLAS's array type as template) + first_ba = tlas.blas_storage[1] + all_nodes = similar(first_ba.nodes, total_nodes) + all_prims = similar(first_ba.primitives, total_prims) + + # Copy BLAS data into flat arrays + nodes_pos = 1 + prims_pos = 1 + for ba in tlas.blas_storage + nn = length(ba.nodes) + copyto!(all_nodes, nodes_pos, ba.nodes, 1, nn) + nodes_pos += nn + + np = length(ba.primitives) + copyto!(all_prims, prims_pos, ba.primitives, 1, np) + prims_pos += np + end + + # Store on TLAS to keep alive (prevents GC of backing GPU buffers) + tlas._flat_blas_nodes = all_nodes + tlas._flat_blas_prims = all_prims + tlas._flat_blas_descs = Adapt.adapt(backend, descriptors) +end + +""" + is_valid(tlas::TLAS, handle::TLASHandle) -> Bool + +Check if a handle is still valid (not deleted). O(1) operation. +""" +function is_valid(tlas::TLAS, handle::TLASHandle)::Bool + haskey(tlas.handle_to_range, handle) && !(handle in tlas.deleted_handles) +end + +""" + n_instances(tlas::TLAS, handle::TLASHandle) -> Int + +Get the number of instances referenced by a handle. +""" +function n_instances(tlas::TLAS, handle::TLASHandle)::Int + haskey(tlas.handle_to_range, handle) || return 0 + handle in tlas.deleted_handles && return 0 + return length(tlas.handle_to_range[handle]) +end + +""" + n_total_instances(tlas::TLAS) -> Int + +Get the total number of active instances in the TLAS. +""" +n_total_instances(tlas::TLAS) = length(tlas.instances) + +# ------------------------------------------------------------------------------ +# TLAS: push! operations - Direct GPU append +# ------------------------------------------------------------------------------ + +""" + build_triangle(vertices, normals, uvs, indices, face_idx, metadata) + +Build a Triangle from decomposed mesh arrays at the given face index. +""" +function build_triangle(vertices, normals, uvs, indices, face_idx, metadata) + f_idx = 1 + (3 * (face_idx - 1)) + vs = @SVector [vertices[indices[f_idx + i]] for i in 0:2] + ns = @SVector [normals[indices[f_idx + i]] for i in 0:2] + ts = @SVector [Vec3f(NaN) for _ in 1:3] + uv = if !isempty(uvs) + @SVector [uvs[indices[f_idx + i]] for i in 0:2] + else + SVector(Point2f(0), Point2f(1, 0), Point2f(1, 1)) + end + Triangle(vs, ns, ts, uv, metadata) +end + +""" + is_degenerate_face(vertices, indices, face_idx) + +Check if a triangle face is degenerate (zero area) without constructing a full Triangle. +""" +function is_degenerate_face(vertices, indices, face_idx) + f_idx = 1 + (3 * (face_idx - 1)) + vs = @SVector [vertices[indices[f_idx + i]] for i in 0:2] + is_degenerate(vs) +end + +"""Internal: decompose a GB.Mesh, build a BLAS, append it to the TLAS, and +return the new BLAS's 1-based index. Shared by every `push!` variant.""" +function build_and_append_blas!(tlas::TLAS, mesh::GeometryBasics.Mesh) + nmesh = GeometryBasics.expand_faceviews(mesh) + fs = decompose(TriangleFace{UInt32}, nmesh) + verts = decompose(Point3f, nmesh) + norms = Normal3f.(decompose_normals(nmesh)) + uvs_raw = GeometryBasics.decompose_uv(nmesh) + uvs = isnothing(uvs_raw) ? Point2f[] : Point2f.(uvs_raw) + indices = collect(reinterpret(UInt32, fs)) + + has_meta = hasproperty(nmesh, :face_meta) + n_faces = length(fs) + + cpu_triangles = [begin + # After expand_faceviews, face_meta is per-vertex with all 3 verts sharing same value + meta = has_meta ? nmesh.face_meta[indices[3*(i-1)+1]] : UInt32(i) + build_triangle(verts, norms, uvs, indices, i, meta) + end + for i in 1:n_faces + if !is_degenerate_face(verts, indices, i) + ] + isempty(cpu_triangles) && error("Geometry has no valid triangles") + + backend_triangles = Adapt.adapt(tlas.backend, cpu_triangles) + blas = build_blas(backend_triangles) + isbits_blas = to_isbits_blas(tlas.backend, blas, tlas.blas_storage) + tlas.blas_array = append_blas!(tlas.backend, tlas.blas_array, isbits_blas) + return UInt32(length(tlas.blas_array)) +end + +"""Internal: append the given instance descriptors to `tlas.instances`, register +a handle for the appended range, and return it.""" +function append_instances_with_handle!(tlas::TLAS, cpu_descriptors::AbstractVector{InstanceDescriptor}) + start_idx = length(tlas.instances) + 1 + backend_descriptors = Adapt.adapt(tlas.backend, collect(cpu_descriptors)) + append!(tlas.instances, backend_descriptors) + end_idx = length(tlas.instances) + + handle = TLASHandle(tlas.next_handle_id) + tlas.next_handle_id += UInt32(1) + tlas.handle_to_range[handle] = start_idx:end_idx + tlas.dirty = true + return handle +end + +""" + push!(tlas::TLAS, mesh::GeometryBasics.Mesh, transform::Mat4f=Mat4f(I); + instance_id::UInt32=UInt32(0)) -> TLASHandle + +Add a GeometryBasics.Mesh to the TLAS. Per-face metadata is read from the mesh's +`face_meta` attribute (if present). If no `face_meta` attribute exists, each +triangle gets `UInt32(face_idx)` as metadata. + +`instance_id` is forwarded to the `InstanceDescriptor`. It defaults to `0` +(inherit from triangle metadata); pass a nonzero value to override the +per-triangle interface — see `InstanceDescriptor` for semantics. + +Returns a stable handle for later reference. +""" +function Base.push!(tlas::TLAS, mesh::GeometryBasics.Mesh, transform::Mat4f=Mat4f(I); + instance_id::UInt32=UInt32(0)) + blas_idx = build_and_append_blas!(tlas, mesh) + t = mat4_to_mat3x4(transform) + cpu_descriptors = [InstanceDescriptor(blas_idx, instance_id, t, mat3x4_inverse(t), UInt32(0))] + return append_instances_with_handle!(tlas, cpu_descriptors) +end + +""" + push!(tlas::TLAS, mesh::GeometryBasics.Mesh, transforms::AbstractVector{Mat4f}; + instance_ids::Union{Nothing, AbstractVector{UInt32}}=nothing) -> TLASHandle + +Add a GeometryBasics.Mesh to the TLAS with multiple transforms (instancing). +Builds BLAS once, creates `length(transforms)` InstanceDescriptors. + +`instance_ids` (if given) must match `length(transforms)` and supplies the +per-instance interface override. When `nothing`, every instance gets `0` +(inherit from triangle metadata). + +Returns a stable handle for later reference. +""" +function Base.push!(tlas::TLAS, mesh::GeometryBasics.Mesh, transforms::AbstractVector{Mat4f}; + instance_ids::Union{Nothing, AbstractVector{<:Integer}}=nothing) + if instance_ids !== nothing && length(instance_ids) != length(transforms) + throw(ArgumentError("instance_ids length $(length(instance_ids)) != transforms length $(length(transforms))")) + end + + blas_idx = build_and_append_blas!(tlas, mesh) + + cpu_descriptors = map(enumerate(transforms)) do (i, transform) + t = mat4_to_mat3x4(transform) + iid = instance_ids === nothing ? UInt32(0) : UInt32(instance_ids[i]) + InstanceDescriptor(blas_idx, iid, t, mat3x4_inverse(t), UInt32(0)) + end + return append_instances_with_handle!(tlas, cpu_descriptors) +end + +# ------------------------------------------------------------------------------ +# TLAS: delete! operation +# ------------------------------------------------------------------------------ + +""" + delete!(tlas::TLAS, handle::TLASHandle) -> Bool + +Remove all instances referenced by the handle. Returns true if successful. +The handle becomes invalid after deletion. + +Note: The instances array is compacted during sync!, not immediately. +""" +function Base.delete!(tlas::TLAS, handle::TLASHandle)::Bool + haskey(tlas.handle_to_range, handle) || return false + handle in tlas.deleted_handles && return false + + # Mark as deleted (will be compacted on sync!) + push!(tlas.deleted_handles, handle) + tlas.dirty = true + + return true +end + +# ------------------------------------------------------------------------------ +# TLAS: get_instance - retrieve instance data +# ------------------------------------------------------------------------------ + +""" + get_instance(tlas::TLAS, handle::TLASHandle) -> InstanceDescriptor + get_instance(tlas::TLAS, handle::TLASHandle, instance_idx::Integer) -> InstanceDescriptor + +Retrieve the InstanceDescriptor for a handle. If the handle has multiple instances +(created with multiple transforms), use `instance_idx` to specify which one (1-based). + +Note: Reads from GPU array, may involve a device-to-host copy. +""" +function get_instance(tlas::TLAS, handle::TLASHandle, instance_idx::Integer=1) + haskey(tlas.handle_to_range, handle) || error("Invalid handle") + handle in tlas.deleted_handles && error("Handle has been deleted") + range = tlas.handle_to_range[handle] + 1 <= instance_idx <= length(range) || error("Instance index $instance_idx out of range 1:$(length(range))") + + idx = first(range) + instance_idx - 1 + # Copy single element from GPU to CPU + return Array(tlas.instances[idx:idx])[1] +end + +""" + get_instances(tlas::TLAS, handle::TLASHandle) -> Vector{InstanceDescriptor} + +Retrieve all InstanceDescriptors for a handle (for handles with multiple transforms). + +Note: Reads from GPU array, involves a device-to-host copy. +""" +function get_instances(tlas::TLAS, handle::TLASHandle) + haskey(tlas.handle_to_range, handle) || error("Invalid handle") + handle in tlas.deleted_handles && error("Handle has been deleted") + range = tlas.handle_to_range[handle] + # Copy range from GPU to CPU + return Array(tlas.instances[range]) +end + +# ------------------------------------------------------------------------------ +# TLAS: update_transform! operations - Direct GPU updates +# ------------------------------------------------------------------------------ + +""" + update_transform!(tlas::TLAS, handle::TLASHandle, transform) + +Update the transform of a single-instance handle directly on GPU. +For handles with multiple instances, use `update_transforms!`. + +`transform` may be a `Mat4f` (homogeneous 4×4) or the canonical Vulkan +row-major 3×4 (`Mat3x4f`); the `Mat4f` form is converted internally via +`mat4_to_mat3x4`. After calling this, use `refit_tlas!` to update the +BVH AABBs. +""" +function update_transform!(tlas::TLAS, handle::TLASHandle, transform::Mat3x4f) + haskey(tlas.handle_to_range, handle) || error("Invalid handle") + handle in tlas.deleted_handles && error("Handle has been deleted") + range = tlas.handle_to_range[handle] + length(range) == 1 || error("Handle has $(length(range)) instances, use update_transforms! for multiple") + + transforms = Adapt.adapt(tlas.backend, [transform]) + update_instance_transforms!(tlas, transforms, 1, first(range)) + + return nothing +end + +# Mat4f convenience: mirrors push!(tlas, mesh, ::Mat4f) which also accepts +# the homogeneous 4×4 form and converts to Mat3x4f at the boundary. +update_transform!(tlas::TLAS, handle::TLASHandle, transform::Mat4f) = + update_transform!(tlas, handle, mat4_to_mat3x4(transform)) + +""" + update_transforms!(tlas::TLAS, handle::TLASHandle, transforms) + +Update all instances' transforms in a group directly on GPU. +Length must match the number of instances in the handle. + +Transforms may be `Mat4f` or the canonical Vulkan row-major 3×4 +(`Mat3x4f`); `Mat4f` arrays are converted element-wise via +`mat4_to_mat3x4`. The array can be CPU or GPU — it'll be adapted to the +TLAS backend. After calling this, use `refit_tlas!` to update the BVH +AABBs. +""" +function update_transforms!(tlas::TLAS, handle::TLASHandle, transforms::AbstractVector{Mat3x4f}) + haskey(tlas.handle_to_range, handle) || error("Invalid handle") + handle in tlas.deleted_handles && error("Handle has been deleted") + range = tlas.handle_to_range[handle] + length(transforms) == length(range) || error("Transform count ($(length(transforms))) != instance count ($(length(range)))") + + backend_transforms = Adapt.adapt(tlas.backend, transforms) + update_instance_transforms!(tlas, backend_transforms, length(range), first(range)) + + return nothing +end + +update_transforms!(tlas::TLAS, handle::TLASHandle, transforms::AbstractVector{Mat4f}) = + update_transforms!(tlas, handle, map(mat4_to_mat3x4, transforms)) + +# ------------------------------------------------------------------------------ +# TLAS: update! for geometry replacement +# ------------------------------------------------------------------------------ + +""" + update!(tlas::TLAS, handle::TLASHandle, new_geometry) + +Replace the geometry (BLAS) for a handle. All instances sharing this BLAS get updated. +""" +function update!(tlas::TLAS, handle::TLASHandle, new_geometry) + haskey(tlas.handle_to_range, handle) || error("Invalid handle") + handle in tlas.deleted_handles && error("Handle has been deleted") + range = tlas.handle_to_range[handle] + isempty(range) && error("Handle has no instances") + + # Get blas_index from first instance (read from GPU) + first_desc = Array(tlas.instances[first(range):first(range)])[1] + blas_idx = Int(first_desc.blas_index) + + # Build new BLAS on backend (GPU-first) - decompose GB.Mesh directly + nmesh = GeometryBasics.expand_faceviews(new_geometry) + fs = decompose(TriangleFace{UInt32}, nmesh) + verts = decompose(Point3f, nmesh) + norms = Normal3f.(decompose_normals(nmesh)) + uvs_raw = GeometryBasics.decompose_uv(nmesh) + uvs = isnothing(uvs_raw) ? Point2f[] : Point2f.(uvs_raw) + indices = collect(reinterpret(UInt32, fs)) + + has_meta = hasproperty(nmesh, :face_meta) + n_faces = length(fs) + + cpu_triangles = [begin + meta = has_meta ? nmesh.face_meta[indices[3*(i-1)+1]] : UInt32(i) + build_triangle(verts, norms, uvs, indices, i, meta) + end + for i in 1:n_faces + if !is_degenerate_face(verts, indices, i) + ] + isempty(cpu_triangles) && error("New geometry has no valid triangles") + + backend_triangles = Adapt.adapt(tlas.backend, cpu_triangles) + new_blas = build_blas(backend_triangles) + + # Swap in the new backing arrays; the old BLASArrays entry is replaced + # atomically so the isbits pointers in `blas_array` are never left + # dangling mid-update. + tlas.blas_storage[blas_idx] = BLASArrays(new_blas.nodes, new_blas.primitives) + + # Create isbits version and replace in blas_array + isbits_blas = BLAS( + get_isbits_ptr(tlas.backend, new_blas.nodes), + get_isbits_ptr(tlas.backend, new_blas.primitives), + new_blas.root_aabb + ) + tlas.blas_array[blas_idx] = isbits_blas + + tlas.dirty = true + return nothing +end + +# ------------------------------------------------------------------------------ +# TLAS: sync! - Rebuild BVH from GPU instances array +# ------------------------------------------------------------------------------ + +""" + sync!(tlas::TLAS) -> TLAS + +Rebuild the BVH structure if dirty, then **wait for the GPU to finish** +all work currently queued on `tlas.backend` (via `KA.synchronize`). +No-op if already up-to-date AND no work is pending. + +`sync!` is the single owner of `tlas.static_tlas`. It rebuilds in place +(`resize!`/`copyto!` on the same backing buffer) wherever possible so that a +`StaticTLAS` returned by an earlier `Adapt.adapt(backend, tlas)` still sees +the new geometry, and reallocates + reassigns `tlas.static_tlas` when a +buffer needed to grow. + +After `sync!(tlas)` returns: +- The TLAS reflects all prior `push!`/`delete!` calls. +- `tlas.static_tlas` is the fresh adapted form. +- The GPU is idle for `tlas.backend`. +- Any resource that is no longer reachable through the (new) accel is + safe to release synchronously: its `VkAccelerationStructureKHR` / + BDA captures in in-flight work have drained. + +Invariants for callers: +- `update!` / `push!` / `delete!` on the TLAS do NOT wait and do NOT + perform cleanup; they mark state dirty and return immediately. Call + `sync!` to establish a safe boundary before freeing transitively- + owned resources. +- Consumers of the adapted form read `tlas.static_tlas` (or call + `Adapt.adapt(backend, tlas)`) per dispatch. They MUST NOT cache the + returned `StaticTLAS` across mutations — after a reshape-driven reassign + the cached snapshot holds a stale device pointer. +""" +function sync!(tlas::TLAS) + # True no-op when nothing changed and static_tlas is already built. + # adapt(backend, tlas) calls this per dispatch — the fast path MUST be + # allocation-free and must not issue a GPU synchronize. + if !tlas.dirty && !tlas.transforms_dirty && tlas.static_tlas !== nothing + return tlas + end + + if tlas.dirty + rebuild_bvh!(tlas) + tlas.transforms_dirty = false + rebuild_static_tlas!(tlas) + elseif tlas.transforms_dirty + refit_tlas!(tlas) + # refit updates AABBs in tlas.nodes in place — if static_tlas wraps the + # same backing buffer (normal case) it sees the new AABBs. If the field + # was never built yet, build it now. + tlas.static_tlas === nothing && rebuild_static_tlas!(tlas) + else + # Freshly-built TLAS, no mutations yet — still need a valid static_tlas + # so `adapt(backend, tlas)` returns something usable. + rebuild_static_tlas!(tlas) + end + # Only synchronize when we actually touched the GPU. This keeps sync! cheap + # on the clean path (above), so callers can invoke it liberally. + KA.synchronize(tlas.backend) + return tlas +end + +""" + rebuild_static_tlas!(tlas::TLAS) + +Build a fresh `StaticTLAS` from the current `tlas.nodes` / `tlas.instances` / +flat-BLAS arrays and store it on `tlas.static_tlas`. Called by `sync!` — not a +public API; consumers read `tlas.static_tlas` or `Adapt.adapt(backend, tlas)`. +""" +function rebuild_static_tlas!(tlas::TLAS) + backend = tlas.backend + if tlas._flat_blas_nodes === nothing + # Empty scene — use correctly-typed empty backing arrays so the + # StaticTLAS type parameters are still concrete. + prim_type = isempty(tlas.blas_storage) ? Triangle{UInt32} : eltype(tlas.blas_storage[1].primitives) + empty_nodes = KA.allocate(backend, BVHNode2, 0) + empty_prims = KA.allocate(backend, prim_type, 0) + empty_descs = Adapt.adapt(backend, BLASDescriptor[]) + tlas.static_tlas = StaticTLAS( + Adapt.adapt(backend, tlas.nodes), + Adapt.adapt(backend, tlas.instances), + Adapt.adapt(backend, empty_nodes), + Adapt.adapt(backend, empty_prims), + Adapt.adapt(backend, empty_descs), + tlas.root_aabb, + ) + return tlas + end + + tlas.static_tlas = StaticTLAS( + Adapt.adapt(backend, tlas.nodes), + Adapt.adapt(backend, tlas.instances), + Adapt.adapt(backend, tlas._flat_blas_nodes), + Adapt.adapt(backend, tlas._flat_blas_prims), + Adapt.adapt(backend, tlas._flat_blas_descs), + tlas.root_aabb, + ) + return tlas +end + +"""Internal: Compact deleted instances and rebuild TLAS BVH.""" +function rebuild_bvh!(tlas::TLAS) + # If there are deletions, compact the instances array + if !isempty(tlas.deleted_handles) + compact_instances!(tlas) + end + + n = length(tlas.instances) + if n == 0 + tlas.nodes = KA.allocate(tlas.backend, BVHNode2, 0) + tlas.root_aabb = Bounds3() + # Drain any stale flat-BLAS arrays from a previous non-empty state so + # `length(tlas._flat_blas_*)` faithfully reports "no live geometry" + # and the next non-empty rebuild starts from a clean slate. + build_flat_blas_arrays!(tlas) + tlas.dirty = false + return + end + + # Build TLAS BVH topology from existing GPU arrays + # blas_array is only used for root_aabb (inline data, safe on Metal) + nodes, root_aabb = build_tlas_topology(tlas.blas_array, tlas.instances, tlas.backend) + + tlas.nodes = nodes + tlas.root_aabb = root_aabb + + # Build flat BLAS arrays during sync (not during adapt_structure). + # This ensures the data is ready before any kernel dispatch. + build_flat_blas_arrays!(tlas) + + tlas.dirty = false + return +end + +"""Internal: Compact instances array by removing deleted handles, and compact BLASes.""" +function compact_instances!(tlas::TLAS) + # Copy valid ranges to new array and update handle mappings + cpu_instances = Array(tlas.instances) + new_instances = InstanceDescriptor[] + new_handle_to_range = Dict{TLASHandle, UnitRange{Int}}() + + for (handle, range) in tlas.handle_to_range + handle in tlas.deleted_handles && continue + new_start = length(new_instances) + 1 + append!(new_instances, cpu_instances[range]) + new_end = length(new_instances) + new_handle_to_range[handle] = new_start:new_end + end + + # Remove deleted handles from tracking + for handle in tlas.deleted_handles + delete!(tlas.handle_to_range, handle) + end + empty!(tlas.deleted_handles) + tlas.handle_to_range = new_handle_to_range + + # Compact BLASes: find which blas_index values are still referenced + used_blas_indices = Set{UInt32}() + for inst in new_instances + push!(used_blas_indices, inst.blas_index) + end + + n_blas = tlas.blas_array === nothing ? 0 : length(tlas.blas_array) + if n_blas > 0 && length(used_blas_indices) < n_blas + # Build old→new index mapping (only for referenced BLASes) + sorted_used = sort!(collect(used_blas_indices)) + old_to_new = Dict{UInt32, UInt32}() + for (new_idx, old_idx) in enumerate(sorted_used) + old_to_new[old_idx] = UInt32(new_idx) + end + + # Remap blas_index in all instances + for i in eachindex(new_instances) + inst = new_instances[i] + new_blas_idx = old_to_new[inst.blas_index] + if new_blas_idx != inst.blas_index + new_instances[i] = InstanceDescriptor( + new_blas_idx, inst.instance_id, + inst.transform, inst.inv_transform, inst.flags + ) + end + end + + # Rebuild blas_storage keeping only referenced entries; finalize the rest. + new_blas_storage = BLASArrays[tlas.blas_storage[Int(old_idx)] for old_idx in sorted_used] + for old_idx in UInt32(1):UInt32(n_blas) + old_idx in used_blas_indices && continue + drop = tlas.blas_storage[Int(old_idx)] + finalize(drop.nodes) + finalize(drop.primitives) + end + empty!(tlas.blas_storage) + append!(tlas.blas_storage, new_blas_storage) + + # Rebuild blas_array with only referenced isbits BLASes + cpu_blas = Array(tlas.blas_array) + new_cpu_blas = [cpu_blas[Int(old_idx)] for old_idx in sorted_used] + tlas.blas_array = Adapt.adapt(tlas.backend, new_cpu_blas) + end + + # Update instances on backend + tlas.instances = isempty(new_instances) ? + KA.allocate(tlas.backend, InstanceDescriptor, 0) : + Adapt.adapt(tlas.backend, new_instances) +end + +# ------------------------------------------------------------------------------ +# TLAS: Adapt integration - Returns StaticTLAS for kernel traversal +# ------------------------------------------------------------------------------ + +""" + Adapt.adapt_structure(to, tlas::TLAS) -> StaticTLAS + +Return `tlas.static_tlas` after a `sync!` to make sure it reflects any +pending mutation. `sync!` is the single owner of that field; this function +just reads it. On a clean TLAS, `sync!` is a true no-op — no GPU +synchronize, no allocations — so this is cheap to call per dispatch. Do +call it per dispatch, and do NOT cache the return. + +`tlas.static_tlas` is in kernel-ready isbits form on `tlas.backend`. The +`to` argument is checked: passing a `KA.Backend` other than `tlas.backend` +errors loudly. Cross-backend conversion is not supported — build the TLAS +on the same backend you intend to dispatch kernels against. +""" +function Adapt.adapt_structure(to, tlas::TLAS) + # Loud error on cross-backend adapt: silently returning a static_tlas + # whose arrays live on a different backend than the kernel expects + # surfaces later as a confusing GPUCompiler "non-bitstype argument" error. + # Catch it at the API boundary instead. + # + # Compare by type, not by `===`: two distinct `LavaBackend()` instances + # are semantically the same backend but not object-identical, so `!==` + # would falsely flag a same-backend adapt as cross-backend. + if to isa KA.Backend && typeof(to) !== typeof(tlas.backend) + error( + "Cross-backend Adapt.adapt(::$(typeof(to)), ::TLAS) is not supported. " * + "TLAS was built on $(typeof(tlas.backend)), but adapt was called with $(typeof(to)). " * + "Construct the TLAS on the matching backend (e.g. `Raycore.TLAS($(typeof(to))())`).") + end + sync!(tlas) + return tlas.static_tlas +end + +""" + Adapt.adapt_structure(to, tlas::StaticTLAS) -> StaticTLAS + +Adapt StaticTLAS arrays. If already isbits, returns as-is. +Otherwise adapts each array (CLArray → CLDeviceVector). + +Note: StaticTLAS should come from adapting a mutable TLAS, where BLASes +already have isbits device pointers. Use TLAS(items) to create a mutable +TLAS that properly manages GPU array lifetimes. +""" +function Adapt.adapt_structure(to, tlas::StaticTLAS) + isbitstype(typeof(tlas)) && return tlas + + return StaticTLAS( + Adapt.adapt(to, tlas.nodes), + Adapt.adapt(to, tlas.instances), + Adapt.adapt(to, tlas.all_blas_nodes), + Adapt.adapt(to, tlas.all_blas_prims), + Adapt.adapt(to, tlas.blas_descriptors), + tlas.root_aabb + ) +end + +# Adapt BLAS when adapting to kernel arguments +function Adapt.adapt_structure(to, blas::BLAS) + BLAS( + Adapt.adapt(to, blas.nodes), + Adapt.adapt(to, blas.primitives), + blas.root_aabb + ) +end + +# ============================================================================== +# AABB and Morton Code Utilities +# ============================================================================== + +"""Compute AABB from BVHNode2 for BLAS (BVH2IL format with triangle vertices in leaves).""" +@inline function get_node_aabb(node::BVHNode2, is_interior::Bool)::Bounds3 + if is_interior + # Interior: union of children AABBs + Bounds3( + min.(node.aabb0_min, node.aabb1_min), + max.(node.aabb0_max, node.aabb1_max) + ) + else + # Leaf: vertices stored in aabb slots (BVH2IL format) + # Compute AABB from triangle vertices v0, v1, v2 + v0 = Point3f(node.aabb0_min...) + v1 = Point3f(node.aabb0_max...) + v2 = Point3f(node.aabb1_min...) + + p_min = min.(min.(v0, v1), v2) + p_max = max.(max.(v0, v1), v2) + + Bounds3(p_min, p_max) + end +end + +"""Compute AABB from BVHNode2 for TLAS (AABBs stored directly in leaves).""" +@inline function get_tlas_node_aabb(node::BVHNode2, is_interior::Bool)::Bounds3 + if is_interior + # Interior: union of children AABBs + Bounds3( + min.(node.aabb0_min, node.aabb1_min), + max.(node.aabb0_max, node.aabb1_max) + ) + else + # Leaf: instance AABB stored directly in aabb0 fields + Bounds3(node.aabb0_min, node.aabb0_max) + end +end + +"""3-dilate bits for Morton code (spreads bits by factor of 3).""" +@inline function expand_bits(x::UInt32)::UInt32 + x = (x * 0x00010001) & 0xFF0000FF + x = (x * 0x00000101) & 0x0F00F00F + x = (x * 0x00000011) & 0xC30C30C3 + x = (x * 0x00000005) & 0x49249249 + return x +end + +""" +Calculate 30-bit Morton code from normalized 3D point [0,1]³. +Interleaves x,y,z bits to create space-filling Z-curve ordering. +""" +@inline function morton_code_30bit(p::Point3f)::UInt32 + # Clamp to [0, 1023] for 10-bit precision per axis + unit_side = 1024.0f0 + x = clamp(p[1] * unit_side, 0.0f0, unit_side - 1.0f0) + y = clamp(p[2] * unit_side, 0.0f0, unit_side - 1.0f0) + z = clamp(p[3] * unit_side, 0.0f0, unit_side - 1.0f0) + + # Interleave bits: xxyyzzxxyyzzxxyyzz... + return (expand_bits(unsafe_trunc(UInt32, x)) << 2) | + (expand_bits(unsafe_trunc(UInt32, y)) << 1) | + expand_bits(unsafe_trunc(UInt32, z)) +end + +"""Count leading zeros (clz) for 32-bit integer.""" +@inline function clz32(x::UInt32)::Int32 + x == 0 && return Int32(32) + return Int32(31 - (sizeof(UInt32)*8 - 1 - leading_zeros(x))) +end + +""" +Compute longest common prefix (LCP) of Morton codes. +Uses index fallback when codes are identical. +""" +@inline function delta(i1::Int32, i2::Int32, morton_codes::AbstractVector{UInt32}, num_prims::Int32)::Int32 + # Bounds check + left = min(i1, i2) + right = max(i1, i2) + + (left < 1 || right > num_prims) && return Int32(-1) + + left_code = morton_codes[left] + right_code = morton_codes[right] + + # If codes differ, return common prefix length + # If codes are same, use indices as tiebreaker + if left_code != right_code + return Int32(clz32(left_code ⊻ right_code)) + else + return Int32(32 + clz32(UInt32(left) ⊻ UInt32(right))) + end +end + +"""Find the span of primitives covered by this internal node (Karras 2012).""" +@inline function find_span_for_node( + idx::Int32, + morton_codes::AbstractVector{UInt32}, + n_prims::Int32 +)::Tuple{Int32, Int32} + # Determine direction + d_left = delta(idx, idx - Int32(1), morton_codes, n_prims) + d_right = delta(idx, idx + Int32(1), morton_codes, n_prims) + d = d_right > d_left ? Int32(1) : Int32(-1) + + # Compute upper bound for length + delta_min = delta(idx, idx - d, morton_codes, n_prims) + l_max = Int32(2) + while delta(idx, idx + l_max * d, morton_codes, n_prims) > delta_min + l_max *= Int32(2) + end + + # Binary search for the other end + l = Int32(0) + t = l_max + while t > Int32(1) + t = t ÷ Int32(2) + if delta(idx, idx + (l + t) * d, morton_codes, n_prims) > delta_min + l = l + t + end + end + j = idx + l * d + + # Return sorted span + return d > Int32(0) ? (idx, j) : (j, idx) +end + +"""Find the split position within a span (Karras 2012).""" +@inline function find_split_in_span( + span_left::Int32, + span_right::Int32, + morton_codes::AbstractVector{UInt32}, + n_prims::Int32 +)::Int32 + # Calculate the number of identical bits from higher end + numidentical = delta(span_left, span_right, morton_codes, n_prims) + + # Binary search for split position using midpoint + left = span_left + right = span_right + while right > left + Int32(1) + # Proposed split at midpoint + newsplit = (right + left) ÷ Int32(2) + + # If it has more equal leading bits than left and right, accept it + if delta(left, newsplit, morton_codes, n_prims) > numidentical + left = newsplit + else + right = newsplit + end + end + + return left +end + +"""Compute leaf node index from primitive index.""" +@inline function leaf_index(prim_idx::Integer, n_prims::Int32)::Int + return Int(n_prims) - 1 + prim_idx +end + +"""Refit AABB for one internal node from its children.""" +@inline function refit_node_aabb( + node_idx::Int32, + nodes::AbstractVector{BVHNode2}, + n_prims::Int32 +)::BVHNode2 + @inbounds node = nodes[node_idx] + child0 = node.child0 + child1 = node.child1 + + is_child0_internal = child0 < n_prims + is_child1_internal = child1 < n_prims + + aabb0 = get_node_aabb(nodes[child0], is_child0_internal) + aabb1 = get_node_aabb(nodes[child1], is_child1_internal) + + return BVHNode2( + aabb0.p_min, aabb0.p_max, + aabb1.p_min, aabb1.p_max, + node.child0, node.child1, node.parent + ) +end + +@inline function refit_tlas_node_aabb( + node_idx::Int32, + nodes::AbstractVector{BVHNode2}, + n_instances::Int32 +)::BVHNode2 + @inbounds node = nodes[node_idx] + child0 = node.child0 + child1 = node.child1 + + is_child0_internal = child0 < n_instances + is_child1_internal = child1 < n_instances + + aabb0 = get_tlas_node_aabb(nodes[child0], is_child0_internal) + aabb1 = get_tlas_node_aabb(nodes[child1], is_child1_internal) + + return BVHNode2( + aabb0.p_min, aabb0.p_max, + aabb1.p_min, aabb1.p_max, + node.child0, node.child1, node.parent + ) +end + +# ============================================================================== +# BLAS Construction (LBVH Algorithm) +# ============================================================================== + +""" + build_blas(primitives) -> BLAS + +Build a Bottom-Level Acceleration Structure using Linear BVH (LBVH). + +Uses KernelAbstractions for automatic CPU/GPU execution based on input array type. + +Algorithm: +1. Compute scene AABB +2. Calculate Morton codes in parallel (GPU kernel) +3. Sort primitives by Morton code +4. Build binary radix tree in parallel (GPU kernel) +5. Compute AABBs bottom-up + +Based on Karras 2012 "Maximizing Parallelism in the Construction of BVHs, Octrees, and k-d Trees" + +# Arguments +- `primitives`: Vector or GPU array of Triangle objects (array type determines backend) + +# Example +```julia +# CPU execution +blas_cpu = build_blas(triangles) # Vector{Triangle} + +# GPU execution (CUDA) +using CUDA +gpu_triangles = CuArray(triangles) +blas_gpu = build_blas(gpu_triangles) # CuArray{Triangle} +``` +""" +function build_blas( + primitives::AbstractVector{T} +) where {T <: Triangle} + n = length(primitives) + n == 0 && error("Cannot build BLAS from empty primitive list") + + # Infer backend from input array type + backend = KA.get_backend(primitives) + + # Compute scene AABB (works on both CPU and GPU arrays) + scene_aabb = mapreduce(world_bound, ∪, primitives, init=Bounds3()) + scene_min = scene_aabb.p_min + scene_extent = Vec3f(scene_aabb.p_max - scene_aabb.p_min) + + # Allocate arrays on same backend as input + morton_codes = KA.allocate(backend, UInt32, n) + + # Launch kernel: Calculate Morton codes + calc_kernel! = calculate_morton_codes_kernel!(backend) + calc_kernel!(morton_codes, primitives, scene_min, scene_extent, ndrange=n) + + # Sort primitives by Morton codes + # AcceleratedKernels only supports GPU backends, use Julia's sortperm for CPU + perm = AK.sortperm(morton_codes) + KA.synchronize(backend) # Ensure sort temp buffers aren't freed while GPU is still using them + morton_codes = morton_codes[perm] + primitives = primitives[perm] + + # Allocate nodes and initialize with empty values + # Use kernel-based fill (OpenCL's fill! doesn't support struct types) + nodes = KA.allocate(backend, BVHNode2, 2*n - 1) + empty_node = BVHNode2( + Point3f(0), Point3f(0), Point3f(0), Point3f(0), + INVALID_NODE, INVALID_NODE, INVALID_NODE + ) + fill_kernel! = fill_bvhnode2_kernel!(backend) + fill_kernel!(nodes, empty_node, ndrange=length(nodes)) + + # Launch kernel: Emit topology (only if n > 1, i.e., there are internal nodes) + if n > 1 + topo_kernel! = emit_topology_kernel!(backend) + topo_kernel!(nodes, morton_codes, Int32(n), ndrange=n-1) + + # Launch kernel: Set parent pointers + parent_kernel! = set_parent_pointers_kernel!(backend) + parent_kernel!(nodes, Int32(n), ndrange=n-1) + end + + # Launch kernel: Create leaf nodes + leaf_kernel! = create_leaf_nodes_kernel!(backend) + leaf_kernel!(nodes, primitives, Int32(n), ndrange=n) + # Ensure leaf writes are visible before the cross-workgroup atomic refit pass. + KA.synchronize(backend) + + # Refit AABBs bottom-up (parallel using atomic counters) + update_flags = KA.zeros(backend, UInt32, n - 1) # One flag per internal node + refit_kernel! = refit_aabbs_kernel!(backend) + refit_kernel!(nodes, update_flags, Int32(n), ndrange=n) + + # Compute root AABB - check if root is interior or leaf + # Use explicit copy to CPU to avoid scalar indexing issues on GPU + KA.synchronize(backend) + root_node = Array(nodes[1:1])[1] + root_is_interior = is_interior(root_node) + root_aabb = get_node_aabb(root_node, root_is_interior) + + return BLAS(nodes, primitives, root_aabb) +end + +# ============================================================================== +# TLAS Construction +# ============================================================================== + +"""Build topology for one TLAS internal node (same algorithm as BLAS).""" +@inline function build_tlas_topology_for_node( + idx::Int32, + morton_codes::AbstractVector{UInt32}, + n_instances::Int32 +)::BVHNode2 + # Helper function + @inline leaf_idx(j::Int32) = n_instances - Int32(1) + j + + # Find span + span_left, span_right = find_span_for_node(idx, morton_codes, n_instances) + + # Find split + split = find_split_in_span(span_left, span_right, morton_codes, n_instances) + + # Determine children (matches HLSL reference exactly) + # child0 is leaf only if split == span_left + # child1 is leaf only if split + 1 == span_right + child0 = (split == span_left) ? leaf_idx(split) : split + child1_idx = split + Int32(1) + child1 = (child1_idx == span_right) ? leaf_idx(child1_idx) : child1_idx + + return BVHNode2( + Point3f(0), Point3f(0), Point3f(0), Point3f(0), + UInt32(child0), UInt32(child1), INVALID_NODE + ) +end + +""" + build_tlas_topology(blas_array, instances, backend) -> (nodes, root_aabb) + +Internal: Build TLAS BVH topology (Morton codes, sorting, tree construction, refit). +Returns (nodes, root_aabb). Only accesses blas_array for root_aabb (inline data). + +`instances` must already be on the backend. +""" +function build_tlas_topology(blas_array, instances, backend) + n = length(instances) + + # Compute scene AABB from transformed instance bounds using GPU kernel + # Allocate arrays for per-instance world AABBs + aabb_mins = KA.allocate(backend, Point3f, n) + aabb_maxs = KA.allocate(backend, Point3f, n) + + # Launch kernel to compute world AABBs in parallel + aabb_kernel! = compute_instance_aabbs_kernel!(backend) + aabb_kernel!(aabb_mins, aabb_maxs, instances, blas_array, ndrange=n) + KA.synchronize(backend) + + # Copy results to CPU and compute scene AABB via reduction + cpu_mins = Array(aabb_mins) + cpu_maxs = Array(aabb_maxs) + + scene_min_p = cpu_mins[1] + scene_max_p = cpu_maxs[1] + for i in 2:n + scene_min_p = Point3f(min(scene_min_p[1], cpu_mins[i][1]), + min(scene_min_p[2], cpu_mins[i][2]), + min(scene_min_p[3], cpu_mins[i][3])) + scene_max_p = Point3f(max(scene_max_p[1], cpu_maxs[i][1]), + max(scene_max_p[2], cpu_maxs[i][2]), + max(scene_max_p[3], cpu_maxs[i][3])) + end + scene_aabb = Bounds3(scene_min_p, scene_max_p) + + scene_min = scene_aabb.p_min + aabb_extent = scene_aabb.p_max - scene_aabb.p_min + # Handle degenerate cases (avoid division by zero) + scene_extent = Vec3f( + max(aabb_extent[1], 1f-6), + max(aabb_extent[2], 1f-6), + max(aabb_extent[3], 1f-6) + ) + + # Calculate Morton codes on same backend as input + morton_codes = KA.allocate(backend, UInt32, n) + calc_kernel! = calculate_tlas_morton_codes_kernel!(backend) + calc_kernel!(morton_codes, instances, blas_array, scene_min, scene_extent, ndrange=n) + KA.synchronize(backend) + + # Sort instances by Morton codes. + # On Lava, merge_sort_by_key! with a 64-bit Int payload can corrupt the + # permutation vector, which later sends TLAS leaf creation out of bounds. + # Use sortperm like the BLAS path so the permutation type matches the backend. + if backend isa KA.CPU + sorted_indices = sortperm(morton_codes) + morton_codes .= morton_codes[sorted_indices] + else + sorted_indices = AK.sortperm(morton_codes) + KA.synchronize(backend) # Ensure sort temp buffers aren't freed while GPU is still using them + morton_codes = morton_codes[sorted_indices] + end + + # Allocate nodes and initialize with empty values + # Use kernel-based fill (OpenCL's fill! doesn't support struct types) + nodes = KA.allocate(backend, BVHNode2, max(1, 2*n - 1)) + empty_node = BVHNode2( + Point3f(0), Point3f(0), Point3f(0), Point3f(0), + INVALID_NODE, INVALID_NODE, INVALID_NODE + ) + fill_kernel! = fill_bvhnode2_kernel!(backend) + fill_kernel!(nodes, empty_node, ndrange=length(nodes)) + + # Single-instance case: trivial TLAS + if n == 1 + # For CPU, sorted_indices is already a CPU array; for GPU, copy to avoid scalar indexing + original_idx = backend isa KA.CPU ? sorted_indices[1] : Array(sorted_indices[1:1])[1] + # Use scene_aabb computed from kernel (same as the single instance's world AABB) + world_aabb = scene_aabb + + # Create leaf node on CPU and copy to backend + leaf_node = BVHNode2( + world_aabb.p_min, world_aabb.p_max, + Point3f(0), Point3f(0), + INVALID_NODE, UInt32(original_idx - 1), + INVALID_NODE + ) + cpu_nodes = [leaf_node] + copyto!(nodes, Adapt.adapt(backend, cpu_nodes)) + + return (nodes, world_aabb) + end + + # Multi-instance case: build proper LBVH + # Launch kernel: Emit topology (reuse BLAS topology kernel - same algorithm) + topo_kernel! = emit_topology_kernel!(backend) + topo_kernel!(nodes, morton_codes, Int32(n), ndrange=n-1) + # Launch kernel: Set parent pointers + parent_kernel! = set_parent_pointers_kernel!(backend) + parent_kernel!(nodes, Int32(n), ndrange=n-1) + # Launch kernel: Create TLAS leaf nodes (different from BLAS - stores AABBs, not vertices) + leaf_kernel! = create_tlas_leaf_nodes_kernel!(backend) + leaf_kernel!(nodes, sorted_indices, instances, blas_array, Int32(n), ndrange=n) + # Ensure leaf writes are visible before the cross-workgroup atomic refit pass. + KA.synchronize(backend) + # Refit AABBs bottom-up (parallel using atomic counters) + update_flags = KA.zeros(backend, UInt32, n - 1) + refit_kernel! = refit_tlas_aabbs_kernel!(backend) + refit_kernel!(nodes, update_flags, Int32(n), ndrange=n) + + # Get root AABB (copy to CPU to avoid scalar indexing) + root_node = Array(nodes[1:1])[1] + root_aabb = get_tlas_node_aabb(root_node, true) + + return (nodes, root_aabb) +end + +""" + build_tlas(blas_array::AbstractVector{BLAS}, instances::AbstractVector{InstanceDescriptor}) -> StaticTLAS + +Build a Top-Level Acceleration Structure over instances. +Uses LBVH over transformed instance AABBs. + +Returns a StaticTLAS with flat BLAS arrays suitable for ray traversal. +Uses KernelAbstractions for automatic CPU/GPU execution based on input array type. +""" +function build_tlas( + blas_array::AbstractVector{B}, + instances::AbstractVector{InstanceDescriptor} +) where {B <: BLAS} + n_blas = length(blas_array) + n = length(instances) + + if n == 0 + prim_type = n_blas > 0 ? eltype(blas_array[1].primitives) : Triangle{UInt32} + return StaticTLAS( + BVHNode2[], instances, + BVHNode2[], prim_type[], + BLASDescriptor[], + Bounds3() + ) + end + + backend = KA.get_backend(blas_array) + backend_instances = Adapt.adapt(backend, instances) + + nodes, root_aabb = build_tlas_topology(blas_array, backend_instances, backend) + + # Build flat arrays from BLAS data + descriptors = Vector{BLASDescriptor}(undef, n_blas) + total_nodes = 0 + total_prims = 0 + for i in 1:n_blas + descriptors[i] = BLASDescriptor(UInt32(total_nodes), UInt32(total_prims), blas_array[i].root_aabb) + total_nodes += length(blas_array[i].nodes) + total_prims += length(blas_array[i].primitives) + end + + all_nodes = similar(blas_array[1].nodes, total_nodes) + all_prims = similar(blas_array[1].primitives, total_prims) + nodes_pos = 1 + prims_pos = 1 + for i in 1:n_blas + nn = length(blas_array[i].nodes) + copyto!(all_nodes, nodes_pos, blas_array[i].nodes, 1, nn) + nodes_pos += nn + np = length(blas_array[i].primitives) + copyto!(all_prims, prims_pos, blas_array[i].primitives, 1, np) + prims_pos += np + end + + return StaticTLAS(nodes, backend_instances, all_nodes, all_prims, descriptors, root_aabb) +end + + +# Type union for traversal - both TLAS and StaticTLAS have the same traversal-relevant fields +const TraversableTLAS = Union{TLAS, StaticTLAS} + +# ============================================================================== +# Transform Utilities +# ============================================================================== + +# Mat4f (column-major 4×4) → Mat3x4f (Vulkan row-major 3×4). +# The upper three rows of the 4×4 become the three Vulkan rows. +@inline function mat4_to_mat3x4(m)::Mat3x4f + Mat3x4f( + Float32(m[1,1]), Float32(m[1,2]), Float32(m[1,3]), Float32(m[1,4]), + Float32(m[2,1]), Float32(m[2,2]), Float32(m[2,3]), Float32(m[2,4]), + Float32(m[3,1]), Float32(m[3,2]), Float32(m[3,3]), Float32(m[3,4]), + ) +end + +# Affine inverse of a Mat3x4f. +# SMatrix{4,3} stores columns: col k = [Vk_row0_colK, Vk_row1_colK, Vk_row2_colK, tx/ty/tz]. +# The upper-left SMatrix{3,3} equals inv(A_vulkan) in row-major terms; translation inverts as +# -(inv(A_vulkan)^T * t) which reduces to -(B^T * t) with B = inv(upper-left 3×3). +@inline function mat3x4_inverse(m::Mat3x4f)::Mat3x4f + R = m[SOneTo(3), SOneTo(3)] + B = inv(R) + tx, ty, tz = m[4,1], m[4,2], m[4,3] + t_inv_x = -(B[1,1]*tx + B[2,1]*ty + B[3,1]*tz) + t_inv_y = -(B[1,2]*tx + B[2,2]*ty + B[3,2]*tz) + t_inv_z = -(B[1,3]*tx + B[2,3]*ty + B[3,3]*tz) + return Mat3x4f( + B[1,1], B[2,1], B[3,1], t_inv_x, + B[1,2], B[2,2], B[3,2], t_inv_y, + B[1,3], B[2,3], B[3,3], t_inv_z, + ) +end + +# Transform point by a Mat3x4f (Vulkan row-major 3×4 affine transform). +# m[j+1, i+1] = Vulkan element (row i, col j), so each result component is +# a dot of one Julia column with (p..., 1). +@inline function transform_point(m::Mat3x4f, p::Point3f)::Point3f + Point3f( + m[1,1] * p[1] + m[2,1] * p[2] + m[3,1] * p[3] + m[4,1], + m[1,2] * p[1] + m[2,2] * p[2] + m[3,2] * p[3] + m[4,2], + m[1,3] * p[1] + m[2,3] * p[2] + m[3,3] * p[3] + m[4,3], + ) +end + +# Transform point by a homogeneous Mat4f (Julia column-major; translation in column 4). +# Standard graphics convention: out_i = sum_j m[i,j]*p[j] + m[i,4]. +@inline function transform_point(m::Mat4f, p::Point3f)::Point3f + Point3f( + m[1,1] * p[1] + m[1,2] * p[2] + m[1,3] * p[3] + m[1,4], + m[2,1] * p[1] + m[2,2] * p[2] + m[2,3] * p[3] + m[2,4], + m[3,1] * p[1] + m[3,2] * p[2] + m[3,3] * p[3] + m[3,4], + ) +end + +# Transform direction (ignoring translation). +@inline function transform_direction(m::Mat3x4f, v::Vec3f)::Vec3f + Vec3f( + m[1,1] * v[1] + m[2,1] * v[2] + m[3,1] * v[3], + m[1,2] * v[1] + m[2,2] * v[2] + m[3,2] * v[3], + m[1,3] * v[1] + m[2,3] * v[2] + m[3,3] * v[3], + ) +end + +# Mat4f variant — translation column ignored for direction transforms. +@inline function transform_direction(m::Mat4f, v::Vec3f)::Vec3f + Vec3f( + m[1,1] * v[1] + m[1,2] * v[2] + m[1,3] * v[3], + m[2,1] * v[1] + m[2,2] * v[2] + m[2,3] * v[3], + m[3,1] * v[1] + m[3,2] * v[2] + m[3,3] * v[3], + ) +end + +# ============================================================================== +# Two-Level Traversal +# ============================================================================== + +"""Sentinel value to mark top-level to bottom-level transitions.""" +const TOP_LEVEL_SENTINEL = 0xFFFFFFFE + +""" + safe_invdir(d::Vec3f) -> Vec3f + +Safe ray direction inversion that avoids division by zero. +Clamps near-zero components to ±1e-5. +Matches HLSL reference implementation. +""" +@inline function safe_invdir(d::Vec3f)::Vec3f + ooeps = 1.0f-5 + inv_x = 1.0f0 / (abs(d[1]) > ooeps ? d[1] : copysign(ooeps, d[1])) + inv_y = 1.0f0 / (abs(d[2]) > ooeps ? d[2] : copysign(ooeps, d[2])) + inv_z = 1.0f0 / (abs(d[3]) > ooeps ? d[3] : copysign(ooeps, d[3])) + return Vec3f(inv_x, inv_y, inv_z) +end + +""" + fast_intersect_triangle(ray_o, ray_d, v0, v1, v2, t_min, closest_t) -> (hit, t, u, v) + +Möller-Trumbore ray-triangle intersection test. +Matches HLSL reference implementation. +""" +@inline function fast_intersect_triangle( + ray_o::Point3f, ray_d::Vec3f, + v0::Point3f, v1::Point3f, v2::Point3f, + t_min::Float32, closest_t::Float32 +) + # Edge vectors + e1 = v1 - v0 + e2 = v2 - v0 + + # Begin calculating determinant - also used to calculate u parameter + s1 = cross(ray_d, e2) + determinant = dot(s1, e1) + invd = 1.0f0 / determinant + + # Calculate distance from v0 to ray origin + d = ray_o - v0 + u = dot(d, s1) * invd + + # Test u parameter + if u < 0.0f0 || u > 1.0f0 + return (false, 0.0f0, 0.0f0, 0.0f0) + end + + # Prepare to test v parameter + s2 = cross(d, e1) + v = dot(ray_d, s2) * invd + + # Test v parameter + if v < 0.0f0 || (u + v) > 1.0f0 + return (false, 0.0f0, 0.0f0, 0.0f0) + end + + # Calculate t + t = dot(e2, s2) * invd + + # Test t against range + if t < t_min || t > closest_t + return (false, 0.0f0, 0.0f0, 0.0f0) + end + + return (true, t, u, v) +end + +""" + intersect_internal_node(node, ray_inv_d, ray_o, t_min, t_max) -> (near_child, far_child) + +Test ray against internal node's two children AABBs. +Returns ordered children indices (near first, far second). +INVALID_NODE if child is not intersected. +Matches HLSL IntersectInternalNode. +""" +@inline function intersect_internal_node( + node::BVHNode2, + ray_inv_d::Vec3f, + ray_o::Point3f, + t_min::Float32, + t_max::Float32 +) + # Get child AABBs + aabb0 = Bounds3(Point3f(node.aabb0_min...), Point3f(node.aabb0_max...)) + aabb1 = Bounds3(Point3f(node.aabb1_min...), Point3f(node.aabb1_max...)) + + # Test both children + t0_min, t0_max = fast_intersect_bbox(ray_o, ray_inv_d, aabb0, t_min, t_max) + t1_min, t1_max = fast_intersect_bbox(ray_o, ray_inv_d, aabb1, t_min, t_max) + + # Determine which children to traverse + traverse0 = (t0_min <= t0_max) ? node.child0 : INVALID_NODE + traverse1 = (t1_min <= t1_max) ? node.child1 : INVALID_NODE + + # Order by distance (near first) + if t0_min < t1_min && traverse0 != INVALID_NODE + return (traverse0, traverse1) + else + return (traverse1, traverse0) + end +end + +""" + fast_intersect_bbox(ray_o, ray_inv_d, bbox, t_min, t_max) -> (entry_t, exit_t) + +Fast ray-AABB intersection using slab method. +Returns parametric distances to entry and exit points. +Matches HLSL fast_intersect_bbox. +""" +@inline function fast_intersect_bbox( + ray_o::Point3f, + ray_inv_d::Vec3f, + bbox::Bounds3, + t_min::Float32, + t_max::Float32 +) + oxinvdir = -ray_o .* ray_inv_d + f = bbox.p_max .* ray_inv_d .+ oxinvdir + n = bbox.p_min .* ray_inv_d .+ oxinvdir + + tmax_vec = max.(f, n) + tmin_vec = min.(f, n) + + max_t = min(minimum(tmax_vec), t_max) + min_t = max(maximum(tmin_vec), t_min) + + return (min_t, max_t) +end + +""" + intersect_leaf_node(node, ray_d, ray_o, t_min, closest_t) -> (hit, t, u, v) + +Test ray against triangle stored in leaf node. +Returns hit status and intersection parameters. +Matches HLSL IntersectLeafNode. +""" +@inline function intersect_leaf_node( + node::BVHNode2, + ray_d::Vec3f, + ray_o::Point3f, + t_min::Float32, + closest_t::Float32 +) + # In BVH2IL format, leaf nodes store triangle vertices in AABB slots + v0 = Point3f(node.aabb0_min...) + v1 = Point3f(node.aabb0_max...) + v2 = Point3f(node.aabb1_min...) + + return fast_intersect_triangle(ray_o, ray_d, v0, v1, v2, t_min, closest_t) +end + +""" + closest_hit(tlas::TLAS, ray::AbstractRay) -> (hit, primitive, distance, barycentric, instance_idx) + +Traverse two-level BVH to find closest ray intersection. + +`instance_idx` is the 1-based position in `tlas.instances` (or `UInt32(0)` +on miss). Dereferencing `tlas.instances[instance_idx]` yields the full +`InstanceDescriptor` — from which the caller can read the transforms, the +interface-override `instance_id`, or anything else. This keeps a single +source of truth (the TLAS instance array) rather than duplicating a +pre-extracted sub-field. + +Algorithm: +1. Traverse TLAS to find candidate instances +2. Transform ray to local space +3. Traverse BLAS for geometry intersection +4. Transform back to world space +5. Return closest hit across all instances +""" +@inline function closest_hit(tlas::StaticTLAS, ray::R) where {R <: AbstractRay} + # Initialize traversal state - matches HLSL TraceRays + ray = check_direction(ray) + ray_o::Point3f = ray.o + ray_d::Vec3f = ray.d + ray_mint::Float32 = ray.t_min # Minimum t for intersection + ray_maxt::Float32 = ray.t_max + ray_inv_d::Vec3f = safe_invdir(ray_d) # Use safe inversion to avoid division by zero + + # Stack for traversal (32 entries sufficient for typical BVH depths of ~20 levels) + stack = MVector{32, UInt32}(undef) + stack_ptr::Int32 = Int32(1) + @inbounds stack[stack_ptr] = INVALID_NODE + + # Traversal state - use Int32 for indices to avoid UInt32 arithmetic issues + current_instance::Int32 = Int32(-1) # -1 means no instance (top level) + closest_instance::Int32 = Int32(-1) + closest_prim::UInt32 = INVALID_NODE + hit_u::Float32 = 0.0f0 + hit_v::Float32 = 0.0f0 + + # Entry point is node 1 (1-indexed in Julia) + node_index::UInt32 = UInt32(1) + + # Cached BLAS offset for current instance (avoids repeated descriptor lookup) + current_blas_offset::UInt32 = UInt32(0) + + # Get typed references to avoid repeated field access + tlas_nodes = tlas.nodes + tlas_instances = tlas.instances + tlas_blas_nodes = tlas.all_blas_nodes + tlas_blas_prims = tlas.all_blas_prims + tlas_blas_descs = tlas.blas_descriptors + + @inbounds while node_index != INVALID_NODE + # Fetch node based on current level + node::BVHNode2 = if current_instance < Int32(0) + tlas_nodes[node_index] + else + tlas_blas_nodes[current_blas_offset + node_index] + end + + is_leaf::Bool = (node.child0 == INVALID_NODE) + + if !is_leaf + # Interior node - test both children and get ordered traversal + near_child, far_child = intersect_internal_node(node, ray_inv_d, ray_o, ray_mint, ray_maxt) + + # Push far child if valid + if far_child != INVALID_NODE + stack_ptr += Int32(1) + stack[stack_ptr] = far_child + end + + # Visit near child if valid + if near_child != INVALID_NODE + node_index = near_child + continue + end + elseif current_instance < Int32(0) + # Top-level leaf - transition to instance + current_instance = Int32(node.child1) # 0-indexed instance index + + # Push sentinel + stack_ptr += Int32(1) + stack[stack_ptr] = TOP_LEVEL_SENTINEL + + # Get instance and transform ray + node_index = UInt32(1) # Start at root of BLAS + inst = tlas_instances[current_instance + Int32(1)] + desc = tlas_blas_descs[inst.blas_index] + current_blas_offset = desc.nodes_offset + ray_o = transform_point(inst.inv_transform, ray.o) + ray_d = transform_direction(inst.inv_transform, ray.d) + ray_inv_d = safe_invdir(ray_d) + continue + else + # Bottom-level leaf - test triangle + hit, t, u, v = intersect_leaf_node(node, ray_d, ray_o, ray_mint, ray_maxt) + if hit + # Update closest hit + ray_maxt = t + closest_instance = current_instance + closest_prim = node.child1 + hit_u = u + hit_v = v + end + end + + # Pop from stack + node_index = stack[stack_ptr] + stack_ptr -= Int32(1) + + # Check for level transition + if node_index == TOP_LEVEL_SENTINEL + # Return to top level + node_index = stack[stack_ptr] + stack_ptr -= Int32(1) + current_instance = Int32(-1) + + # Restore original ray + ray_o = ray.o + ray_d = ray.d + ray_inv_d = safe_invdir(ray_d) + end + end + + # Fill in hit output - matches HLSL + @inbounds if closest_instance >= Int32(0) + inst_idx = UInt32(closest_instance + Int32(1)) + inst = tlas_instances[inst_idx] + desc = tlas_blas_descs[inst.blas_index] + tri = tlas_blas_prims[desc.primitives_offset + closest_prim] + w = 1.0f0 - hit_u - hit_v + bary = SVector{3, Float32}(w, hit_u, hit_v) + return (true, tri, ray_maxt, bary, inst_idx) + else + # No hit - return zero sentinel + dummy_tri = empty_triangle(eltype(tlas_blas_prims)) + bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) + return (false, dummy_tri, 0.0f0, bary, UInt32(0)) + end +end + +""" + any_hit(tlas::TLAS, ray::AbstractRay) -> (hit, primitive, distance, barycentric, instance_idx) + +Traverse two-level BVH to find ANY ray intersection (returns on first hit). +Faster than closest_hit when only occlusion testing is needed. + +Matches HLSL TraceRays with ANY_HIT defined. +""" +@inline function any_hit(tlas::StaticTLAS, ray::R) where {R <: AbstractRay} + # Initialize traversal state - matches HLSL TraceRays + ray = check_direction(ray) + ray_o::Point3f = ray.o + ray_d::Vec3f = ray.d + ray_mint::Float32 = 0.0f0 + ray_maxt::Float32 = ray.t_max + ray_inv_d::Vec3f = safe_invdir(ray_d) + + # Stack for traversal (32 entries sufficient for typical BVH depths of ~20 levels) + stack = MVector{32, UInt32}(undef) + stack_ptr::Int32 = Int32(1) + @inbounds stack[stack_ptr] = INVALID_NODE + + # Traversal state - use Int32 for indices to avoid UInt32 arithmetic issues + current_instance::Int32 = Int32(-1) # -1 means no instance (top level) + # Entry point is node 1 (1-indexed in Julia) + node_index::UInt32 = UInt32(1) + current_blas_offset::UInt32 = UInt32(0) + + # Get typed references to avoid repeated field access + tlas_nodes = tlas.nodes + tlas_instances = tlas.instances + tlas_blas_nodes = tlas.all_blas_nodes + tlas_blas_prims = tlas.all_blas_prims + tlas_blas_descs = tlas.blas_descriptors + + @inbounds while node_index != INVALID_NODE + # Fetch node based on current level + node::BVHNode2 = if current_instance < Int32(0) + tlas_nodes[node_index] + else + tlas_blas_nodes[current_blas_offset + node_index] + end + + is_leaf::Bool = (node.child0 == INVALID_NODE) + + if !is_leaf + # Interior node - test both children and get ordered traversal + near_child, far_child = intersect_internal_node(node, ray_inv_d, ray_o, ray_mint, ray_maxt) + + # Push far child if valid + if far_child != INVALID_NODE + stack_ptr += Int32(1) + stack[stack_ptr] = far_child + end + + # Visit near child if valid + if near_child != INVALID_NODE + node_index = near_child + continue + end + elseif current_instance < Int32(0) + # Top-level leaf - transition to instance + current_instance = Int32(node.child1) # 0-indexed instance index + + # Push sentinel + stack_ptr += Int32(1) + stack[stack_ptr] = TOP_LEVEL_SENTINEL + + # Get instance and transform ray + node_index = UInt32(1) # Start at root of BLAS + inst = tlas_instances[current_instance + Int32(1)] + desc = tlas_blas_descs[inst.blas_index] + current_blas_offset = desc.nodes_offset + ray_o = transform_point(inst.inv_transform, ray.o) + ray_d = transform_direction(inst.inv_transform, ray.d) + ray_inv_d = safe_invdir(ray_d) + continue + else + # Bottom-level leaf - test triangle + hit, t, u, v = intersect_leaf_node(node, ray_d, ray_o, ray_mint, ray_maxt) + if hit + # ANY_HIT: return immediately on first hit + inst_idx = UInt32(current_instance + Int32(1)) + inst = tlas_instances[inst_idx] + desc = tlas_blas_descs[inst.blas_index] + tri = tlas_blas_prims[desc.primitives_offset + node.child1] + w = 1.0f0 - u - v + bary = SVector{3, Float32}(w, u, v) + return (true, tri, t, bary, inst_idx) + end + end + + # Pop from stack + node_index = stack[stack_ptr] + stack_ptr -= Int32(1) + + # Check for level transition + if node_index == TOP_LEVEL_SENTINEL + # Return to top level + node_index = stack[stack_ptr] + stack_ptr -= Int32(1) + current_instance = Int32(-1) + + # Restore original ray + ray_o = ray.o + ray_d = ray.d + ray_inv_d = safe_invdir(ray_d) + end + end + + # No hit found + @inbounds dummy_tri = tlas_blas_prims[1] + bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) + return (false, dummy_tri, 0.0f0, bary, UInt32(0)) +end + +# ============================================================================== +# Helper Functions +# ============================================================================== + +"""Get world-space AABB of a TLAS.""" +function world_bound(tlas::TraversableTLAS)::Bounds3 + return tlas.root_aabb +end + +"""Get world-space AABB of a BLAS.""" +function world_bound(blas::BLAS)::Bounds3 + return blas.root_aabb +end + +# ============================================================================== +# Dynamic TLAS Updates +# ============================================================================== + +""" + update_instance_transform!(tlas::TLAS, instance_idx::Integer, transform) + +Update the transform of a single instance. Call `refit_tlas!` after updating transforms. + +# Arguments +- `tlas`: The TLAS to update +- `instance_idx`: 1-based index of the instance to update +- `transform`: New local-to-world transform — `Mat4f` or canonical Vulkan row-major 3×4 (`Mat3x4f`) +""" +function update_instance_transform!(tlas::TLAS, instance_idx::Integer, transform::Mat3x4f) + @allowscalar begin + old_inst = tlas.instances[instance_idx] + tlas.instances[instance_idx] = InstanceDescriptor( + old_inst.blas_index, + old_inst.instance_id, + transform, + mat3x4_inverse(transform), + old_inst.flags + ) + end + tlas.transforms_dirty = true + return nothing +end + +update_instance_transform!(tlas::TLAS, instance_idx::Integer, transform::Mat4f) = + update_instance_transform!(tlas, instance_idx, mat4_to_mat3x4(transform)) + +""" + refit_tlas!(tlas::TLAS) + +Refit the TLAS after instance transforms have been updated. +Updates leaf AABBs from instance transforms and propagates changes up the tree. + +This is much faster than rebuilding the TLAS from scratch when only transforms change. +Operates directly on the backend arrays stored in the TLAS. +""" +function refit_tlas!(tlas::TLAS) + tlas.transforms_dirty || return tlas + n = length(tlas.instances) + n == 0 && (tlas.transforms_dirty = false; return tlas) + backend = tlas.backend + + # Update leaf node AABBs from new transforms (kernel) + # blas_array is only used for root_aabb (inline data, safe on Metal) + leaf_kernel! = update_tlas_leaf_aabbs_kernel!(backend) + leaf_kernel!(tlas.nodes, tlas.instances, tlas.blas_array, Int32(n), ndrange=n) + KA.synchronize(backend) + # Refit internal nodes bottom-up using atomic counters + if n > 1 + update_flags = KA.zeros(backend, UInt32, n - 1) + refit_kernel! = refit_tlas_aabbs_kernel!(backend) + refit_kernel!(tlas.nodes, update_flags, Int32(n), ndrange=n) + KA.synchronize(backend) + end + + # Propagate the refitted root-node AABB back to the cached world bound so + # `world_bound(tlas)` reports the post-transform extents (rays/grids depend on this). + root_node = Array(tlas.nodes[1:1])[1] + tlas.root_aabb = get_tlas_node_aabb(root_node, is_interior(root_node)) + tlas.transforms_dirty = false + return tlas +end + +function update_instance_transforms!(tlas::TLAS, transforms::AbstractVector{Mat3x4f}, n_to_update::Integer) + backend = KA.get_backend(transforms) + kernel! = update_instance_transforms_kernel!(backend) + kernel!(tlas.instances, transforms, Int32(n_to_update), ndrange=n_to_update) + KA.synchronize(backend) + tlas.transforms_dirty = true + return nothing +end + +# Mat4f convenience: convert host-side and re-dispatch. GPU-resident Mat4f +# arrays are converted via a CPU map; that's fine for the only known caller +# (RayMakie's mesh.jl, which builds a 1-element CPU array per update). +update_instance_transforms!(tlas::TLAS, transforms::AbstractVector{Mat4f}, n_to_update::Integer) = + update_instance_transforms!(tlas, map(mat4_to_mat3x4, transforms), n_to_update) + +function update_instance_transforms!(tlas::TLAS, transforms::AbstractVector{Mat3x4f}, n_to_update::Integer, first_idx::Integer) + backend = KA.get_backend(transforms) + kernel! = update_instance_transforms_offset_kernel!(backend) + kernel!(tlas.instances, transforms, Int32(n_to_update), Int32(first_idx), ndrange=n_to_update) + KA.synchronize(backend) + tlas.transforms_dirty = true + return nothing +end + +update_instance_transforms!(tlas::TLAS, transforms::AbstractVector{Mat4f}, n_to_update::Integer, first_idx::Integer) = + update_instance_transforms!(tlas, map(mat4_to_mat3x4, transforms), n_to_update, first_idx) + + +# ============================================================================== +# BVH-Compatible API +# ============================================================================== + +""" + TLAS(primitives::AbstractVector, metadata_fn::Function; backend=KA.CPU()) + +Universal TLAS constructor. Each primitive (GB.Mesh or AbstractGeometry) becomes +a BLAS with a single instance. + +Each mesh is automatically treated as an instance at identity transform. +Perfect for scenes where you just have different meshes and want automatic instancing. + +GPU-first: Specify backend to build all BLASes directly on GPU. + +Example: +```julia +geometries = [cat_mesh, floor, sphere] +tlas = TLAS(geometries, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + +# GPU-first: +tlas = TLAS(geometries, metadata_fn; backend=OpenCLBackend()) +``` +""" +function TLAS( + primitives::AbstractVector{P}, + metadata_fn::Function; + backend = KA.CPU() +) where {P} + first_metadata = metadata_fn(1, 1) + TMetadata = typeof(first_metadata) + + identity = mat4_to_mat3x4(Mat4f(I)) + blas_array = BLAS[] + instances = InstanceDescriptor[] + + for mi in 1:length(primitives) + prim = primitives[mi] + # Convert to GB.Mesh if needed + gb_mesh = prim isa GeometryBasics.Mesh ? prim : GeometryBasics.uv_normal_mesh(prim) + nmesh = GeometryBasics.expand_faceviews(gb_mesh) + fs = decompose(TriangleFace{UInt32}, nmesh) + verts = decompose(Point3f, nmesh) + norms = Normal3f.(decompose_normals(nmesh)) + uvs_raw = GeometryBasics.decompose_uv(nmesh) + uvs = isnothing(uvs_raw) ? Point2f[] : Point2f.(uvs_raw) + indices = collect(reinterpret(UInt32, fs)) + + triangles = Triangle{TMetadata}[] + for i in 1:length(fs) + if !is_degenerate_face(verts, indices, i) + metadata = metadata_fn(mi, i) + push!(triangles, build_triangle(verts, norms, uvs, indices, i, metadata)) + end + end + + # Build BLAS on backend + backend_tris = Adapt.adapt(backend, triangles) + blas = build_blas(backend_tris) + push!(blas_array, blas) + + # Create instance at identity + push!(instances, InstanceDescriptor( + UInt32(length(blas_array)), + UInt32(mi), + identity, + identity, + UInt32(0) + )) + end + + return build_tlas(blas_array, instances) +end + +# Note: TLAS(meshes::AbstractVector{<:GB.Mesh}) is defined below. + +""" + Base.eltype(tlas::TraversableTLAS) + +Get the element type of primitives stored in the TLAS. Returns the element +type of the first BLAS's primitives; defaults to `Triangle{UInt32}` when the +TLAS has no BLASes yet. +""" +function Base.eltype(tlas::TLAS) + isempty(tlas.blas_storage) && return Triangle{UInt32} + return eltype(tlas.blas_storage[1].primitives) +end + +function Base.eltype(::StaticTLAS{NA, IA, BNA, BPA, DA}) where {NA, IA, BNA, BPA, DA} + return eltype(BPA) +end + +# ============================================================================== +# Convenience TLAS Constructor +# ============================================================================== + +""" + TLAS(meshes::AbstractVector{<:GeometryBasics.Mesh}; backend=KA.CPU()) -> (TLAS, Vector{TLASHandle}) + +Create a mutable TLAS from a vector of GB.Mesh objects. +Each mesh becomes a BLAS with a single instance at identity transform. + +Returns the mutable TLAS and a vector of TLASHandles for later reference. + +# Examples +```julia +tlas, handles = TLAS([floor_mesh, wall_mesh, sphere_mesh]) +``` +""" +function TLAS(meshes::AbstractVector{<:GeometryBasics.Mesh}; backend=KA.CPU()) + isempty(meshes) && error("Cannot create TLAS from empty mesh list") + + # Create mutable TLAS with the specified backend + tlas = TLAS(backend) + handles = TLASHandle[] + + # Push each mesh at identity transform + for mesh in meshes + h = push!(tlas, mesh) + push!(handles, h) + end + + # Sync to build the BVH structure + sync!(tlas) + + return tlas, handles +end + +""" + n_instances(tlas::TraversableTLAS) + +Return total number of *live* instance descriptors in the TLAS. + +For a `TLAS`, instances `delete!`d but not yet compacted by `sync!` are +excluded; the count tracks what the next `sync!` will publish, not the +backing buffer length. `StaticTLAS` has no pending state, so it returns +`length(tlas.instances)` directly. +""" +n_instances(tlas::StaticTLAS) = length(tlas.instances) +function n_instances(tlas::TLAS) + pending = 0 + for h in tlas.deleted_handles + r = get(tlas.handle_to_range, h, nothing) + r === nothing || (pending += length(r)) + end + return length(tlas.instances) - pending +end + +""" + n_geometries(tlas::TraversableTLAS) + +Return number of unique BLAS geometries in the TLAS. +""" +n_geometries(tlas::TLAS) = tlas.blas_array === nothing ? 0 : length(tlas.blas_array) +n_geometries(tlas::StaticTLAS) = length(tlas.blas_descriptors) + +""" + wait_for_gpu!(accel::AbstractAccel) + +Block the CPU until the GPU has completed all prior work that could be +reading `accel` or its adapted form. Default implementation calls +`KA.synchronize` on `accel.backend`. Concrete types that carry their own +queue (e.g. `Lava.HWTLAS`) override this to wait on the specific timeline. + +Convenience only. The per-dispatch hot path does NOT call this; see `sync!`. +""" +function wait_for_gpu!(accel::AbstractAccel) + KA.synchronize(accel.backend) + return accel +end + +# Export public API +export BLAS, BLASDescriptor, TLAS, StaticTLAS, TraversableTLAS, InstanceDescriptor, BVHNode2 +export build_blas, build_tlas, closest_hit, any_hit, world_bound +export INVALID_NODE + +# TLAS Handle API — the only user-facing mutation interface. +# +# Design contract: mutations write directly to the backing buffers +# (`tlas.instances` / `tlas.nodes`), set `dirty` or `transforms_dirty`, and +# return. `sync!` is the single commit boundary that decides +# rebuild-vs-refit and runs it. No staging, no caches keyed on +# `objectid(tlas)`, no global mutable state — every piece of mutable state +# lives on a TLAS instance and dies with it. +# +# `refit_tlas!` / `update_instance_transform!` / `update_instance_transforms!` +# are internal helpers used by the public methods; they must not be called +# directly by user code. +export TLASHandle +export n_instances, n_geometries, get_instance, get_instances +export update_transform!, update_transforms!, is_valid diff --git a/src/kernel-abstractions.jl b/src/kernel-abstractions.jl index fa5446d..d04878a 100644 --- a/src/kernel-abstractions.jl +++ b/src/kernel-abstractions.jl @@ -2,27 +2,46 @@ import KernelAbstractions as KA KA.@kernel some_kernel_f() = nothing -global PRESERVE = [] - function some_kernel(arr) backend = KA.get_backend(arr) return some_kernel_f(backend) end +# Get KernelAbstractions backend from an ArrayType +function _array_type_to_backend(ArrayType) + # Create a small temporary array to get the backend + tmp = ArrayType{Int}(undef, 1) + return KA.get_backend(tmp) +end + +# Convert array to GPU array +# The caller is responsible for keeping the returned array alive. +# Typically this is done by storing in a scene struct. function to_gpu(ArrayType, m::AbstractArray) arr = ArrayType(m) - push!(PRESERVE, arr) - finalizer((arr) -> filter!(x-> x === arr, PRESERVE), arr) kernel = some_kernel(arr) return KA.argconvert(kernel, arr) end -# GPU conversion for BVH -function to_gpu(ArrayType, bvh::Raycore.BVH) - nodes = to_gpu(ArrayType, bvh.nodes) - triangles = to_gpu(ArrayType, bvh.triangles) - primitives = to_gpu(ArrayType, bvh.primitives) - return Raycore.BVH(nodes, triangles, primitives, bvh.max_node_primitives) +# GPU conversion for BLAS (instanced BVH bottom-level) +function to_gpu(ArrayType, blas::Raycore.BLAS) + nodes = to_gpu(ArrayType, blas.nodes) + primitives = to_gpu(ArrayType, blas.primitives) + return Raycore.BLAS(nodes, primitives, blas.root_aabb) +end + +# GPU conversion for TLAS - use Adapt to create StaticTLAS for kernel traversal +function to_gpu(ArrayType, tlas::Raycore.TLAS) + # Get the backend from the ArrayType + backend = _array_type_to_backend(ArrayType) + # Adapt returns StaticTLAS with isbits arrays for kernel traversal + return Adapt.adapt(backend, tlas) +end + +# Also support StaticTLAS (already GPU-ready, just adapt arrays) +function to_gpu(ArrayType, static_tlas::Raycore.StaticTLAS) + backend = _array_type_to_backend(ArrayType) + return Adapt.adapt(backend, static_tlas) end gpu_int(x) = Base.unsafe_trunc(Int32, x) diff --git a/src/kernels.jl b/src/kernels.jl index dadb7de..8bf977a 100644 --- a/src/kernels.jl +++ b/src/kernels.jl @@ -4,44 +4,95 @@ struct RayHit{TMetadata} metadata::TMetadata end -function hits_from_grid(bvh, viewdir; grid_size=32) - # Calculate grid bounds +# Access all triangles from a StaticTLAS +_primitives(tlas::StaticTLAS) = tlas.all_blas_prims + +function generate_ray_grid(tlas, ray_direction::Vec3f, grid_size::Int) + direction = normalize(ray_direction) + bounds = world_bound(tlas) + rect = GB.Rect3f(Point3f(bounds.p_min), Point3f(bounds.p_max) - Point3f(bounds.p_min)) + corners = GB.decompose(Point3f, rect) + + # Create perpendicular basis for the grid plane + if abs(direction[1]) < 0.9f0 + temp = Vec3f(1.0f0, 0.0f0, 0.0f0) + else + temp = Vec3f(0.0f0, 1.0f0, 0.0f0) + end + basis1 = normalize(cross(direction, temp)) + basis2 = normalize(cross(direction, basis1)) + + # Project corners onto basis vectors + proj1 = [dot(Vec3f(c...), basis1) for c in corners] + proj2 = [dot(Vec3f(c...), basis2) for c in corners] + + min_proj1, max_proj1 = extrema(proj1) + min_proj2, max_proj2 = extrema(proj2) + + margin = 0.05f0 * max(max_proj1 - min_proj1, max_proj2 - min_proj2) + grid_width = max_proj1 - min_proj1 + 2 * margin + grid_height = max_proj2 - min_proj2 + 2 * margin + + # Place grid origin behind the scene + depth_proj = [dot(Vec3f(c...), direction) for c in corners] + min_depth = minimum(depth_proj) - margin + + grid_center = Point3f(0, 0, 0) + min_depth * direction + + ((min_proj1 + max_proj1) / 2) * basis1 + + ((min_proj2 + max_proj2) / 2) * basis2 + + cell_w = grid_width / grid_size + cell_h = grid_height / grid_size + + ray_origins = Matrix{Point3f}(undef, grid_size, grid_size) + for i in 1:grid_size + for j in 1:grid_size + u = (i - (grid_size + 1) / 2) * cell_w + v = (j - (grid_size + 1) / 2) * cell_h + ray_origins[i, j] = grid_center + u * basis1 + v * basis2 + end + end + return ray_origins +end + +function hits_from_grid(tlas, viewdir; grid_size=32) ray_direction = normalize(viewdir) - ray_origins = Raycore.generate_ray_grid(bvh, ray_direction, grid_size) - TMetadata = eltype(bvh.primitives).parameters[1] # Get metadata type from Triangle{TMetadata} + ray_origins = generate_ray_grid(tlas, ray_direction, grid_size) + prims = _primitives(tlas) + TMetadata = eltype(prims).parameters[1] result = similar(ray_origins, RayHit{TMetadata}) Threads.@threads for idx in CartesianIndices(ray_origins) o = ray_origins[idx] - ray = Raycore.Ray(; o=o, d=ray_direction) - hit, prim, dist, bary = Raycore.closest_hit(bvh, ray) + ray = Ray(; o=o, d=ray_direction) + hit, prim, dist, bary, _ = closest_hit(tlas, ray) hitpoint = sum_mul(bary, prim.vertices) @inbounds result[idx] = RayHit{TMetadata}(hit, hitpoint, prim.metadata) end return result end -function view_factors(bvh; rays_per_triangle=10000) - result = zeros(UInt32, length(bvh.primitives), length(bvh.primitives)) - return view_factors!(result, bvh, rays_per_triangle) +function view_factors(tlas; rays_per_triangle=10000) + prims = _primitives(tlas) + result = zeros(UInt32, length(prims), length(prims)) + return view_factors!(result, tlas, rays_per_triangle) end -# Note: view_factors requires metadata to be the primitive index (Int) -# This is the default when constructing BVH without a custom metadata_fn -function view_factors!(result, bvh, rays_per_triangle=10000) - Threads.@threads for idx in eachindex(bvh.primitives) +function view_factors!(result, tlas, rays_per_triangle=10000) + prims = _primitives(tlas) + Threads.@threads for idx in eachindex(prims) @inbounds begin - triangle = bvh.primitives[idx] - tri_idx = triangle.metadata # metadata is the primitive index + triangle = prims[idx] + tri_idx = triangle.metadata n = GB.orthogonal_vector(Vec3f, GB.Triangle(triangle.vertices...)) normal = normalize(n) u, v = get_orthogonal_basis(normal) for i in 1:rays_per_triangle point_on_triangle = random_triangle_point(triangle) - o = point_on_triangle .+ (normal .* 0.01f0) # Offset so it doesn't self intersect + o = point_on_triangle .+ (normal .* 0.01f0) ray = Ray(; o=o, d=random_hemisphere_uniform(normal, u, v)) - hit, hit_prim, dist, _ = closest_hit(bvh, ray) + hit, hit_prim, dist, _, _ = closest_hit(tlas, ray) if hit - hit_idx = hit_prim.metadata # metadata is the primitive index + hit_idx = hit_prim.metadata if hit_idx != tri_idx result[tri_idx, hit_idx] += UInt32(1) end @@ -52,17 +103,15 @@ function view_factors!(result, bvh, rays_per_triangle=10000) return result end -function get_centroid(bvh, viewdir; grid_size=32) - # Calculate grid bounds - hits = hits_from_grid(bvh, viewdir; grid_size=grid_size) +function get_centroid(tlas, viewdir; grid_size=32) + hits = hits_from_grid(tlas, viewdir; grid_size=grid_size) surface_points = [hit.point for hit in hits if hit.hit] return surface_points, mean(surface_points) end -function get_illumination(bvh, viewdir; grid_size=1000) - # Calculate grid bounds - hits = hits_from_grid(bvh, viewdir; grid_size=grid_size) - # Use primitive metadata as keys - requires metadata to be the primitive index +function get_illumination(tlas, viewdir; grid_size=1000) + hits = hits_from_grid(tlas, viewdir; grid_size=grid_size) + prims = _primitives(tlas) result = Dict{Int, Float32}() for hit in hits if hit.hit @@ -71,5 +120,5 @@ function get_illumination(bvh, viewdir; grid_size=1000) result[idx] = count + 1f0 end end - return [get(result, idx, 0.0f0) for idx in 1:length(bvh.primitives)] + return [get(result, idx, 0.0f0) for idx in 1:length(prims)] end diff --git a/src/multitypeset.jl b/src/multitypeset.jl new file mode 100644 index 0000000..6615f78 --- /dev/null +++ b/src/multitypeset.jl @@ -0,0 +1,656 @@ +# ============================================================================ +# MultiTypeSet - Type-stable heterogeneous collections for GPU +# ============================================================================ +# Provides compile-time type-stable dispatch over collections of different types. +# Used for materials, textures, media, lights, etc. + +using Adapt +using Base: @propagate_inbounds +import KernelAbstractions as KA + +# ============================================================================ +# SetKey - Encodes type slot + vector index +# ============================================================================ + +""" + SetKey + +Index into a heterogeneous vector, encoding both which type slot (1-based) +and the index within that type's array. + +- `type_idx`: Which tuple slot (1-based), 0 = invalid/constant sentinel +- `vec_idx`: 1-based index within the vector at that slot +""" +struct SetKey + # UInt32 (not UInt8) is intentional: LLVM's select-scalarization pass produces broken + # IR (`select i1` with mismatched result type) when scalarizing a `select { i8, i32 }`. + # Using uniform UInt32 fields gives `{ i32, i32 }`, which scalarizes correctly. + type_idx::UInt32 + vec_idx::UInt32 +end + +# Default constructor for invalid/placeholder index +SetKey() = SetKey(UInt32(0), UInt32(0)) + +# Check for invalid sentinel +is_invalid(idx::SetKey) = idx.type_idx == UInt32(0) && idx.vec_idx == UInt32(0) +is_valid(idx::SetKey) = !is_invalid(idx) + +# ============================================================================ +# StaticMultiTypeSet - Immutable with separate texture storage for GPU +# ============================================================================ + +""" + StaticMultiTypeSet{Data, Textures} + +Immutable heterogeneous collection with separate texture storage. +- `data`: Tuple of GPU vectors for materials/objects +- `textures`: Tuple of GPU vectors containing isbits device pointers +""" +struct StaticMultiTypeSet{Data<:Tuple,Textures<:Tuple} <: AbstractVector{Any} + data::Data + textures::Textures +end + +# Empty constructor +StaticMultiTypeSet() = StaticMultiTypeSet((), ()) + +Base.isempty(smv::StaticMultiTypeSet) = isempty(smv.data) +Base.length(smv::StaticMultiTypeSet) = sum(length, smv.data; init=0) +n_slots(smv::StaticMultiTypeSet) = length(smv.data) + +# Get the static version - identity for StaticMultiTypeSet, .static field for MultiTypeSet +get_static(smv::StaticMultiTypeSet) = smv + +# Convert to a flat Tuple of all elements (preserves concrete element types) +_concat_to_tuple() = () +_concat_to_tuple(v::AbstractVector, rest...) = (v..., _concat_to_tuple(rest...)...) +to_tuple(smv::StaticMultiTypeSet) = _concat_to_tuple(smv.data...) + +# ============================================================================ +# foreach_element - Type-stable iteration over all elements +# ============================================================================ + +""" + foreach_element(f, smv::StaticMultiTypeSet, args...) + +Execute function `f` for each element in the StaticMultiTypeSet, passing additional `args`. +The function is called as `f(element, linear_idx, args...)` where `element` has a concrete type +and `linear_idx` is the 1-based linear index across all type slots. + +Uses compile-time unrolled loops for type stability. +The function `f` must not capture variables - pass all data as `args`. +""" +@inline @generated function foreach_element( + f::F, smv::StaticMultiTypeSet{Data, Textures}, args... +) where {F, Data<:Tuple, Textures} + N = length(Data.parameters) + + if N == 0 + return :(nothing) + end + + # Generate unrolled loops over each type slot + loops = Expr[] + for i in 1:N + push!(loops, quote + for j in eachindex(smv.data[$i]) + linear_idx += 1 + @inbounds f(smv.data[$i][j], linear_idx, args...) + end + end) + end + + quote + linear_idx = 0 + $(loops...) + nothing + end +end + +# ============================================================================ +# mapreduce - Type-stable reduction over all elements +# ============================================================================ + + +@inline function Base.mapreduce( + f::F, op::Op, smv::StaticMultiTypeSet{Data,Textures}, args...; init + ) where {F,Op,Data<:Tuple,Textures} + _mapreduce(f, op, smv, init, args...) +end +@inline function Base.mapreduce( + f::F, op::Op, smv::StaticMultiTypeSet{Data,Textures}, args::Vararg{Union{Base.AbstractBroadcasted,AbstractArray}}; init +) where {F,Op,Data<:Tuple,Textures} + _mapreduce(f, op, smv, init, args...) +end + +@inline @generated function _mapreduce( + f::F, op::Op, smv::StaticMultiTypeSet{Data,Textures}, init, args...) where {F, Op, Data<:Tuple, Textures} + N = length(Data.parameters) + + if N == 0 + return :(init) + end + + # Generate unrolled reduction over each type slot + reductions = Expr[] + for i in 1:N + push!(reductions, quote + for j in eachindex(smv.data[$i]) + @inbounds acc = op(acc, f(smv.data[$i][j], args...)) + end + end) + end + + quote + acc = init + $(reductions...) + acc + end +end + +# ============================================================================ +# TextureRef - Typed reference to a texture +# ============================================================================ + +# TIdx is the 1-based type slot index, idx is the element index within that slot's vector +struct TextureRef{ReferencedArrayType, T, N, TIdx} <: AbstractArray{T, N} + idx::Int +end + +Base.size(::TextureRef{ReferencedArrayType, T, N}) where {ReferencedArrayType, T, N} = ntuple(_ -> 0, N) + +# Deref for StaticMultiTypeSet - textures stored as Tuple{GPUVector{IsbitsPtr1}, GPUVector{IsbitsPtr2}, ...} +@inline function deref(smv::StaticMultiTypeSet{Data, Textures}, tref::TextureRef{ReferencedArrayType, T, N, TIdx}) where {Data, Textures, ReferencedArrayType, T, N, TIdx} + @inbounds smv.textures[TIdx][tref.idx] +end + +# Fallback: if already a concrete array, just return it (no-op for CPU paths or non-TextureRef fields) +@inline deref(::StaticMultiTypeSet, arr::AbstractArray) = arr + +# Fallback for nothing context (used by convenience overloads for CPU code that doesn't use MultiTypeSet) +@inline deref(::Nothing, arr::AbstractArray) = arr + +@inline function deref(smv::StaticMultiTypeSet{Data,Textures}, tref::TextureRef{ReferencedArrayType,T,N,TIdx}) where {Data<:Tuple,Textures<:Tuple,ReferencedArrayType,T,N,TIdx} + @inbounds smv.textures[TIdx][tref.idx] +end +# ============================================================================ +# Dummy kernel for argconvert (same pattern as kernel-abstractions.jl) +# ============================================================================ + +KA.@kernel multitypeset_dummy_kernel() = nothing + +function get_isbits_ptr(backend, gpu_arr) + kernel = multitypeset_dummy_kernel(backend) + return KA.argconvert(kernel, gpu_arr) +end + +# ============================================================================ +# MultiTypeSet - Mutable, builds GPU-ready structures on push! +# ============================================================================ + +""" + MultiTypeSet(backend) + +Mutable heterogeneous vector that builds GPU-ready structures on each push!. +Takes a KernelAbstractions backend at construction. + +# Example +```julia +backend = OpenCL.OpenCLBackend() +dhv = MultiTypeSet(backend) +texture = rand(Float32, 20, 20) +idx1 = push!(dhv, MatteMaterial(texture)) +idx2 = push!(dhv, GlassMaterial(1.5f0)) + +# Access the GPU-ready StaticMultiTypeSet directly +gpu_smv = dhv.static # Always up-to-date, no adapt needed +``` + +Push converts arrays to TextureRefs and stores texture data as GPU arrays. +The static field is rebuilt on each push to stay up-to-date. +""" +mutable struct MultiTypeSet{Backend} <: AbstractVector{Any} + backend::Backend + # Material storage - CPU vectors for accumulation (the authoritative data). + data_vectors::Dict{DataType, Any} # Type -> Vector{Type} + data_order::Vector{DataType} + # Texture type order. The shader-visible table of isbits device pointers + # lives only in `static.textures[slot]` — no parallel CPU mirror. + texture_order::Vector{DataType} + # Keep GPU texture arrays alive (the actual texture data). The backend + # handle kept here is the single owner for each texture's backing buffer. + texture_gpu_arrays::Vector{Any} + # Canonical GPU state. Every mutator (`push!` / `update!` / + # `store_texture` / `copyto_texture!`) keeps this field consistent by + # design — surgical `resize!` + `@allowscalar setindex!` on the affected + # slot — so there is no dirty flag and no batched rebuild step. The + # TLAS (`scene.accel`) has its own dirty+sync because BVH rebuilds are + # genuinely expensive; MultiTypeSet's element-level updates are cheap + # (one scalar GPU write) and pay no amortisation benefit from batching. + static::StaticMultiTypeSet +end + +Base.size(set::MultiTypeSet) = (length(set),) +function Base.length(set::MultiTypeSet) + return sum(length, values(set.data_vectors); init=0) +end + +function Base.show(io::IO, ::MIME"text/plain", set::MultiTypeSet) + n_types = length(set.data_order) + total = length(set) + print(io, "MultiTypeSet with $n_types type(s), $total element(s)") + for T in set.data_order + vec = set.data_vectors[T]::Vector + print(io, "\n ", length(vec), "× ", T) + end +end + +Base.show(io::IO, set::MultiTypeSet) = print(io, "MultiTypeSet(", length(set.data_order), " types, ", length(set), " elements)") + +function MultiTypeSet(backend) + return MultiTypeSet( + backend, + Dict{DataType, Any}(), + DataType[], + DataType[], + Any[], + StaticMultiTypeSet(), + ) +end + +n_slots(dhv::MultiTypeSet) = length(dhv.data_order) + +# `static` is always in sync with CPU state (maintained per-mutation). +get_static(dhv::MultiTypeSet) = dhv.static + +# MultiTypeSet delegates to its static version +to_tuple(mts::MultiTypeSet) = to_tuple(get_static(mts)) + +# `rebuild_static!` is deleted — mutators (`push!`, `update!`, `store_texture`, +# `copyto_texture!`) keep `static` consistent surgically, so there is no +# batched rebuild step. The TLAS (`scene.accel`) has its own `dirty + sync!` +# because BVH rebuilds are expensive; MultiTypeSet operations are all O(1) +# scalar GPU writes and pay no amortisation benefit from batching. + +# ============================================================================ +# Texture conversion and storage +# ============================================================================ + +""" + maybe_convert_field(dhv::MultiTypeSet, fval) + +Convert a struct field value for GPU storage. Override this for custom types. +- AbstractArray → TextureRef (uploaded to GPU) +- Everything else → unchanged (default) + +Materials should use loose type parameters so fields can be either raw values OR +TextureRef. This way constant values don't need texture indirection at all. +This function should not be overloaded outside Raycore. +""" +# Convert large arrays to TextureRef, but NOT StaticArrays (they're inline values, not textures) +maybe_convert_field(dhv::MultiTypeSet, arr::A) where A<:AbstractArray = store_texture(dhv, arr) +maybe_convert_field(::MultiTypeSet, arr::StaticArrays.StaticArray) = arr # Keep StaticArrays inline +# Don't re-convert already converted refs +maybe_convert_field(::MultiTypeSet, ref::TextureRef) = ref +# Default: recurse into structs, pass through primitives +function maybe_convert_field(dhv::MultiTypeSet, item::T) where T + # Recurse into struct types to convert nested arrays + if !isbitstype(T) + return convert_to_texturerefs(dhv, item) + end + # Primitives and empty structs pass through unchanged + return item +end + +# Convert arrays in a struct to TextureRefs, storing them as GPU arrays +function convert_to_texturerefs(dhv::MultiTypeSet, item::T) where T + if !isstructtype(T) || T <: AbstractArray + return item + end + fnames = fieldnames(T) + if isempty(fnames) + return item + end + new_fields = map(fnames) do fname + fval = getfield(item, fname) + maybe_convert_field(dhv, fval) + end + if all(getfield(item, fn) === nf for (fn, nf) in zip(fnames, new_fields)) + return item + end + BaseT = Base.typename(T).wrapper + return BaseT(new_fields...) +end + +# Store a texture as a GPU array, return a TextureRef pointing to its isbits +# device pointer. Keeps `texture_gpu_arrays` and `static.textures[slot]` +# consistent in one call: +# * existing texture type → `resize!` + `@allowscalar setindex!` to append one +# isbits pointer to the matching slot (one scalar GPU write). +# * new texture type → build a 1-element LavaArray{IsbitsPtr} and grow the +# `static.textures` tuple by one slot (tuple shape change is unavoidable). +function store_texture(dhv::MultiTypeSet, arr::AbstractArray{T}) where T + if !isbitstype(T) + arr = map(x -> maybe_convert_field(dhv, x), arr) + end + gpu_arr = Adapt.adapt(dhv.backend, arr) + AT = typeof(gpu_arr) + push!(dhv.texture_gpu_arrays, gpu_arr) + isbits_ptr = get_isbits_ptr(dhv.backend, gpu_arr) + + type_idx = findfirst(==(AT), dhv.texture_order) + if type_idx === nothing + # New texture type: extend `static.textures` by one slot. + push!(dhv.texture_order, AT) + type_idx = length(dhv.texture_order) + new_slot = Adapt.adapt(dhv.backend, [isbits_ptr]) + dhv.static = StaticMultiTypeSet(dhv.static.data, (dhv.static.textures..., new_slot)) + vec_idx = 1 + else + # Existing texture type: surgical one-element append into the GPU slot. + slot = dhv.static.textures[type_idx] + old_len = length(slot) + resize!(slot, old_len + 1) + @allowscalar slot[old_len + 1] = isbits_ptr + vec_idx = old_len + 1 + end + return TextureRef{AT, eltype(AT), ndims(AT), type_idx}(vec_idx) +end + +# ============================================================================ +# push! - Append item, keeping CPU + GPU state consistent in one call. +# ============================================================================ +# Existing type slot → surgical `resize!` + `@allowscalar setindex!` appends +# one element to `static.data[type_idx]` (one scalar GPU write). New type +# slot → build a 1-element LavaArray and extend the `static.data` tuple by +# one slot (tuple shape change is unavoidable when a new type appears). +function Base.push!(dhv::MultiTypeSet, item::T)::SetKey where T + # Convert arrays in the item to TextureRefs (textures stored via `store_texture`). + converted_item = maybe_convert_field(dhv, item) + CT = typeof(converted_item) + + type_idx = findfirst(==(CT), dhv.data_order) + if type_idx === nothing + # New material type: extend `static.data` by one slot. + dhv.data_vectors[CT] = [converted_item] + push!(dhv.data_order, CT) + type_idx = length(dhv.data_order) + new_slot = Adapt.adapt(dhv.backend, [converted_item]) + dhv.static = StaticMultiTypeSet((dhv.static.data..., new_slot), dhv.static.textures) + vec_idx = 1 + else + # Existing material type: surgical one-element append into the GPU slot. + push!(dhv.data_vectors[CT], converted_item) + slot = dhv.static.data[type_idx] + old_len = length(slot) + resize!(slot, old_len + 1) + @allowscalar slot[old_len + 1] = converted_item + vec_idx = old_len + 1 + end + return SetKey(UInt32(type_idx), UInt32(vec_idx)) +end + +# ============================================================================ +# update! - Sync modified CPU data into existing GPU arrays +# ============================================================================ + +""" + update!(dhv::MultiTypeSet, key::SetKey, new_item) + +Update an existing item in the set. The new item is walked against the +stored form via `update_item`: existing TextureRef slots are reused (the new +array data is copied into the existing GPU buffer, reallocating on size +mismatch), const-Texture fields are unwrapped to their scalar values, and +other fields fall through as plain value replacement. + +There is deliberately **no** `maybe_convert_field`/`store_texture` call on +this path — that would allocate a new GPU slot per update and leak hundreds +of MB per frame for plots with per-vertex color textures. +""" +function update!(dhv::MultiTypeSet, key::SetKey, new_item) + # Invalid-key sentinel ⇒ item was never stored in this set (e.g. the + # `SetKey()` returned by `push!(::MultiTypeSet, ::NullMaterial)` which is + # pbrt-v4's "Material interface"/nullptr equivalent and owns no slot, or + # a `MediumInterface` with `inside=nothing`/`outside=nothing` whose media + # side is likewise unpushed). Updating a slot that doesn't exist is a + # no-op, not an error — the caller's intent ("refresh what's stored") + # is vacuously satisfied. + is_invalid(key) && return nothing + CT = dhv.data_order[key.type_idx] + old_converted = dhv.data_vectors[CT][key.vec_idx] + updated = update_item(dhv, old_converted, new_item) + if updated !== old_converted + # Keep CPU and GPU consistent in one go — surgical single-element write. + dhv.data_vectors[CT][key.vec_idx] = updated + @allowscalar dhv.static.data[key.type_idx][key.vec_idx] = updated + end + return nothing +end + +public update_item, copyto_texture! + +""" + update_item(dhv::MultiTypeSet, old, new) + +Compute the updated representation of `old` after applying `new`'s data. +Reuses existing TextureRef slots (copying `new`'s arrays into them rather +than allocating fresh GPU buffers). Extended by backend / material packages +(e.g. Hikari) with overloads for their wrapper types — notably `Texture` +(unwraps const, routes array data to `copyto_texture!`) and `VertexColorTexture`. +""" +function update_item(dhv::MultiTypeSet, old::TextureRef{AT}, new_data::AbstractArray) where AT + copyto_texture!(dhv, old, new_data) + return old +end + +# Already-converted TextureRef on both sides: just reuse the existing slot. +update_item(::MultiTypeSet, old::TextureRef, ::TextureRef) = old + +# Nothing/Nothing: no-op. +update_item(::MultiTypeSet, ::Nothing, ::Nothing) = nothing + +# Generic fallback: walk field-by-field when field names match, otherwise +# replace `old` with `new` (leaf case — isbits values, identically-typed +# structs with no nested arrays, etc.). A type parameter mismatch (e.g. +# `Diffuse{TextureRef,Float32}` ↔ `Diffuse{VertexColorTexture{Matrix},Texture{Float32}}`) +# still recurses because the field names are identical. The reconstructed +# struct uses the concrete type produced by the per-field recursive calls; +# when TextureRefs are kept in place and const-Textures are unwrapped, that +# matches the stored (old) type. +# Tuple / NamedTuple recursion: these can't be reconstructed via +# `T.name.wrapper(fields...)` like regular structs can (`Tuple(a,b,c)` expects +# an iterable, not varargs). Handle them explicitly. +function update_item(dhv::MultiTypeSet, old::Tuple, new::Tuple) + length(old) == length(new) || return new + changed = false + new_fields = ntuple(length(old)) do i + uf = update_item(dhv, old[i], new[i]) + uf !== old[i] && (changed = true) + uf + end + changed || return old + return new_fields +end +function update_item(dhv::MultiTypeSet, old::NamedTuple{K}, new::NamedTuple{K}) where K + changed = false + new_fields = ntuple(length(K)) do i + uf = update_item(dhv, old[i], new[i]) + uf !== old[i] && (changed = true) + uf + end + changed || return old + return NamedTuple{K}(new_fields) +end + +function update_item(dhv::MultiTypeSet, old, new) + T_old = typeof(old) + fnames = fieldnames(T_old) + isempty(fnames) && return new # leaf — swap in the new value + # Only recurse if new exposes the same field names (types of individual + # fields may still differ). + fnames == fieldnames(typeof(new)) || return new + changed = false + new_fields = ntuple(length(fnames)) do i + of = getfield(old, fnames[i]) + nf = getfield(new, fnames[i]) + uf = update_item(dhv, of, nf) + uf !== of && (changed = true) + uf + end + changed || return old + return T_old.name.wrapper(new_fields...) +end + +""" + copyto_texture!(dhv, ref, data) + +Write `data` into the GPU texture addressed by `ref`. Same-size is a plain +`copyto!` (device pointer unchanged — no other state needs updating). Size +mismatch goes through `Base.resize!(::LavaArray)` which is capacity-aware: +the VkBuffer is only re-allocated on genuine growth beyond current capacity, +and the old buffer is retired via the deferred-free path (`bq.deferred_frees` +gated on the batch timeline — safe w.r.t. in-flight GPU work without any +CPU-side `synchronize`). + +If the device pointer actually moved (pool-alloc returned a fresh buffer), +the one affected slot of `static.textures[AT_slot]` is updated via a single +scalar `setindex!` — no re-adapt of the whole table, no per-frame leak. +""" +function copyto_texture!(dhv::MultiTypeSet, ref::TextureRef{AT}, new_data::AbstractArray) where AT + AT_slot = findfirst(==(AT), dhv.texture_order) + AT_slot === nothing && error("MultiTypeSet has no texture type slot for $AT (TextureRef broken?)") + + count = 0 + for arr in dhv.texture_gpu_arrays + typeof(arr) === AT || continue + count += 1 + count == ref.idx || continue + + if size(arr) == size(new_data) + # Same shape: pointer cannot move — one copyto!, done. + copyto!(arr, new_data) + else + # Capacity-aware grow: zero-alloc within capacity, deferred-free + # on true growth. Compare the device pointer before/after to see + # whether the surrounding isbits table needs one slot refreshed. + old_ptr = get_isbits_ptr(dhv.backend, arr) + resize!(arr, size(new_data)) + copyto!(arr, new_data) + new_ptr = get_isbits_ptr(dhv.backend, arr) + if new_ptr != old_ptr + # Pointer moved (grow past capacity). Refresh just this one + # slot of the shader-visible isbits buffer — no rebuild, no + # dirty flag, no sync. A later `get_static` that finds + # `dirty == false` sees a tuple that's already consistent. + @allowscalar dhv.static.textures[AT_slot][ref.idx] = new_ptr + end + end + return + end + error("GPU array not found for TextureRef(idx=$(ref.idx))") +end + +# No `update!` / `resize_and_overwrite!` hook: call sites use `Base.resize!` + +# `Base.copyto!` directly. Lava's `Base.resize!(::LavaArray)` is capacity-aware +# (no Vulkan alloc when within capacity, deferred-free on grow) and `copyto!` +# is the standard GPUArrays upload path — that pair is the full "make dst match +# src" operation without a Raycore-owned generic. + +# ============================================================================ +# with_index - Type-stable dispatch +# ============================================================================ + +""" + with_index(f, smv::StaticMultiTypeSet, idx::SetKey, args...) + +Execute function `f` with the element at index `idx`, passing additional `args`. +The function is called as `f(element, args...)` where `element` has a concrete type. + +Uses a single if-elseif-else chain for SPIR-V structured control flow compatibility. +The function `f` must not capture variables - pass all data as `args`. +""" +@inline @generated function with_index( + f::F, smv::StaticMultiTypeSet{Data, Textures}, idx::SetKey, args... +) where {F, Data<:Tuple, Textures} + N = length(Data.parameters) + + if N == 0 + return :(error("with_index: empty StaticMultiTypeSet")) + end + + # Build a single if-elseif-else chain for structured control flow (SPIR-V compatible) + # Start from the last branch and work backwards to build the chain + result = :(@inbounds f(smv.data[1][1], args...)) # default/else case + + for i in N:-1:1 + result = Expr(:if, + :(idx.type_idx === UInt32($i)), + :(@inbounds f(smv.data[$i][idx.vec_idx], args...)), + result + ) + end + + quote + return $result + end +end + +# ============================================================================ +# Adapt.jl integration for GPU array conversion +# ============================================================================ + +# Adapt StaticMultiTypeSet - adapts data and texture arrays +# For MultiTypeSet.static, arrays are already GPU - this converts to isbits for kernel +function Adapt.adapt_structure(to, smv::StaticMultiTypeSet) + adapted_data = map(smv.data) do arr + Adapt.adapt(to, arr) + end + adapted_textures = map(smv.textures) do tex + Adapt.adapt(to, tex) + end + return StaticMultiTypeSet(adapted_data, adapted_textures) +end + +# Adapt MultiTypeSet - `static` is always consistent (surgical-per-mutation); +# just hand it to the StaticMultiTypeSet adapt method. +function Adapt.adapt_structure(to, dhv::MultiTypeSet) + return Adapt.adapt_structure(to, dhv.static) +end + +# ============================================================================ +# GPU Resource Cleanup +# ============================================================================ + +""" + free!(set::MultiTypeSet) + +Release GPU memory held by the MultiTypeSet — the shadow-owned texture +arrays plus the static material/texture slot buffers. Does **not** +synchronize. + +**Precondition (caller's responsibility):** the GPU must be idle for +`set.backend` before this is called. `MultiTypeSet.texture_gpu_arrays` +is a shadow-ownership site: the only references from GPU work are raw +BDAs in arg buffers, so nothing inside `free!` can prove it's safe to +finalize — the caller establishes that, typically by calling `sync!` on +the enclosing accel / scene (which synchronizes its backend) or by +returning from a `colorbuffer` that completed with `device_wait_idle`. +""" +function free!(set::MultiTypeSet) + for arr in set.texture_gpu_arrays + finalize(arr) + end + empty!(set.texture_gpu_arrays) + # `set.static.data` / `.textures` are the canonical ownership sites for + # the adapted material / isbits-ptr arrays. Finalize them, then drop + # the tuples via a fresh empty StaticMultiTypeSet. + for arr in set.static.data + finalize(arr) + end + for arr in set.static.textures + finalize(arr) + end + set.static = StaticMultiTypeSet() + return nothing +end diff --git a/src/ray.jl b/src/ray.jl index 93739bf..c672d03 100644 --- a/src/ray.jl +++ b/src/ray.jl @@ -1,12 +1,13 @@ Base.@kwdef struct Ray <: AbstractRay o::Point3f d::Vec3f + t_min::Float32 = 0.0f0 t_max::Float32 = Inf32 time::Float32 = 0.0f0 end -@inline function Ray(ray::Ray; o::Point3f = ray.o, d::Vec3f = ray.d, t_max::Float32 = ray.t_max, time::Float32 = ray.time) - Ray(o, d, t_max, time) +@inline function Ray(ray::Ray; o::Point3f = ray.o, d::Vec3f = ray.d, t_min::Float32 = ray.t_min, t_max::Float32 = ray.t_max, time::Float32 = ray.time) + Ray(o, d, t_min, t_max, time) end @@ -61,7 +62,7 @@ end increase_hit(ray::Ray, t_hit) = Ray(ray; t_max=t_hit) increase_hit(ray::RayDifferentials, t_hit) = RayDifferentials(ray; t_max=t_hit) -@inline function intersect_p!(shape::AbstractShape, ray::R) where {R<:AbstractRay} +@inline function intersect_p!(shape::AbstractGeometry, ray::R) where {R<:AbstractRay} intersects, t_hit, barycentric = intersect(shape, ray) !intersects && return false, ray, barycentric ray = increase_hit(ray, t_hit) diff --git a/src/ray_intersection_session.jl b/src/ray_intersection_session.jl deleted file mode 100644 index 28383d3..0000000 --- a/src/ray_intersection_session.jl +++ /dev/null @@ -1,95 +0,0 @@ -""" - RayIntersectionSession{F} - -Represents a ray tracing session containing rays, a BVH structure, a hit function, -and the computed intersection results. - -# Fields -- `rays::Vector{<:AbstractRay}`: Array of rays to trace -- `bvh::BVH`: BVH acceleration structure to intersect against -- `hit_function::F`: Function to use for intersection testing (e.g., `closest_hit` or `any_hit`) -- `hits::Vector{Tuple{Bool, Triangle, Float32, Point3f}}`: Results of hit_function applied to each ray - -# Example -```julia -using Raycore, GeometryBasics - -# Create BVH from geometry -sphere = Tesselation(Sphere(Point3f(0, 0, 1), 1.0f0), 20) -bvh = Raycore.BVH([sphere]) - -# Create rays -rays = [ - Raycore.Ray(Point3f(0, 0, -5), Vec3f(0, 0, 1)), - Raycore.Ray(Point3f(1, 0, -5), Vec3f(0, 0, 1)), -] - -# Create session -session = RayIntersectionSession(rays, bvh, Raycore.closest_hit) - -# Access results -for (i, hit) in enumerate(session.hits) - hit_found, primitive, distance, bary_coords = hit - if hit_found - println("Ray \$i hit at distance \$distance") - end -end -``` -""" -struct RayIntersectionSession{Rays, F} - hit_function::F - rays::Rays - bvh::BVH - hits::Vector{Tuple{Bool, Triangle, Float32, Point3f}} - - function RayIntersectionSession(hit_function::F, rays::Rays, bvh::BVH) where {Rays,F} - # Compute all hits - hits = [hit_function(bvh, ray) for ray in rays] - new{Rays, F}(hit_function, rays, bvh, hits) - end -end - -""" - hit_points(session::RayIntersectionSession) - -Extract all valid hit points from a RayIntersectionSession. - -Returns a vector of `Point3f` containing the world-space hit points for all rays that intersected geometry. -""" -function hit_points(session::RayIntersectionSession) - return map(filter(first, session.hits)) do hit - _, hit_primitive, _, bary_coords = hit - return sum_mul(bary_coords, hit_primitive.vertices) - end -end - -""" - hit_distances(session::RayIntersectionSession) - -Extract all hit distances from a RayIntersectionSession. - -Returns a vector of `Float32` containing distances for all rays that intersected geometry. -""" -function hit_distances(session::RayIntersectionSession) - return map(filter(first, session.hits)) do hit - return hit[3] - end -end - -""" - hit_count(session::RayIntersectionSession) - -Count the number of rays that hit geometry in the session. -""" -function hit_count(session::RayIntersectionSession) - count(hit -> hit[1], session.hits) -end - -""" - miss_count(session::RayIntersectionSession) - -Count the number of rays that missed all geometry in the session. -""" -function miss_count(session::RayIntersectionSession) - count(hit -> !hit[1], session.hits) -end diff --git a/src/rt_transport.jl b/src/rt_transport.jl new file mode 100644 index 0000000..7b93493 --- /dev/null +++ b/src/rt_transport.jl @@ -0,0 +1,42 @@ +# ============================================================================ +# RT transport structs — shared by all hardware RT backends +# ============================================================================ + +""" + RTRay + +Ray input for hardware RT dispatch. 32 bytes, matches Vulkan/Metal layout. +""" +struct RTRay + origin_x::Float32 + origin_y::Float32 + origin_z::Float32 + tmin::Float32 + dir_x::Float32 + dir_y::Float32 + dir_z::Float32 + tmax::Float32 +end + +""" + RTHitResult + +Ray hit output from hardware RT dispatch. 32 bytes. + +- `instance_custom_index` — value of `gl_InstanceCustomIndexEXT` at the hit. + Under the current semantics this carries the `InstanceDescriptor.instance_id` + (the interface-override slot). `0` means "inherit from triangle metadata". +- `instance_id` — value of `gl_InstanceID` at the hit (0-based instance array + position). Used by the caller to look up per-instance data such as the + BLAS triangle offset. +""" +struct RTHitResult + hit::UInt32 + t::Float32 + primitive_id::UInt32 + instance_custom_index::UInt32 + bary_u::Float32 + bary_v::Float32 + instance_id::UInt32 + _pad2::UInt32 +end diff --git a/src/soa.jl b/src/soa.jl new file mode 100644 index 0000000..aa143f5 --- /dev/null +++ b/src/soa.jl @@ -0,0 +1,110 @@ +# ============================================================================ +# Structure of Arrays (SoA) Utilities +# ============================================================================ +# Macros and helpers for efficient SoA data access in GPU kernels. +# SoA layout is critical for GPU memory coalescing. + +""" + @get field1, field2, ... = soa[idx] + +Macro to extract multiple fields from a Structure of Arrays (SoA) at index `idx`. + +# Example +```julia +ray_queue = (ray=[r1, r2, r3], pixel_x=[1, 2, 3], pixel_y=[4, 5, 6]) +@get ray, pixel_x, pixel_y = ray_queue[2] +# Expands to: +# ray = ray_queue.ray[2] +# pixel_x = ray_queue.pixel_x[2] +# pixel_y = ray_queue.pixel_y[2] +``` +""" +macro get(expr) + if expr.head != :(=) + error("@get expects assignment syntax: @get field1, field2 = soa[idx]") + end + + lhs = expr.args[1] + rhs = expr.args[2] + + # Parse left side (field names) + if lhs isa Symbol + fields = [lhs] + elseif lhs.head == :tuple + fields = lhs.args + else + error("@get left side must be field names or tuple of field names") + end + + # Parse right side (soa[idx]) + if rhs.head != :ref + error("@get right side must be array indexing: soa[idx]") + end + soa = rhs.args[1] + idx = rhs.args[2] + + # Generate field extraction code + assignments = [:($(esc(field)) = $(esc(soa)).$(field)[$(esc(idx))]) for field in fields] + + return Expr(:block, assignments...) +end + +""" + @set soa[idx] = (field1=val1, field2=val2, ...) + +Macro to set multiple fields in a Structure of Arrays (SoA) at index `idx`. +Expects named tuple syntax on the right side. + +# Example +```julia +ray_queue = (ray=Vector{Ray}(undef, 10), pixel_x=zeros(Int32, 10)) +@set ray_queue[1] = (ray=my_ray, pixel_x=Int32(5)) +# Expands to: +# ray_queue.ray[1] = my_ray +# ray_queue.pixel_x[1] = Int32(5) +``` +""" +macro set(expr) + if expr.head != :(=) + error("@set expects assignment syntax: @set soa[idx] = (field1=val1, ...)") + end + + lhs = expr.args[1] + rhs = expr.args[2] + + # Parse left side (soa[idx]) + if lhs.head != :ref + error("@set left side must be array indexing: soa[idx]") + end + soa = lhs.args[1] + idx = lhs.args[2] + + # Parse right side (named tuple or parameters) + assignments = [] + if rhs.head == :tuple || rhs.head == :parameters + for arg in rhs.args + if arg isa Expr && arg.head == :(=) + field = arg.args[1] + val = arg.args[2] + push!(assignments, :($(esc(soa)).$(field)[$(esc(idx))] = $(esc(val)))) + else + error("@set expects named parameters: @set soa[idx] = (field=value, ...)") + end + end + else + error("@set expects a tuple with named fields: @set soa[idx] = (field=value, ...)") + end + + return Expr(:block, assignments...) +end + +""" + similar_soa(template_array, T::Type, num_elements) -> NamedTuple + +Create a Structure of Arrays (SoA) layout for type `T` with `num_elements` entries. +Uses `template_array` to determine the array type (Array, ROCArray, etc.). +""" +function similar_soa(template, ::Type{T}, num_elements) where T + fields = [f => similar(template, fieldtype(T, f), num_elements) for f in fieldnames(T)] + return (; fields...) +end diff --git a/src/transformations.jl b/src/transformations.jl index 178523a..afb94a6 100644 --- a/src/transformations.jl +++ b/src/transformations.jl @@ -140,9 +140,9 @@ end function (t::Transformation)(p::Point3f)::Point3f ph = Point4f(p..., 1f0) pt = t.m * ph - pr = Point3f(pt[Vec(1, 2, 3)]) - pt[4] == 1 && return pr - pr ./ pt[4] + # Always divide by w - avoids branch that causes SPIR-V structured control flow errors + # Division by 1.0 is essentially free on GPU + Point3f(pt[Vec(1, 2, 3)]) ./ pt[4] end (t::Transformation)(v::Vec3f)::Vec3f = t.m[Vec(1, 2, 3), Vec(1, 2, 3)] * v diff --git a/src/triangle_mesh.jl b/src/triangle_mesh.jl index 6b691b1..f07bf1f 100644 --- a/src/triangle_mesh.jl +++ b/src/triangle_mesh.jl @@ -1,32 +1,4 @@ -struct TriangleMesh{VT<:AbstractVector{Point3f}, IT<:AbstractVector{UInt32}, NT<:AbstractVector{Normal3f}, TT<:AbstractVector{Vec3f}, UT<:AbstractVector{Point2f}} <: AbstractShape - vertices::VT - # For the i-th triangle, its 3 vertex positions are: - # [vertices[indices[3 * i + j]] for j in 0:2]. - indices::IT - # Optional normal vectors, one per vertex. - normals::NT - # Optional tangent vectors, one per vertex. - tangents::TT - # Optional parametric (u, v) values, one for each vertex. - uv::UT - - function TriangleMesh( - vertices::VT, - indices::IT, - normals::NT = Normal3f[], - tangents::TT = Vec3f[], - uv::UT = Point2f[], - ) where {VT, IT, NT, TT, UT} - - return new{VT, IT, NT, TT, UT}( - vertices, - copy(indices), copy(normals), - copy(tangents), copy(uv), - ) - end -end - -struct Triangle{TMetadata} <: AbstractShape +struct Triangle{TMetadata} <: AbstractGeometry{3, Float32} vertices::SVector{3,Point3f} normals::SVector{3,Normal3f} tangents::SVector{3,Vec3f} @@ -34,45 +6,6 @@ struct Triangle{TMetadata} <: AbstractShape metadata::TMetadata end -# Constructor with metadata -function Triangle(m::TriangleMesh, face_indx, metadata) - f_idx = 1 + (3 * (face_indx - 1)) - vs = @SVector [m.vertices[m.indices[f_idx + i]] for i in 0:2] - ns = @SVector [m.normals[m.indices[f_idx + i]] for i in 0:2] # Every mesh should have normals!? - if !isempty(m.tangents) - ts = @SVector [m.tangents[m.indices[f_idx + i]] for i in 0:2] - else - ts = @SVector [Vec3f(NaN) for _ in 1:3] - end - if !isempty(m.uv) - uv = @SVector [m.uv[m.indices[f_idx + i]] for i in 0:2] - else - uv = SVector(Point2f(0), Point2f(1, 0), Point2f(1, 1)) - end - return Triangle(vs, ns, ts, uv, metadata) -end - -# Convenience constructor without metadata (uses Nothing) -function Triangle(m::TriangleMesh, face_indx) - return Triangle(m, face_indx, nothing) -end - -function TriangleMesh(mesh::GeometryBasics.Mesh) - nmesh = GeometryBasics.expand_faceviews(mesh) - fs = decompose(TriangleFace{UInt32}, nmesh) - vertices = decompose(Point3f, nmesh) - normals = Normal3f.(decompose_normals(nmesh)) - uvs = GeometryBasics.decompose_uv(nmesh) - if isnothing(uvs) - uvs = Point2f[] - end - indices = collect(reinterpret(UInt32, fs)) - return TriangleMesh( - vertices, indices, - normals, Vec3f[], Point2f.(uvs), - ) -end - function area(t::Triangle) vs = vertices(t) 0.5f0 * norm((vs[2] - vs[1]) × (vs[3] - vs[1])) @@ -103,6 +36,45 @@ object_bound(t::Triangle) = mapreduce( world_bound(t::Triangle) = reduce(∪, Bounds3.(vertices(t))) +""" + empty_triangle(::Type{Triangle{TMeta}}) -> Triangle{TMeta} + +Zero-initialized `Triangle` suitable as a no-hit sentinel. All vertex, normal, +tangent, and uv components are 0; metadata is field-wise zero-constructed +(works for any POD struct whose primitive fields have `zero(T)` defined). + +Takes the full Triangle type (not just the metadata type) so callers can +write `empty_triangle(eltype(triangles))` directly. +""" +function empty_triangle(::Type{Triangle{TMeta}}) where TMeta + Triangle{TMeta}( + SVector{3,Point3f}(Point3f(0,0,0), Point3f(0,0,0), Point3f(0,0,0)), + SVector{3,Normal3f}(Normal3f(0,0,0), Normal3f(0,0,0), Normal3f(0,0,0)), + SVector{3,Vec3f}(Vec3f(0,0,0), Vec3f(0,0,0), Vec3f(0,0,0)), + SVector{3,Point2f}(Point2f(0,0), Point2f(0,0), Point2f(0,0)), + _zero_struct(TMeta), + ) +end + +""" + _zero_struct(::Type{T}) -> T + +Field-wise zero-initialise a POD struct `T`. Recursive over `fieldtype(T, i)` +so nested structs (e.g. `Triangle{TriangleMeta}` where `TriangleMeta` has +`UInt32` fields) work without requiring `zero(T)` to be defined on the outer +type. Falls back to `zero(T)` for primitive leaves. +""" +@generated function _zero_struct(::Type{T}) where T + nf = fieldcount(T) + if nf == 0 + # Primitive leaf — use Base.zero. + return :(zero(T)) + else + fields = [:(_zero_struct($(fieldtype(T, i)))) for i in 1:nf] + return Expr(:call, T, fields...) + end +end + function _argmax(vec::Vec3) max_val = vec[1] max_idx = Int32(1) @@ -184,13 +156,9 @@ end ∂n∂u, ∂n∂v end -# Note: surface_interaction and init_triangle_shading_geometry have been removed -# These functions are now handled by Trace.jl's triangle_to_surface_interaction -# Raycore only provides low-level ray-triangle intersection via intersect_triangle - @inline function intersect(triangle::Triangle, ray::AbstractRay)::Tuple{Bool,Float32,Point3f} - verts = vertices(triangle) # Get triangle vertices - return intersect_triangle(verts, ray) # Check if ray hits triangle + verts = vertices(triangle) + return intersect_triangle(verts, ray) end @inline function intersect_p(t::Triangle, ray::Union{Ray,RayDifferentials}, ::Bool=false) @@ -225,7 +193,6 @@ end # Test against t_max range. det < 0f0 && (t_scaled >= 0f0 || t_scaled < ray.t_max * det) && return false, t_hit, barycentric det > 0f0 && (t_scaled <= 0f0 || t_scaled > ray.t_max * det) && return false, t_hit, barycentric - # TODO test against alpha texture if present. # Compute barycentric coordinates and t value for triangle intersection. inv_det = 1.0f0 / det barycentric = edges .* inv_det diff --git a/src/unrolled.jl b/src/unrolled.jl new file mode 100644 index 0000000..9204542 --- /dev/null +++ b/src/unrolled.jl @@ -0,0 +1,322 @@ +# ============================================================================ +# GPU-Safe Unrolled Iteration Utilities +# ============================================================================ +# Provides compile-time unrolled iteration over tuples without closure capture. +# Critical for GPU kernels where dynamic dispatch and boxing are not allowed. + +# ============================================================================ +# Compiler Limits +# ============================================================================ + +const MAX_TUPLE_LENGTH = 32 +const MAX_TYPE_DEPTH = 10 + +# ============================================================================ +# FastClosure - Compile-time validation wrapper +# ============================================================================ + +""" + FastClosure{F, Args<:Tuple} + +A callable wrapper that validates a function and arguments are GPU-safe: +1. Function `f` has no captured variables (no `Core.Box` fields) +2. Arguments don't exceed compiler limits (tuple length, type depth) + +When called, appends stored args to the call: `fc(x) == fc.f(x, fc.args...)` + +This is used internally by `for_unrolled`, `map_unrolled`, and `reduce_unrolled`. +""" +struct FastClosure{F, Args<:Tuple} + f::F + args::Args + + function FastClosure(f::F, args::Args) where {F, Args<:Tuple} + check_no_capture(F) + check_args_limits(Args) + new{F, Args}(f, args) + end +end + +# Make FastClosure callable - appends stored args to call +@inline (fc::FastClosure)(outer_args...) = fc.f(outer_args..., fc.args...) + +""" + check_no_capture(::Type{F}) where F + +Compile-time check that function type `F` has no captured variables. +Any closure field indicates a captured variable which should be passed as an argument instead. +`Core.Box` fields are especially problematic (heap-allocated, type-unstable). +""" +@generated function check_no_capture(::Type{F}) where F + # Regular functions have no fields - OK + if fieldcount(F) == 0 + return :nothing + end + + # Any field on a closure = captured variable + # Collect all captured variable names + captured_names = [fieldname(F, i) for i in 1:fieldcount(F)] + boxed_names = [fieldname(F, i) for i in 1:fieldcount(F) if fieldtype(F, i) === Core.Box] + + if !isempty(boxed_names) + # Boxed captures are the worst - definitely error + names_str = join(boxed_names, ", ") + return :(error("FastClosure: function captures boxed variable(s): " * $names_str * ". Pass as argument(s) instead.")) + else + # Non-boxed captures: still problematic for GPU, error with helpful message + names_str = join(captured_names, ", ") + return :(error("FastClosure: function captures variable(s): " * $names_str * ". Pass as argument(s) instead to ensure GPU compatibility.")) + end +end + +""" + check_args_limits(::Type{Args}) where Args + +Compile-time check that argument types don't exceed compiler limits. +""" +@generated function check_args_limits(::Type{Args}) where Args <: Tuple + # Check total tuple length + n = length(Args.parameters) + if n > MAX_TUPLE_LENGTH + return :(error("FastClosure: too many arguments ($($n) > $MAX_TUPLE_LENGTH). This may cause inference failures.")) + end + + # Check for overly long tuple arguments + for (i, T) in enumerate(Args.parameters) + if T <: Tuple && length(T.parameters) > MAX_TUPLE_LENGTH + return :(error("FastClosure: argument $($i) is a tuple with $($(length(T.parameters))) elements (> $MAX_TUPLE_LENGTH). This may cause inference failures.")) + end + end + + return :nothing +end + +# ============================================================================ +# for_unrolled - Side effects, no return value +# ============================================================================ + +""" + for_unrolled(f, tuple, args...) + +Iterate over `tuple` at compile-time, calling `f(element, args...)` for each element. +No return value (use for side effects). + +The function `f` must not capture any variables - pass all data as `args` instead. + +# Example +```julia +lights = (sun_light, point_light, spot_light) +total = Ref(RGBSpectrum(0f0)) + +# Bad - captures `total` and `ray`: +for light in lights + total[] += le(light, ray) # Boxing on GPU! +end + +# Good - pass as arguments: +for_unrolled(add_light!, lights, total, ray) +# Where: add_light!(light, total, ray) = total[] += le(light, ray) +``` +""" +@inline function for_unrolled(f::F, tuple::Tuple, args...) where F + fc = FastClosure(f, args) + _for_unrolled(fc, tuple) + return nothing +end + +@inline _for_unrolled(_fc, ::Tuple{}) = nothing +@inline function _for_unrolled(fc, tuple::Tuple) + fc(first(tuple)) + _for_unrolled(fc, Base.tail(tuple)) + return nothing +end + +# Val{N} version for index-based iteration +""" + for_unrolled(f, ::Val{N}, args...) + +Iterate from 1 to N at compile-time, calling `f(i, args...)` for each index. +""" +@inline function for_unrolled(f::F, ::Val{N}, args...) where {F, N} + fc = FastClosure(f, args) + _for_unrolled_n(fc, Val(N)) + return nothing +end + +@inline _for_unrolled_n(_fc, ::Val{0}) = nothing +@inline function _for_unrolled_n(fc, ::Val{N}) where N + _for_unrolled_n(fc, Val(N-1)) + fc(N % Int32) + return nothing +end + +# ============================================================================ +# map_unrolled - Transform tuple elements +# ============================================================================ + +""" + map_unrolled(f, tuple, args...) -> Tuple + +Transform each element of `tuple` at compile-time, returning a new tuple. +Calls `f(element, args...)` for each element. + +# Example +```julia +lights = (sun_light, point_light) +contributions = map_unrolled(compute_light, lights, hit_point, normal) +# Returns: (compute_light(sun_light, hit_point, normal), +# compute_light(point_light, hit_point, normal)) +``` +""" +@inline function map_unrolled(f::F, tuple::Tuple, args...) where F + fc = FastClosure(f, args) + return _map_unrolled(fc, tuple) +end + +@inline _map_unrolled(_fc, ::Tuple{}) = () + +# Use @generated to avoid tuple splatting which causes allocations +@generated function _map_unrolled(fc, tup::T) where T <: Tuple + N = length(T.parameters) + exprs = [:(fc(tup[$i])) for i in 1:N] + return :(($(exprs...),)) +end + +# ============================================================================ +# reduce_unrolled - Accumulate over tuple elements +# ============================================================================ + +""" + reduce_unrolled(f, tuple, init, args...) -> result + +Reduce `tuple` at compile-time using `f(accumulator, element, args...)`. + +# Example +```julia +lights = (sun_light, point_light, env_light) + +# Compute total light contribution +total = reduce_unrolled(add_light_contribution, lights, RGBSpectrum(0f0), ray, hit_point) +# Where: add_light_contribution(acc, light, ray, hp) = acc + compute_li(light, ray, hp) +``` +""" +@inline function reduce_unrolled(f::F, tuple::Tuple, init, args...) where F + fc = FastClosure(f, args) + return _reduce_unrolled(fc, tuple, init) +end + +@inline _reduce_unrolled(_fc, ::Tuple{}, acc) = acc +@inline function _reduce_unrolled(fc, tuple::Tuple, acc) + new_acc = fc(acc, first(tuple)) + return _reduce_unrolled(fc, Base.tail(tuple), new_acc) +end + +# StaticMultiTypeSet support: compile-time unrolled reduction over type slots +@inline @generated function _reduce_unrolled(fc, smv::StaticMultiTypeSet{Data}, acc) where {Data<:Tuple} + N = length(Data.parameters) + if N == 0 + return :(acc) + end + reductions = Expr[] + for i in 1:N + push!(reductions, quote + for j in eachindex(smv.data[$i]) + @inbounds acc = fc(acc, smv.data[$i][j]) + end + end) + end + quote + $(reductions...) + acc + end +end + +# MultiTypeSet delegates to its static version +@inline function reduce_unrolled(f::F, mts::MultiTypeSet, init, args...) where F + reduce_unrolled(f, get_static(mts), init, args...) +end + +@inline function reduce_unrolled(f::F, smv::StaticMultiTypeSet, init, args...) where F + fc = FastClosure(f, args) + return _reduce_unrolled(fc, smv, init) +end + +# ============================================================================ +# sum_unrolled - Common reduction pattern +# ============================================================================ + +""" + sum_unrolled(f, tuple, args...) -> result + +Sum `f(element, args...)` over all elements of `tuple`. + +# Example +```julia +lights = (sun_light, point_light) +total = sum_unrolled(le, lights, ray) +# Computes: le(sun_light, ray) + le(point_light, ray) +``` +""" +@inline function sum_unrolled(f::F, tuple::Tuple, args...) where F + fc = FastClosure(f, args) + return _sum_unrolled(fc, tuple) +end + +@inline _sum_unrolled(_fc, ::Tuple{}) = nothing # Empty tuple - caller should handle +@inline _sum_unrolled(fc, tuple::Tuple{T}) where T = fc(first(tuple)) +@inline function _sum_unrolled(fc, tuple::Tuple) + return fc(first(tuple)) + _sum_unrolled(fc, Base.tail(tuple)) +end + +# ============================================================================ +# getindex_unrolled - Select element by runtime index, apply function +# ============================================================================ + +""" + getindex_unrolled(f, tuple, idx::Int32, args...) -> result + +Select element at runtime index `idx` from `tuple` and apply `f(element, args...)`. +Uses unrolled if-branches for GPU compatibility - no dynamic dispatch. + +The index is 1-based. If idx is out of bounds, returns `f(tuple[1], args...)` as fallback. + +# Example +```julia +lights = (sun_light, point_light, env_light) +light_idx = Int32(2) + +# Sample from the selected light +sample = getindex_unrolled(sample_light, lights, light_idx, point, lambda, u) +# Equivalent to: sample_light(point_light, point, lambda, u) +``` +""" +@inline function getindex_unrolled(f::F, tuple::Tuple, idx::Int32, args...) where F + fc = FastClosure(f, args) + return _getindex_unrolled(fc, tuple, idx) +end + +# Generated function creates unrolled if-branches for type stability +@generated function _getindex_unrolled(fc, tuple::T, idx::Int32) where T <: Tuple + N = length(T.parameters) + + if N == 0 + # Empty tuple - shouldn't happen, but return nothing + return :(error("getindex_unrolled: empty tuple")) + end + + # Build unrolled if-else chain + # Start from the last index and work backwards to build nested if-else + expr = :(fc(tuple[$N])) # Default/fallback case + + for i in (N-1):-1:1 + expr = quote + if idx == Int32($i) + fc(tuple[$i]) + else + $expr + end + end + end + + return expr +end diff --git a/test/bounds.jl b/test/bounds.jl index 6132d00..87b1d09 100644 --- a/test/bounds.jl +++ b/test/bounds.jl @@ -191,9 +191,9 @@ end @testset "Bounding sphere" begin b = Raycore.Bounds3(Point3f(0, 0, 0), Point3f(2, 2, 2)) - center, radius = Raycore.bounding_sphere(b) - @test center == Point3f(1, 1, 1) - @test radius ≈ sqrt(3.0f0) + sphere = Raycore.bounding_sphere(b) + @test sphere.center == Point3f(1, 1, 1) + @test sphere.r ≈ sqrt(3.0f0) end @testset "Ray-Bounds intersection" begin diff --git a/test/runtests.jl b/test/runtests.jl index bc1cd3a..1a4c20f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,21 +1,133 @@ +# NOTE: GPU kernel tests are skipped under --check-bounds=yes (the Pkg.test default) +# because bounds checking injects error paths that can't compile to SPIR-V. +# For full test coverage: Pkg.test("Raycore"; julia_args=`--check-bounds=auto`) +# +# Backend selection (CI matrix): +# RAYCORE_TEST_BACKEND=cpu — KA.CPU() (default; runs on every CI worker) +# RAYCORE_TEST_BACKEND=lavapipe — Lava on lavapipe ICD (apt: mesa-vulkan-drivers). +# CURRENTLY BAILS OUT EARLY: lavapipe's mesa-LLVM +# JIT aborts the process on `X86ISD::MGATHER` +# selection for `Aligned 1` loads emitted by +# Lava's SPIR-V backend in AcceleratedKernels' +# shared-memory reductions. A process abort +# can't be recorded as `@test_broken`, so the +# suite emits placeholder broken-tests until +# the alignment hint is fixed in Lava's +# SPIR-V emitter. +# RAYCORE_TEST_BACKEND=lava — Lava on whatever Vulkan ICD is found +# (developer GPUs / RADV / NVIDIA / etc.). +# Full Lava-using suite runs. + using Test using GeometryBasics using LinearAlgebra +using StaticArrays using Raycore using JET using Aqua +using KernelAbstractions +const KA = KernelAbstractions + +const RAYCORE_TEST_BACKEND_NAME = lowercase(get(ENV, "RAYCORE_TEST_BACKEND", "cpu")) +const _USE_LAVA = RAYCORE_TEST_BACKEND_NAME in ("lava", "lavapipe") +const _IS_LAVAPIPE = RAYCORE_TEST_BACKEND_NAME == "lavapipe" + +if _USE_LAVA + using Lava +end + +""" + test_backend() + +KernelAbstractions backend the current CI matrix entry asks for. +`KA.CPU()` (default) or `Lava.LavaBackend()` (when env var selects lava/lavapipe). +""" +test_backend() = _USE_LAVA ? Lava.LavaBackend() : KA.CPU() + +"""Whether we're running on lavapipe specifically (mesa software Vulkan). +Used to mark tests broken that hit lavapipe-specific JIT bugs.""" +test_is_lavapipe() = _IS_LAVAPIPE + +"""Whether the current backend has VK_KHR_ray_tracing_pipeline. +Used to gate HWTLAS tests.""" +test_has_hw_rt() = _USE_LAVA && Lava.vk_context().rt_pipeline_properties !== nothing + +""" + @lavapipe_broken expr + +Mark a `@test` expression as broken on lavapipe specifically. Use only on +tests whose failure on lavapipe is a Julia exception (catchable). Tests +whose failure crashes the Julia process — e.g. mesa-LLVM JIT aborts — +can't be caught by this; whole testset must be replaced with a +`@test_broken false` placeholder upstream. +""" +macro lavapipe_broken(ex) + quote + if test_is_lavapipe() + @test_broken $(esc(ex)) + else + @test $(esc(ex)) + end + end +end # ambiguities come from GeometryBasics.@fixed_vector Normal = StaticVector Aqua.test_all(Raycore; ambiguities=(; broken=true)) @testset "Raycore Tests" begin + # CPU-only suites — run on every backend matrix entry. @testset "Intersection" begin include("test_intersection.jl") end - @testset "Type Stability" begin - include("test_type_stability.jl") - end @testset "Bounds" begin include("bounds.jl") end + @testset "Unrolled" begin + include("test_unrolled.jl") + end + + # Backend-using suites. On lavapipe, the FIRST kernel dispatch in any + # of these aborts the Julia process via mesa's LLVM JIT before + # `@testset` machinery can record anything — so on lavapipe we replace + # the whole include with a placeholder broken-test. When the + # underlying alignment hint in Lava's SPIR-V emitter is fixed, drop + # these placeholders and let the testsets run normally. + if test_is_lavapipe() + @testset "Instanced BVH (lavapipe placeholder)" begin + @test_broken false + end + @testset "MultiTypeSet (lavapipe placeholder)" begin + @test_broken false + end + @testset "Mesh Update (lavapipe placeholder)" begin + @test_broken false + end + @testset "AbstractAccel contract (lavapipe placeholder)" begin + @test_broken false + end + @testset "TLAS Stress (lavapipe placeholder)" begin + @test_broken false + end + else + # Either KA.CPU() (cpu matrix entry) or LavaBackend on real GPU. + @testset "Instanced BVH" begin + include("test_instanced_bvh.jl") + end + if _USE_LAVA + # Suites that hard-depend on Lava-specific types (LavaArray / + # HWTLAS). Don't run on the cpu matrix entry. + @testset "MultiTypeSet" begin + include("test_multitypeset.jl") + end + @testset "Mesh Update (Lava SW)" begin + include("test_mesh_update.jl") + end + @testset "AbstractAccel contract" begin + include("test_abstract_accel_contract.jl") + end + @testset "TLAS Stress" begin + include("test_tlas_stress.jl") + end + end + end end diff --git a/test/test_abstract_accel_contract.jl b/test/test_abstract_accel_contract.jl new file mode 100644 index 0000000..311c7ea --- /dev/null +++ b/test/test_abstract_accel_contract.jl @@ -0,0 +1,34 @@ +using Test, Raycore, GeometryBasics, StaticArrays, LinearAlgebra +using KernelAbstractions; const KA = KernelAbstractions +using Adapt +using Lava + +@testset "AbstractAccel — surface" begin + backend = Lava.LavaBackend() + tlas = Raycore.TLAS(backend) + mesh = GeometryBasics.normal_mesh(Sphere(Point3f(0), 1f0)) + push!(tlas, mesh, SMatrix{4,4,Float32}(I)) + Raycore.sync!(tlas) + + @test Raycore.n_instances(tlas) == 1 + @test Raycore.n_geometries(tlas) == 1 + @test Raycore.world_bound(tlas) isa Raycore.Bounds3 + + # wait_for_gpu! returns `accel` so it's chainable; smoke-test the contract. + @test_nowarn Raycore.wait_for_gpu!(tlas) + @test Raycore.wait_for_gpu!(tlas) === tlas +end + +@testset "AbstractAccel contract — Lava.HWTLAS" begin + backend = Lava.LavaBackend() + hwtlas = Lava.HWTLAS(backend) + mesh = GeometryBasics.normal_mesh(Sphere(Point3f(0), 1f0)) + push!(hwtlas, mesh, SMatrix{4,4,Float32}(I); instance_id=UInt32(1)) + Raycore.sync!(hwtlas) + + @test Raycore.n_instances(hwtlas) == 1 + @test Raycore.n_geometries(hwtlas) == 1 + @test Raycore.world_bound(hwtlas) isa Raycore.Bounds3 + @test_nowarn Raycore.wait_for_gpu!(hwtlas) + @test Raycore.wait_for_gpu!(hwtlas) === hwtlas +end diff --git a/test/test_instanced_bvh.jl b/test/test_instanced_bvh.jl new file mode 100644 index 0000000..5e550ca --- /dev/null +++ b/test/test_instanced_bvh.jl @@ -0,0 +1,1195 @@ +# ============================================================================== +# Instanced BVH Tests +# ============================================================================== + +using Test +using Raycore +using GeometryBasics +using StaticArrays +using LinearAlgebra +using KernelAbstractions + +# Use qualified names to avoid conflicts with other packages +const RTriangle = Raycore.Triangle # Conflicts with GeometryBasics.Triangle +const RBLAS = Raycore.BLAS # Conflicts with LinearAlgebra.BLAS +const is_leaf = Raycore.is_leaf +const is_interior = Raycore.is_interior + +@testset "Instanced BVH" begin + +@testset "Morton Code Generation" begin + # Test Morton code for known points + p1 = Point3f(0.0, 0.0, 0.0) + p2 = Point3f(1.0, 1.0, 1.0) + p3 = Point3f(0.5, 0.5, 0.5) + + code1 = Raycore.morton_code_30bit(p1) + code2 = Raycore.morton_code_30bit(p2) + code3 = Raycore.morton_code_30bit(p3) + + @test code1 isa UInt32 + @test code2 isa UInt32 + @test code3 isa UInt32 + + # Morton codes should order points along Z-curve + @test code1 < code2 + @test code1 < code3 < code2 +end + +@testset "BLAS Construction - Single Triangle" begin + # Create a single triangle + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + nothing + ) + + primitives = [tri] + blas = build_blas(primitives) + + @test blas isa RBLAS + @test length(blas.nodes) == 1 # Single triangle = 1 node (leaf) + @test length(blas.primitives) == 1 + @test is_leaf(blas.nodes[1]) + @test blas.nodes[1].child1 == UInt32(1) # Points to primitive 1 +end + +@testset "BLAS Construction - Multiple Triangles" begin + # Create a simple quad (2 triangles) + v1 = Point3f(0, 0, 0) + v2 = Point3f(1, 0, 0) + v3 = Point3f(1, 1, 0) + v4 = Point3f(0, 1, 0) + + tri1 = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(1, 1)), + nothing + ) + + tri2 = RTriangle( + SVector(v1, v3, v4), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 1), Point2f(0, 1)), + nothing + ) + + primitives = [tri1, tri2] + blas = build_blas(primitives) + + @test blas isa RBLAS + @test length(blas.nodes) == 3 # 2 leaves + 1 interior = 3 nodes + @test length(blas.primitives) == 2 + + # Check root is interior node + @test is_interior(blas.nodes[1]) + + # Check root AABB contains all primitives + root_aabb = blas.root_aabb + @test root_aabb.p_min[1] ≈ 0.0f0 + @test root_aabb.p_min[2] ≈ 0.0f0 + @test root_aabb.p_max[1] ≈ 1.0f0 + @test root_aabb.p_max[2] ≈ 1.0f0 +end + +@testset "BLAS Type Stability" begin + # Test type stability of build_blas + v1 = Point3f(0, 0, 0) + v2 = Point3f(1, 0, 0) + v3 = Point3f(0, 1, 0) + + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + nothing + ) + + primitives = [tri] + + # build_blas should be type-stable + result_type = @inferred build_blas(primitives) + @test result_type isa RBLAS +end + +@testset "Transform Utilities" begin + # Test point transformation + identity = Mat4f(I) + p = Point3f(1, 2, 3) + p_transformed = Raycore.transform_point(identity, p) + @test p_transformed ≈ p + + # Test translation + translation = Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 5, 10, 15, 1 + ) + p_translated = Raycore.transform_point(translation, p) + @test p_translated ≈ Point3f(6, 12, 18) + + # Test direction transformation (should ignore translation) + v = Vec3f(1, 0, 0) + v_transformed = Raycore.transform_direction(translation, v) + @test v_transformed ≈ v + + # Test type stability + @test (@inferred Raycore.transform_point(identity, p)) isa Point3f + @test (@inferred Raycore.transform_direction(identity, v)) isa Vec3f +end + +@testset "BVHNode2 Utilities" begin + # Test leaf detection + leaf_node = BVHNode2( + Point3f(0), Point3f(1), + Point3f(0), Point3f(0), + INVALID_NODE, UInt32(5), INVALID_NODE + ) + @test is_leaf(leaf_node) + @test !is_interior(leaf_node) + + # Test interior detection + interior_node = BVHNode2( + Point3f(0), Point3f(1), + Point3f(0), Point3f(1), + UInt32(2), UInt32(3), INVALID_NODE + ) + @test !is_leaf(interior_node) + @test is_interior(interior_node) + + # Test AABB extraction + aabb = Raycore.get_node_aabb(interior_node, true) + @test aabb isa Bounds3 + @test aabb.p_min == Point3f(0, 0, 0) + @test aabb.p_max == Point3f(1, 1, 1) +end + +@testset "AABB Utilities" begin + # Test expand_bits + @test Raycore.expand_bits(UInt32(0)) == UInt32(0) + @test Raycore.expand_bits(UInt32(1)) isa UInt32 + + # Test clz32 + @test Raycore.clz32(UInt32(0)) == Int32(32) + @test Raycore.clz32(UInt32(1)) == Int32(31) + @test Raycore.clz32(UInt32(0x80000000)) == Int32(0) +end + +@testset "Delta Function (LCP)" begin + # Test longest common prefix calculation + codes = UInt32[0x00000001, 0x00000002, 0x00000004, 0x00000008] + + # Adjacent codes with different prefixes + d1 = Raycore.delta(Int32(1), Int32(2), codes, Int32(4)) + d2 = Raycore.delta(Int32(2), Int32(3), codes, Int32(4)) + + @test d1 isa Int32 + @test d2 isa Int32 + + # Out of bounds should return -1 + d_oob = Raycore.delta(Int32(1), Int32(10), codes, Int32(4)) + @test d_oob == Int32(-1) +end + +# ============================================================================== +# TLAS Construction Tests +# ============================================================================== + +@testset "TLAS Construction - Single Instance" begin + # Create a single triangle as BLAS + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0))] + + tlas = build_tlas([blas], instances) + + @test tlas isa Raycore.StaticTLAS + @test length(tlas.instances) == 1 + @test length(tlas.blas_descriptors) == 1 + @test length(tlas.nodes) == 1 # Single instance = 1 node (leaf) +end + +@testset "TLAS Construction - Multiple Instances" begin + # Create a triangle BLAS + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + + # Create two instances with different transforms + identity = Mat4f(I) + translation = Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 5, 0, 0, 1 + ) + inv_translation = Mat4f(inv(translation)) + + instances = [ + InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0)), + InstanceDescriptor(UInt32(1), UInt32(2), translation, inv_translation, UInt32(0)) + ] + + tlas = build_tlas([blas], instances) + + @test tlas isa Raycore.StaticTLAS + @test length(tlas.instances) == 2 + @test length(tlas.blas_descriptors) == 1 + @test length(tlas.nodes) == 3 # 2 leaves + 1 interior = 3 nodes + + # World bound should encompass both instances + wb = world_bound(tlas) + @test wb.p_min[1] ≈ 0.0f0 + @test wb.p_max[1] ≈ 6.0f0 # Original + translated +end + +# ============================================================================== +# TLAS Ray Intersection Tests +# ============================================================================== + +@testset "TLAS closest_hit - Basic" begin + # Create a unit triangle in XY plane at z=0 + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(42) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0))] + tlas = build_tlas([blas], instances) + + # Ray pointing down at center of triangle + ray = Ray(o=Point3f(0.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + hit, prim, dist, bary, inst_id = closest_hit(tlas, ray) + + @test hit == true + @test dist ≈ 1.0f0 + @test prim.metadata == UInt32(42) + + # Ray missing the triangle + ray_miss = Ray(o=Point3f(2, 2, 1.0), d=Vec3f(0, 0, -1)) + hit_miss, _, _, _, _ = closest_hit(tlas, ray_miss) + @test hit_miss == false +end + +@testset "TLAS closest_hit - Transformed Instance" begin + # Create a unit triangle + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + + # Translate instance by (10, 0, 0) + translation = Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 10, 0, 0, 1 + ) + inv_translation = Mat4f(inv(translation)) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), translation, inv_translation, UInt32(0))] + + tlas = build_tlas([blas], instances) + + # Ray at original position should miss + ray_miss = Ray(o=Point3f(0.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + hit_miss, _, _, _, _ = closest_hit(tlas, ray_miss) + @test hit_miss == false + + # Ray at translated position should hit + ray_hit = Ray(o=Point3f(10.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + hit, _, dist, _, _ = closest_hit(tlas, ray_hit) + @test hit == true + @test dist ≈ 1.0f0 +end + +@testset "TLAS closest_hit - Multiple Instances (Closest Selection)" begin + # Create a unit triangle + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + + # Two instances: one at z=0, one at z=-5 (further from camera) + translate_back = Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, -5, 1 + ) + inv_translate_back = Mat4f(inv(translate_back)) + + instances = [ + InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0)), + InstanceDescriptor(UInt32(1), UInt32(2), translate_back, inv_translate_back, UInt32(0)) + ] + + tlas = build_tlas([blas], instances) + + # Ray should hit the closer one (z=0) + ray = Ray(o=Point3f(0.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + hit, _, dist, _, inst_id = closest_hit(tlas, ray) + + @test hit == true + @test dist ≈ 1.0f0 # Distance to z=0 plane + @test inst_id == UInt32(1) # First instance +end + +@testset "TLAS any_hit - Basic" begin + # Create a unit triangle + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0))] + tlas = build_tlas([blas], instances) + + # Ray hitting triangle + ray = Ray(o=Point3f(0.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + hit, _, _, _, _ = any_hit(tlas, ray) + @test hit == true + + # Ray missing triangle + ray_miss = Ray(o=Point3f(2, 2, 1.0), d=Vec3f(0, 0, -1)) + hit_miss, _, _, _, _ = any_hit(tlas, ray_miss) + @test hit_miss == false +end + +# ============================================================================== +# GB.Mesh TLAS API Tests +# ============================================================================== + +# Helper to create a GB.Mesh with normals +function make_test_mesh(verts, normals) + faces = [GLTriangleFace(1, 2, 3)] + GeometryBasics.mesh(verts, faces; normal=normals) +end + +@testset "TLASHandle and n_instances" begin + mesh1 = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + mesh2 = make_test_mesh( + [Point3f(5, 0, 0), Point3f(6, 0, 0), Point3f(5, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + tlas, handles = TLAS([mesh1, mesh2]) + + @test length(handles) == 2 + @test handles[1] isa TLASHandle + @test handles[2] isa TLASHandle + + count1 = n_instances(tlas, handles[1]) + count2 = n_instances(tlas, handles[2]) + + @test count1 == 1 + @test count2 == 1 + @test is_valid(tlas, handles[1]) + @test is_valid(tlas, handles[2]) +end + +@testset "TLAS with multi-transform push!" begin + mesh1 = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + mesh2 = make_test_mesh( + [Point3f(5, 0, 0), Point3f(6, 0, 0), Point3f(5, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + # Use multi-transform push! for mesh1 (instancing) + transforms = [ + Mat4f(I), + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 2, 0, 0, 1) + ] + + tlas = Raycore.TLAS(KernelAbstractions.CPU()) + h1 = push!(tlas, mesh1, transforms) + h2 = push!(tlas, mesh2) + sync!(tlas) + + @test n_geometries(tlas) == 2 # 2 unique BLAS + @test n_instances(tlas) == 3 # 2 + 1 = 3 instance descriptors + + @test n_instances(tlas, h1) == 2 # First handle has 2 instances + @test n_instances(tlas, h2) == 1 # Second handle has 1 instance +end + +@testset "TLAS from GB.Mesh Vector" begin + mesh1 = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + mesh2 = make_test_mesh( + [Point3f(5, 0, 0), Point3f(6, 0, 0), Point3f(5, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + tlas, handles = TLAS([mesh1, mesh2]) + + @test n_geometries(tlas) == 2 + @test n_instances(tlas) == 2 + @test length(handles) == 2 +end + +# ============================================================================== +# Dynamic Update Tests +# ============================================================================== + +@testset "update_transform! (single instance)" begin + mesh = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + tlas, handles = TLAS([mesh]) + + new_transform = Mat4f( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 10, 0, 0, 1 + ) + update_transform!(tlas, handles[1], new_transform) + + # TLAS stores transforms as Mat3x4f (Vulkan row-major 3×4); compare in + # that form, since `≈` between SMatrix{4,3} and SMatrix{4,4} would + # throw on size mismatch. + @test get_instance(tlas, handles[1]).transform ≈ Raycore.mat4_to_mat3x4(new_transform) +end + +@testset "update_transforms! (multiple instances)" begin + mesh = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + # Create with 3 transforms using multi-transform push! + initial_transforms = [Mat4f(I), Mat4f(I), Mat4f(I)] + tlas = Raycore.TLAS(KernelAbstractions.CPU()) + h = push!(tlas, mesh, initial_transforms) + sync!(tlas) + handles = [h] + + # Update all transforms + new_transforms = [ + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1), + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 2, 0, 0, 1), + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 3, 0, 0, 1) + ] + update_transforms!(tlas, handles[1], new_transforms) + + instances = get_instances(tlas, handles[1]) + for (i, inst) in enumerate(instances) + @test inst.transform ≈ Raycore.mat4_to_mat3x4(new_transforms[i]) + end +end + +@testset "push! GB.Mesh" begin + mesh1 = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + mesh2 = make_test_mesh( + [Point3f(5, 0, 0), Point3f(6, 0, 0), Point3f(5, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + tlas, handles = TLAS([mesh1]) + + @test n_geometries(tlas) == 1 + @test n_instances(tlas) == 1 + + # Add new mesh using push! + sync! + new_handle = push!(tlas, mesh2) + sync!(tlas) + + @test n_geometries(tlas) == 2 + @test n_instances(tlas) == 2 + @test new_handle isa TLASHandle +end + +@testset "delete! and sync!" begin + mesh1 = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + mesh2 = make_test_mesh( + [Point3f(5, 0, 0), Point3f(6, 0, 0), Point3f(5, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + tlas, handles = TLAS([mesh1, mesh2]) + + @test n_instances(tlas) == 2 + @test is_valid(tlas, handles[1]) + @test is_valid(tlas, handles[2]) + + deleted = delete!(tlas, handles[1]) + @test deleted == true + + @test !is_valid(tlas, handles[1]) + + sync!(tlas) + + @test n_instances(tlas) == 1 + @test is_valid(tlas, handles[2]) +end + +# ============================================================================== +# Type Stability Tests for TLAS +# ============================================================================== + +@testset "TLAS Type Stability" begin + # Create triangle with concrete metadata type + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0))] + tlas = build_tlas([blas], instances) + + ray = Ray(o=Point3f(0.25, 0.25, 1.0), d=Vec3f(0, 0, -1)) + + # Test type stability of closest_hit + result_type = @inferred closest_hit(tlas, ray) + @test result_type[1] isa Bool + @test result_type[2] isa RTriangle{UInt32} + @test result_type[3] isa Float32 + @test result_type[4] isa SVector{3, Float32} + @test result_type[5] isa UInt32 + + # Test type stability of any_hit + result_type_any = @inferred any_hit(tlas, ray) + @test result_type_any[1] isa Bool +end + +@testset "TLAS eltype" begin + # Verify eltype returns the correct triangle type without indexing into arrays + v1, v2, v3 = Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0) + tri = RTriangle( + SVector(v1, v2, v3), + SVector(Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)), + SVector(Vec3f(0), Vec3f(0), Vec3f(0)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(0, 1)), + UInt32(1) + ) + + blas = build_blas([tri]) + identity = Mat4f(I) + instances = [InstanceDescriptor(UInt32(1), UInt32(1), identity, identity, UInt32(0))] + tlas = build_tlas([blas], instances) + + @test eltype(tlas) == RTriangle{UInt32} +end + +@testset "n_instances and n_geometries" begin + mesh = make_test_mesh( + [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)], + [Normal3f(0, 0, 1), Normal3f(0, 0, 1), Normal3f(0, 0, 1)] + ) + + # 5 instances of same geometry using multi-transform push! + transforms = [Mat4f(I) for _ in 1:5] + tlas = Raycore.TLAS(KernelAbstractions.CPU()) + push!(tlas, mesh, transforms) + sync!(tlas) + + @test n_geometries(tlas) == 1 + @test n_instances(tlas) == 5 +end + +end # main testset "Instanced BVH" + +# ============================================================================== +# KernelAbstractions Dynamic Scene Tests (Lava backend) +# ============================================================================== + +# Load packages at top-level (required for KA.@kernel macro) +using Lava +import KernelAbstractions as KA +using KernelAbstractions: @index, @Const +import Adapt + +# =========================================================================== +# Kernel definitions - at top-level so KA.@kernel macro works +# =========================================================================== + +# Kernel 1: Basic closest_hit - returns hit/distance +KA.@kernel function closest_hit_kernel!(hits, distances, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, _, dist, _, _ = closest_hit(tlas, ray) + hits[i] = hit + distances[i] = dist + end +end + +# Kernel 2: any_hit for shadow/occlusion testing +KA.@kernel function any_hit_kernel!(hits, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, _, _, _, _ = any_hit(tlas, ray) + hits[i] = hit + end +end + +# Kernel 3: closest_hit with instance ID retrieval +KA.@kernel function closest_hit_instance_id_kernel!(hits, distances, instance_ids, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, _, dist, _, inst_id = closest_hit(tlas, ray) + hits[i] = hit + distances[i] = dist + instance_ids[i] = inst_id + end +end + +# Kernel 4: closest_hit with primitive metadata retrieval +KA.@kernel function closest_hit_metadata_kernel!(hits, metadata_out, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, prim, _, _, _ = closest_hit(tlas, ray) + hits[i] = hit + # Only access metadata if hit + metadata_out[i] = hit ? prim.metadata : UInt32(0) + end +end + +# Kernel 5: closest_hit with barycentric coordinates +KA.@kernel function closest_hit_bary_kernel!(hits, barys, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, _, _, bary, _ = closest_hit(tlas, ray) + hits[i] = hit + barys[i] = bary + end +end + +# Kernel 6: Batch trace with all outputs (stress test) +KA.@kernel function full_trace_kernel!(hits, distances, instance_ids, metadata_out, barys, tlas, origins, directions) + i = @index(Global, Linear) + @inbounds begin + ray = Ray(o=origins[i], d=directions[i]) + hit, prim, dist, bary, inst_id = closest_hit(tlas, ray) + hits[i] = hit + distances[i] = dist + instance_ids[i] = inst_id + metadata_out[i] = hit ? prim.metadata : UInt32(0) + barys[i] = bary + end +end + +# GPU kernel compilation is incompatible with --check-bounds=yes (Pkg.test default) +# because bounds checking injects error-throwing paths that can't compile to SPIR-V. +# Use: Pkg.test("Raycore"; julia_args=`--check-bounds=auto`) +if Base.JLOptions().check_bounds == 1 # 1 = --check-bounds=yes + @testset "KernelAbstractions Dynamic Scenes" begin + @test_broken false # skipped: --check-bounds=yes is incompatible with GPU kernel compilation + end +else +@testset "KernelAbstractions Dynamic Scenes" begin + cl_backend = test_backend() + + # Helper to create a simple GB.Mesh + function make_triangle_mesh(offset::Vec3f=Vec3f(0, 0, 0)) + verts = [ + Point3f(0, 0, 0) + offset, + Point3f(1, 0, 0) + offset, + Point3f(0, 1, 0) + offset + ] + norms = fill(Normal3f(0, 0, 1), 3) + faces = [GLTriangleFace(1, 2, 3)] + return GeometryBasics.mesh(verts, faces; normal=norms) + end + + @testset "TLAS adapt to LavaArray" begin + mesh = make_triangle_mesh() + tlas, handles = TLAS([mesh]; backend=cl_backend) + + # Adapt TLAS to Lava arrays (GPU-first: backend must match) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + @test cl_tlas isa Raycore.StaticTLAS + # GPU arrays (LavaArray) are not isbits on the host — KA handles + # the device pointer conversion during kernel launch. + # The kernel tests below verify that the TLAS works correctly on GPU. + if cl_backend isa KA.CPU + @test cl_tlas.nodes isa Vector + else + @test cl_tlas.nodes isa LavaArray + end + end + + @testset "TLAS sync with many instances" begin + mesh = make_triangle_mesh() + transforms = [Mat4f(1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + Float32(mod(i - 1, 9)) * 1.5f0, + Float32((i - 1) ÷ 9) * 1.25f0, + 0, + 1) for i in 1:81] + + tlas = Raycore.TLAS(cl_backend) + push!(tlas, mesh, transforms) + sync!(tlas) + + @test length(tlas.instances) == 81 + @test length(tlas.nodes) == 161 + @test Raycore.world_bound(tlas) isa Bounds3 + end + + @testset "closest_hit_kernel! - basic intersection" begin + mesh = make_triangle_mesh() + tlas, _ = TLAS([mesh]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 4 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + distances = KA.allocate(cl_backend, Float32, n) + + # Test rays: 2 hits, 2 misses + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 1.0), # hit + Point3f(0.5, 0.25, 1.0), # hit + Point3f(5.0, 5.0, 1.0), # miss + Point3f(-1.0, -1.0, 1.0) # miss + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = closest_hit_kernel!(cl_backend) + kernel(hits, distances, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + distances_cpu = Array(distances) + + @test hits_cpu[1] == true + @test hits_cpu[2] == true + @test hits_cpu[3] == false + @test hits_cpu[4] == false + @test distances_cpu[1] ≈ 1.0f0 + @test distances_cpu[2] ≈ 1.0f0 + end + + @testset "any_hit_kernel! - shadow/occlusion test" begin + mesh = make_triangle_mesh() + tlas, _ = TLAS([mesh]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 4 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + + # Test rays + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 1.0), # hit + Point3f(0.1, 0.1, 1.0), # hit + Point3f(5.0, 5.0, 1.0), # miss + Point3f(0.9, 0.9, 1.0) # miss (outside triangle) + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = any_hit_kernel!(cl_backend) + kernel(hits, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + @test hits_cpu[1] == true + @test hits_cpu[2] == true + @test hits_cpu[3] == false + @test hits_cpu[4] == false + end + + @testset "closest_hit_instance_id_kernel! - instance identification" begin + mesh = make_triangle_mesh() + + # Three instances at different positions. Traversal returns the + # 1-based instance array index; here we push 3 instances so each + # ray hits position 1, 2, 3. + transforms = [ + Mat4f(I), # Instance 1 at origin + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 5, 0, 0, 1), # Instance 2 at x=5 + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 5, 0, 1) # Instance 3 at y=5 + ] + tlas_tmp = Raycore.TLAS(cl_backend) + push!(tlas_tmp, mesh, transforms) + sync!(tlas_tmp) + tlas = tlas_tmp + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 3 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + distances = KA.allocate(cl_backend, Float32, n) + instance_ids = KA.allocate(cl_backend, UInt32, n) + + # Each ray targets a different instance + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 1.0), # hits instance 1 + Point3f(5.25, 0.25, 1.0), # hits instance 2 + Point3f(0.25, 5.25, 1.0) # hits instance 3 + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = closest_hit_instance_id_kernel!(cl_backend) + kernel(hits, distances, instance_ids, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + instance_ids_cpu = Array(instance_ids) + + @test all(hits_cpu) + # closest_hit returns the 1-based instance array index + @test instance_ids_cpu[1] == UInt32(1) + @test instance_ids_cpu[2] == UInt32(2) + @test instance_ids_cpu[3] == UInt32(3) + end + + @testset "closest_hit_metadata_kernel! - primitive metadata" begin + # Create meshes at different positions (metadata test simplified - mesh default is 0) + mesh1 = make_triangle_mesh(Vec3f(0, 0, 0)) + mesh2 = make_triangle_mesh(Vec3f(5, 0, 0)) + mesh3 = make_triangle_mesh(Vec3f(0, 5, 0)) + + tlas, _ = TLAS([mesh1, mesh2, mesh3]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 4 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + metadata_out = KA.allocate(cl_backend, UInt32, n) + + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 1.0), # hits mesh1 + Point3f(5.25, 0.25, 1.0), # hits mesh2 + Point3f(0.25, 5.25, 1.0), # hits mesh3 + Point3f(10.0, 10.0, 1.0) # miss + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = closest_hit_metadata_kernel!(cl_backend) + kernel(hits, metadata_out, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + + @test hits_cpu[1] == true + @test hits_cpu[2] == true + @test hits_cpu[3] == true + @test hits_cpu[4] == false + # Note: metadata from mesh is 0 by default, so we just test hits work + end + + @testset "closest_hit_bary_kernel! - barycentric coordinates" begin + mesh = make_triangle_mesh() + tlas, _ = TLAS([mesh]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 3 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + barys = KA.allocate(cl_backend, SVector{3, Float32}, n) + + # Triangle vertices: (0,0,0), (1,0,0), (0,1,0) + # Hit points chosen to give predictable barycentrics + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 1.0), # should give bary ≈ (0.25, 0.25, 0.5) + Point3f(0.1, 0.1, 1.0), # should give bary ≈ (0.1, 0.1, 0.8) + Point3f(0.5, 0.0, 1.0) # edge hit, bary ≈ (0.5, 0.0, 0.5) + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = closest_hit_bary_kernel!(cl_backend) + kernel(hits, barys, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + barys_cpu = Array(barys) + + @test all(hits_cpu) + # Barycentrics are (w, u, v) where w = 1-u-v + # For hit at (0.25, 0.25): u=0.25, v=0.25, w=0.5 + @test barys_cpu[1][1] ≈ 0.5f0 atol=0.01 # w + @test barys_cpu[1][2] ≈ 0.25f0 atol=0.01 # u + # For hit at (0.1, 0.1): u=0.1, v=0.1, w=0.8 + @test barys_cpu[2][1] ≈ 0.8f0 atol=0.01 # w + @test barys_cpu[2][2] ≈ 0.1f0 atol=0.01 # u + # For edge hit at (0.5, 0.0): u=0.5, v=0.0, w=0.5 + @test barys_cpu[3][1] ≈ 0.5f0 atol=0.01 # w + @test barys_cpu[3][2] ≈ 0.5f0 atol=0.01 # u + end + + @testset "full_trace_kernel! - comprehensive output" begin + mesh1 = make_triangle_mesh(Vec3f(0, 0, 0)) + mesh2 = make_triangle_mesh(Vec3f(5, 0, 0)) + + # Two default-override (inherit) instances; closest_hit returns + # their 1-based array positions (1 and 2). + tlas, _ = TLAS([mesh1, mesh2]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + n = 3 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + distances = KA.allocate(cl_backend, Float32, n) + instance_ids = KA.allocate(cl_backend, UInt32, n) + metadata_out = KA.allocate(cl_backend, UInt32, n) + barys = KA.allocate(cl_backend, SVector{3, Float32}, n) + + KA.copyto!(cl_backend, origins, [ + Point3f(0.25, 0.25, 2.0), # hits mesh1 at dist=2 + Point3f(5.25, 0.25, 3.0), # hits mesh2 at dist=3 + Point3f(10.0, 10.0, 1.0) # miss + ]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = full_trace_kernel!(cl_backend) + kernel(hits, distances, instance_ids, metadata_out, barys, cl_tlas, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + distances_cpu = Array(distances) + instance_ids_cpu = Array(instance_ids) + barys_cpu = Array(barys) + + @test hits_cpu[1] == true + @test hits_cpu[2] == true + @test hits_cpu[3] == false + + @test distances_cpu[1] ≈ 2.0f0 + @test distances_cpu[2] ≈ 3.0f0 + + @test instance_ids_cpu[1] == UInt32(1) + @test instance_ids_cpu[2] == UInt32(2) + + # Barycentrics are (w, u, v) where w = 1-u-v + # For hit at (0.25, 0.25): u=0.25, v=0.25, w=0.5 + @test barys_cpu[1][1] ≈ 0.5f0 atol=0.01 # w + @test barys_cpu[2][1] ≈ 0.5f0 atol=0.01 # w + end + + @testset "Dynamic transform updates via kernel" begin + mesh = make_triangle_mesh() + + # Create mutable TLAS with backend for dynamic updates + tlas = Raycore.TLAS(cl_backend) + handle = push!(tlas, mesh) + Raycore.sync!(tlas) + + # Initial position: ray at origin should hit + cl_tlas1 = Adapt.adapt(cl_backend, tlas) + + n = 1 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + distances = KA.allocate(cl_backend, Float32, n) + + KA.copyto!(cl_backend, origins, [Point3f(0.25, 0.25, 1.0)]) + KA.copyto!(cl_backend, directions, [Vec3f(0, 0, -1)]) + + kernel = closest_hit_kernel!(cl_backend) + kernel(hits, distances, cl_tlas1, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + @test Array(hits)[1] == true + + # Update transform: move to x=10 + new_transform = Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 10, 0, 0, 1) + Raycore.update_transform!(tlas, handle, new_transform) + Raycore.sync!(tlas) + + # Adapt again after update + cl_tlas2 = Adapt.adapt(cl_backend, tlas) + + # Now ray at origin should miss + kernel(hits, distances, cl_tlas2, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + @test Array(hits)[1] == false + + # Ray at x=10 should hit + KA.copyto!(cl_backend, origins, [Point3f(10.25, 0.25, 1.0)]) + kernel(hits, distances, cl_tlas2, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + @test Array(hits)[1] == true + @test Array(distances)[1] ≈ 1.0f0 + end + + @testset "Dynamic scene: add instances via kernel" begin + mesh1 = make_triangle_mesh() + mesh2 = make_triangle_mesh(Vec3f(5, 0, 0)) + + # Create mutable TLAS + tlas = Raycore.TLAS(cl_backend) + h1 = push!(tlas, mesh1) + Raycore.sync!(tlas) + + n = 2 + origins = KA.allocate(cl_backend, Point3f, n) + directions = KA.allocate(cl_backend, Vec3f, n) + hits = KA.allocate(cl_backend, Bool, n) + distances = KA.allocate(cl_backend, Float32, n) + + KA.copyto!(cl_backend, origins, [Point3f(0.25, 0.25, 1.0), Point3f(5.25, 0.25, 1.0)]) + KA.copyto!(cl_backend, directions, fill(Vec3f(0, 0, -1), n)) + + kernel = closest_hit_kernel!(cl_backend) + + # Test with just first instance + cl_tlas1 = Adapt.adapt(cl_backend, tlas) + kernel(hits, distances, cl_tlas1, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + @test hits_cpu[1] == true # first mesh + @test hits_cpu[2] == false # second mesh not added yet + + # Add second instance + h2 = push!(tlas, mesh2) + Raycore.sync!(tlas) + + # Test again with both instances + cl_tlas2 = Adapt.adapt(cl_backend, tlas) + kernel(hits, distances, cl_tlas2, origins, directions; ndrange=n) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + @test hits_cpu[1] == true # first mesh + @test hits_cpu[2] == true # second mesh now present + end + + @testset "Batch ray tracing via kernel (64 rays)" begin + mesh = make_triangle_mesh() + tlas, _ = TLAS([mesh]; backend=cl_backend) + cl_tlas = Adapt.adapt(cl_backend, tlas) + + # Create batch of rays + n_rays = 64 + origins_vec = [Point3f(0.25 + 0.5*(i % 8)/7, 0.25 + 0.5*((i ÷ 8) % 8)/7, 1.0) for i in 0:n_rays-1] + directions_vec = fill(Vec3f(0, 0, -1), n_rays) + + origins = KA.allocate(cl_backend, Point3f, n_rays) + directions = KA.allocate(cl_backend, Vec3f, n_rays) + hits = KA.allocate(cl_backend, Bool, n_rays) + distances = KA.allocate(cl_backend, Float32, n_rays) + + KA.copyto!(cl_backend, origins, origins_vec) + KA.copyto!(cl_backend, directions, directions_vec) + + kernel = closest_hit_kernel!(cl_backend) + kernel(hits, distances, cl_tlas, origins, directions; ndrange=n_rays) + KA.synchronize(cl_backend) + + hits_cpu = Array(hits) + n_hits = count(hits_cpu) + @test n_hits > 0 # At least some hits + @test n_hits < n_rays # Some misses near edges + end + + @testset "StaticTLAS field types after adapt" begin + mesh = make_triangle_mesh() + tlas, _ = TLAS([mesh]; backend=cl_backend) + + cl_tlas = Adapt.adapt(cl_backend, tlas) + + # Verify fields land on the right backend after adapt. + ArrayType = cl_backend isa KA.CPU ? Vector : LavaArray + @test cl_tlas.nodes isa ArrayType + @test cl_tlas.instances isa ArrayType + @test cl_tlas.all_blas_nodes isa ArrayType + @test cl_tlas.all_blas_prims isa ArrayType + @test cl_tlas.blas_descriptors isa ArrayType + # root_aabb stays isbits (not an array) + @test isbitstype(typeof(cl_tlas.root_aabb)) + end + + @testset "World bound preserved after adapt" begin + mesh = make_triangle_mesh() + transforms = [ + Mat4f(I), + Mat4f(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 10, 10, 0, 1) + ] + tlas, _ = begin; tlas_tmp = Raycore.TLAS(cl_backend); push!(tlas_tmp, mesh, transforms); sync!(tlas_tmp); (tlas_tmp, [TLASHandle(UInt32(1))]); end + + gpu_bound = tlas.root_aabb + cl_tlas = Adapt.adapt(cl_backend, tlas) + cl_bound = cl_tlas.root_aabb + + @test gpu_bound.p_min ≈ cl_bound.p_min + @test gpu_bound.p_max ≈ cl_bound.p_max + end + +end +end # if check_bounds diff --git a/test/test_intersection.jl b/test/test_intersection.jl index 5b749eb..a3f2cd8 100644 --- a/test/test_intersection.jl +++ b/test/test_intersection.jl @@ -19,33 +19,30 @@ @test !Raycore.intersect_p(b_neg, r1, inv_dir, dir_is_negative) end -# Note: Ray-Sphere intersection tests moved to Trace.jl -# Raycore no longer has Sphere shapes - only low-level triangle intersection - @testset "Test triangle" begin - triangles = Raycore.TriangleMesh( - [Point3f(0, 0, 2), Point3f(1, 0, 2), Point3f(1, 1, 2)], - UInt32[1, 2, 3], - [Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)], + # Construct Triangle directly (no TriangleMesh) + triangle = Raycore.Triangle( + SVector(Point3f(0, 0, 2), Point3f(1, 0, 2), Point3f(1, 1, 2)), + SVector(Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)), + SVector(Vec3f(NaN), Vec3f(NaN), Vec3f(NaN)), + SVector(Point2f(0, 0), Point2f(1, 0), Point2f(1, 1)), + nothing ) - triangle = Raycore.Triangle(triangles, 1) tv = Raycore.vertices(triangle) a = norm(tv[1] - tv[2])^2 * 0.5f0 @test Raycore.area(triangle) ≈ a target_wb = Raycore.Bounds3(Point3f(0, 0, 2), Point3f(1, 1, 2)) - # In the refactored API, object_bound returns world bounds since transformation is applied during creation @test Raycore.object_bound(triangle) ≈ target_wb - # Test ray intersection - API has changed: intersect now returns (Bool, Float32, Point3f) with barycentric coords + # Test ray intersection ray = Raycore.Ray(o = Point3f(0, 0, -2), d = Vec3f(0, 0, 1)) intersects_p = Raycore.intersect_p(triangle, ray) intersects, t_hit, bary_coords = Raycore.intersect(triangle, ray) @test intersects_p == intersects == true @test t_hit ≈ 4f0 @test Raycore.apply(ray, t_hit) ≈ Point3f(0, 0, 2) - # Barycentric coordinates for vertex 0 (corner hit) @test bary_coords ≈ Point3f(1, 0, 0) # Test ray intersection (different point). @@ -57,57 +54,89 @@ end @test Raycore.apply(ray, t_hit) ≈ Point3f(0.5, 0.25, 2) end -# BVH tests with spheres removed - refactored Raycore only supports triangle meshes in BVH -@testset "BVH" begin - # Create triangle meshes instead of spheres - triangle_meshes = [] - for i in 0:1:3 # Use fewer triangles for simpler test - core = Raycore.translate(Vec3f(i*3, i*3, 0)) - mesh = Raycore.TriangleMesh( - core.([Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(1, 1, 0)]), - UInt32[1, 2, 3], - [Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)], - ) - push!(triangle_meshes, mesh) +@testset "TLAS with triangle meshes" begin + using GeometryBasics + using LinearAlgebra + + # Create simple triangle meshes as GB.Mesh at different positions + function make_gb_mesh(offset::Vec3f=Vec3f(0, 0, 0)) + verts = [Point3f(0, 0, 0) + offset, Point3f(1, 0, 0) + offset, Point3f(1, 1, 0) + offset] + faces = [GLTriangleFace(1, 2, 3)] + normals = [Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)] + GeometryBasics.mesh(verts, faces; normal=normals) end - bvh = Raycore.BVH(triangle_meshes) - # Test basic BVH functionality with triangle meshes - @test !isnothing(Raycore.world_bound(bvh)) + meshes = [make_gb_mesh(Vec3f(i*3, i*3, 0)) for i in 0:3] + + tlas = Raycore.TLAS(meshes, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + @test !isnothing(Raycore.world_bound(tlas)) # Simple intersection test ray = Raycore.Ray(o = Point3f(0.5, 0.5, -1), d = Vec3f(0, 0, 1)) - intersects, interaction = Raycore.closest_hit(bvh, ray) - @test intersects + hit, tri, dist, bary, inst_id = Raycore.closest_hit(tlas, ray) + @test hit + @test tri isa Raycore.Triangle end -# BVH test with spheres removed - using triangle meshes instead -@testset "Test BVH with triangle meshes in a row" begin - triangle_meshes = [] - - # Create triangle meshes at different z positions - positions = [0, 4, 8] - vertices = [Point3f(-1, -1, 0), Point3f(1, -1, 0), Point3f(0, 1, 0)] - for (i, z) in enumerate(positions) - core = Raycore.translate(Vec3f(0, 0, z)) - vs = core.(vertices) - mesh = Raycore.TriangleMesh( - vs, - UInt32[1, 2, 3], - [Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)], - ) - push!(triangle_meshes, mesh) +@testset "TLAS with triangle meshes in a row" begin + using GeometryBasics + using LinearAlgebra + + function make_gb_mesh_at_z(z::Float32) + verts = [Point3f(-1, -1, z), Point3f(1, -1, z), Point3f(0, 1, z)] + faces = [GLTriangleFace(1, 2, 3)] + normals = [Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1), Raycore.Normal3f(0, 0, -1)] + GeometryBasics.mesh(verts, faces; normal=normals) end - bvh = Raycore.BVH(triangle_meshes) - # Test that BVH can be created and has a valid bound - bound = Raycore.world_bound(bvh) + meshes = [make_gb_mesh_at_z(Float32(z)) for z in [0, 4, 8]] + + tlas = Raycore.TLAS(meshes, (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + bound = Raycore.world_bound(tlas) @test !isnothing(bound) # Test intersection with the first triangle ray = Raycore.Ray(o = Point3f(0, 0, -2), d = Vec3f(0, 0, 1)) - intersects, triangle = Raycore.closest_hit(bvh, ray) - @test intersects - # BVH closest_hit returns Triangle object, not SurfaceInteraction - @test triangle isa Raycore.Triangle + hit, tri, dist, bary, inst_id = Raycore.closest_hit(tlas, ray) + @test hit + @test tri isa Raycore.Triangle +end + +@testset "empty_triangle" begin + e = Raycore.empty_triangle(Raycore.Triangle{UInt32}) + @test e isa Raycore.Triangle{UInt32} + @test all(v -> all(iszero, v), e.vertices) + @test all(n -> all(iszero, n), e.normals) + @test all(t -> all(iszero, t), e.tangents) + @test all(u -> all(iszero, u), e.uv) + @test e.metadata == zero(UInt32) + + # Works for arbitrary metadata types that have `zero(T)`. + e2 = Raycore.empty_triangle(Raycore.Triangle{Int32}) + @test e2 isa Raycore.Triangle{Int32} + @test e2.metadata == zero(Int32) +end + +@testset "closest_hit no-hit returns empty_triangle sentinel" begin + using GeometryBasics + + function make_unit_mesh() + verts = [Point3f(0,0,0), Point3f(1,0,0), Point3f(0,1,0)] + faces = [GLTriangleFace(1,2,3)] + normals = [Raycore.Normal3f(0,0,1), Raycore.Normal3f(0,0,1), Raycore.Normal3f(0,0,1)] + GeometryBasics.mesh(verts, faces; normal=normals) + end + + tlas = Raycore.TLAS([make_unit_mesh()], (mesh_idx, tri_idx) -> UInt32(mesh_idx)) + + # Ray clearly misses: origin far away, direction pointing further away + ray = Raycore.Ray(o = Point3f(100, 100, 100), d = Vec3f(1, 0, 0)) + hit, tri, _, _, _ = Raycore.closest_hit(tlas, ray) + + @test !hit + @test tri isa Raycore.Triangle + # Returned sentinel must be the zero triangle, not a storage-indexed triangle + Tri = eltype(tlas.all_blas_prims) + @test tri == Raycore.empty_triangle(Tri) + @test all(v -> all(iszero, v), tri.vertices) end diff --git a/test/test_mesh_update.jl b/test/test_mesh_update.jl new file mode 100644 index 0000000..9108ff4 --- /dev/null +++ b/test/test_mesh_update.jl @@ -0,0 +1,298 @@ +# ============================================================================== +# Mesh update tests: correctness + no-leak under delete+push cycles +# ============================================================================== +# +# The TLAS API has no in-place "update mesh geometry" — to change the triangles +# of an existing instance you `delete!` the old handle and `push!` a new mesh. +# When the new mesh has a *different vertex count*, every backing buffer +# (BLAS nodes, BLAS primitives, per-instance tri/offset buffers on HW) is a +# different size too, so a stale reference from any previous dispatch or cached +# descriptor will fault or return the wrong geometry. +# +# This suite covers both TLAS backings, both driven through Lava + Vulkan: +# +# 1. SW TLAS (`Raycore.TLAS`) — BVH traversed on GPU via a KernelAbstractions +# `closest_hit` kernel, with the backing `LavaBackend`. Verified after +# every mutation. +# 2. HW TLAS (`Raycore.HWTLAS`) — Vulkan hardware ray tracing. Verified via +# `trace_closest_hits!`. +# +# For each backend we oscillate the mesh tessellation count (small ↔ big ↔ +# small) many times and assert: +# - The hit is always at the sphere surface within tolerance (correctness — +# catches stale BLAS captures: ray would miss or come back with wrong t if +# any pointer stayed captured). +# - Internal GPU-side resource counters stay bounded (leak / UAF bound: pool +# blocks and live buffers must not scale with iteration count). +# +# Lava is a hard test dep for Raycore, so this runs as part of the normal suite. +# ============================================================================== + +using Test +using GeometryBasics +using LinearAlgebra +using StaticArrays +using Raycore +using KernelAbstractions +const KA = KernelAbstractions +using Adapt +using Lava + +const GPU_BACKEND = Lava.LavaBackend() + +# ------------------------------------------------------------------------------ +# Shared: sphere mesh with varying tessellation + analytic ray/sphere intersect +# ------------------------------------------------------------------------------ + +"""Unit sphere centred at origin; `n` = tesselation count (higher = more tris). +Vertex count ≈ (n+1)^2, so successive `n`s give meaningfully different BLAS sizes.""" +function sphere_mesh(n::Int) + GeometryBasics.normal_mesh(Tesselation(Sphere(Point3f(0), 1f0), n)) +end + +"""Ray straight down the +z axis from z=5 at (0, 0). For a unit sphere at the +origin translated by `offset`, the closest hit is at z = offset.z + 1, i.e. +t = 5 - (offset.z + 1) = 4 - offset.z.""" +expected_t(offset_z::Real) = Float32(5) - Float32(offset_z) - Float32(1) + +translation(dx, dy, dz) = SMatrix{4,4,Float32,16}( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + dx, dy, dz, 1, +) + +# ------------------------------------------------------------------------------ +# SW TLAS mesh-update test (Raycore BVH traversal on Lava) +# ------------------------------------------------------------------------------ + +KA.@kernel function sw_trace_one_kernel!(hit_out, t_out, tlas, origin, direction) + ray = Raycore.Ray(; o=origin, d=direction) + hit, _, dist, _, _ = Raycore.closest_hit(tlas, ray) + hit_out[1] = hit + t_out[1] = Float32(dist) +end + +"Trace one ray down the +z axis from (0,0,5) and return (hit, t) on CPU." +function sw_trace_one(tlas) + static_tlas = Adapt.adapt(GPU_BACKEND, tlas) + hit = KA.zeros(GPU_BACKEND, Bool, 1) + t = KA.zeros(GPU_BACKEND, Float32, 1) + origin = Point3f(0f0, 0f0, 5f0) + direction = Vec3f(0f0, 0f0, -1f0) + sw_trace_one_kernel!(GPU_BACKEND)(hit, t, static_tlas, origin, direction; ndrange=1) + KA.synchronize(GPU_BACKEND) + return (hit = Array(hit)[1], t = Array(t)[1]) +end + +"Replace the one mesh in `tlas` with a fresh `sphere_mesh(n)` at `offset_z`." +function sw_swap_mesh!(tlas, handle, n, offset_z) + Raycore.delete!(tlas, handle) + new_handle = push!(tlas, sphere_mesh(n), translation(0, 0, offset_z)) + Raycore.sync!(tlas) + return new_handle +end + +@testset "SW TLAS — mesh update correctness under size oscillation" begin + tlas = Raycore.TLAS(GPU_BACKEND) + handle = push!(tlas, sphere_mesh(16), translation(0, 0, 0)) + Raycore.sync!(tlas) + + # Baseline + r = sw_trace_one(tlas) + @test r.hit + @test isapprox(r.t, expected_t(0); atol=0.05f0) + + # Oscillate small → big → small → bigger → smaller. Vertex count varies + # non-monotonically to cover shrink-in-place and grow-out-of-place paths. + tess_schedule = [32, 8, 48, 12, 64, 16, 8, 32, 96, 16] + for (i, n) in enumerate(tess_schedule) + offset_z = Float32(0.05 * i) # nudge z so a wrong stale geometry shows up as wrong t + handle = sw_swap_mesh!(tlas, handle, n, offset_z) + r = sw_trace_one(tlas) + @test r.hit broken=false + @test isapprox(r.t, expected_t(offset_z); atol=0.1f0) + end +end + +@testset "SW TLAS — adapt-once-then-mutate via tlas.static_tlas" begin + # Invariant: `sync!(tlas)` is the single owner of `tlas.static_tlas`. A + # consumer that re-reads `tlas.static_tlas` (or calls `Adapt.adapt(backend, + # tlas)`) per dispatch MUST see any mutation that went through `push!` / + # `delete!` + `sync!`. A consumer that caches an old static_tlas across a + # mutation MAY see stale data — that's the contract consumers must honour. + # + # The pre-refactor `Hikari.VolPath.get_or_adapt_scene!` cached on + # `objectid(scene)` alone, which silently violated the "re-read per + # dispatch" rule and rendered frozen geometry in the dolphin video. This + # test nails the contract down at the Raycore level. + tlas = Raycore.TLAS(GPU_BACKEND) + handle = push!(tlas, sphere_mesh(16), translation(0, 0, 0)) + + # First adapt triggers build of tlas.static_tlas. + st_before = Adapt.adapt(GPU_BACKEND, tlas) + @test tlas.static_tlas === st_before + + hit_before = let hit = KA.zeros(GPU_BACKEND, Bool, 1), t = KA.zeros(GPU_BACKEND, Float32, 1) + sw_trace_one_kernel!(GPU_BACKEND)(hit, t, st_before, Point3f(0,0,5), Vec3f(0,0,-1); ndrange=1) + KA.synchronize(GPU_BACKEND) + (hit = Array(hit)[1], t = Array(t)[1]) + end + @test hit_before.hit + @test isapprox(hit_before.t, expected_t(0); atol=0.05f0) + + # Mutate the TLAS: swap mesh to one shifted +2 in z. The expected t moves + # from 4.0 to 2.0, so any stale-geometry trace shows up obviously. + Raycore.delete!(tlas, handle) + handle = push!(tlas, sphere_mesh(48), translation(0, 0, 2f0)) + + # A consumer re-reads `tlas.static_tlas` per dispatch (the canonical path): + st_after = Adapt.adapt(GPU_BACKEND, tlas) + hit_after_fresh = let hit = KA.zeros(GPU_BACKEND, Bool, 1), t = KA.zeros(GPU_BACKEND, Float32, 1) + sw_trace_one_kernel!(GPU_BACKEND)(hit, t, st_after, Point3f(0,0,5), Vec3f(0,0,-1); ndrange=1) + KA.synchronize(GPU_BACKEND) + (hit = Array(hit)[1], t = Array(t)[1]) + end + @test hit_after_fresh.hit + @test isapprox(hit_after_fresh.t, expected_t(2); atol=0.1f0) + + # sync! on a clean TLAS is a no-op: doesn't change static_tlas identity, + # doesn't issue a GPU synchronize. The field is already up to date. + st_pinned = tlas.static_tlas + Raycore.sync!(tlas) + @test tlas.static_tlas === st_pinned + + # A consumer that CACHED `st_before` across the mutation may see stale + # data: this is by design — the invariant pushes that work onto the + # consumer (re-read `tlas.static_tlas` per dispatch). Test that we NOTICE + # staleness when it happens, so regressions that make consumers silently + # cache are flagged. + hit_stale = let hit = KA.zeros(GPU_BACKEND, Bool, 1), t = KA.zeros(GPU_BACKEND, Float32, 1) + sw_trace_one_kernel!(GPU_BACKEND)(hit, t, st_before, Point3f(0,0,5), Vec3f(0,0,-1); ndrange=1) + KA.synchronize(GPU_BACKEND) + (hit = Array(hit)[1], t = Array(t)[1]) + end + # Either st_before's backing buffer was reused in place (stale consumer + # accidentally sees new data — allowed but not guaranteed), or it was + # reallocated (stale consumer sees old data). In the reallocation case + # st_before !== st_after and t ≈ expected_t(0). + if st_before !== st_after + @test isapprox(hit_stale.t, expected_t(0); atol=0.05f0) # stale snapshot + end +end + +@testset "SW TLAS — transform-update refit path" begin + # The refit path is the "cheap in-place" branch: update_transform! + sync! + # must update the leaf AABBs in tlas.nodes in place, so a cached static_tlas + # (that wraps the same backing buffer) sees the new position without needing + # a fresh adapt call. + # + # Pre-fix, refit_tlas! had `tlas.dirty || return tlas` (wrong flag) so it + # always short-circuited. sync! would run refit_tlas! but refit was a no-op; + # tlas.transforms_dirty stayed true forever; subsequent clean-path fast + # returns never kicked in. This test pins the refit wiring. + tlas = Raycore.TLAS(GPU_BACKEND) + handle = push!(tlas, sphere_mesh(16), translation(0, 0, 0)) + st_initial = Adapt.adapt(GPU_BACKEND, tlas) + + # Baseline: unit sphere at z=0, ray from z=5 hits at z=1 → t=4. + r = sw_trace_one(tlas) + @test r.hit + @test isapprox(r.t, expected_t(0); atol=0.05f0) + + # Move the instance to z=1.5 via update_transform!. The mesh & handle stay + # the same; only the instance transform changes. + Raycore.update_transform!(tlas, handle, translation(0, 0, 1.5f0)) + @test tlas.transforms_dirty + + # sync! must run refit and clear transforms_dirty — and keep static_tlas + # valid (same object, because refit updates tlas.nodes in place). + Raycore.sync!(tlas) + @test !tlas.transforms_dirty + @test !tlas.dirty + # refit updates AABBs in place in tlas.nodes — static_tlas identity stays. + @test tlas.static_tlas === st_initial + + # New expected t: sphere at z=1.5, hit at z=2.5, t = 5 - 2.5 = 2.5. + r2 = sw_trace_one(tlas) + @test r2.hit + @test isapprox(r2.t, expected_t(1.5); atol=0.05f0) + + # Clean-path sync! is a true no-op: static_tlas identity unchanged, no GPU + # sync, no allocations in the repeated calls. + Raycore.sync!(tlas) + Raycore.sync!(tlas) + Raycore.sync!(tlas) + @test tlas.static_tlas === st_initial +end + +@testset "SW TLAS — only one live static_tlas across many swaps (leak bound)" begin + # The static_tlas field is the single owner of the adapted form. Overwriting + # it on every rebuild means the old StaticTLAS goes unreferenced and is + # collectable. Prior draft designs (kept a cache of adapted_scene keyed by + # objectid in VolPath) accumulated references across mutations — that's + # the regression this test exists to prevent. + tlas = Raycore.TLAS(GPU_BACKEND) + handle = push!(tlas, sphere_mesh(16), translation(0, 0, 0)) + Raycore.sync!(tlas) + + # Hold a weak reference to the first static_tlas; it should be collectable + # after enough swaps since no one else is keeping it alive. + first_static = tlas.static_tlas + wref = WeakRef(first_static) + first_static = nothing # drop the hard local ref + # Ensure the rest of this iteration doesn't pin `first_static` on the stack + # via sentinel bindings by doing enough work in between. + + for iter in 1:20 + n = isodd(iter) ? 32 : 16 + handle = sw_swap_mesh!(tlas, handle, n, Float32(0.01 * iter)) + _ = Adapt.adapt(GPU_BACKEND, tlas) + end + GC.gc(true) + + # The rebuild path reallocates tlas.nodes each swap, so the original + # static_tlas is unreachable. With no cache hanging on to it, GC should + # collect it. + # pre-fix regression guard: a static_tlas from before N swaps must be collectable. + @test wref.value === nothing +end + +@testset "SW TLAS — mesh update leak bound (Julia heap)" begin + tlas = Raycore.TLAS(GPU_BACKEND) + handle = push!(tlas, sphere_mesh(16), translation(0, 0, 0)) + Raycore.sync!(tlas) + + # Warm up cycle so JIT and pool-style caches settle before sampling. + for _ in 1:3 + handle = sw_swap_mesh!(tlas, handle, 32, 0.1f0) + _ = sw_trace_one(tlas) + end + GC.gc(true) + + n_iters = 50 + for iter in 1:n_iters + n = iseven(iter) ? 16 : 48 # oscillate + handle = sw_swap_mesh!(tlas, handle, n, Float32(0.01 * iter)) + @assert sw_trace_one(tlas).hit + + # Tight invariants checked EVERY iteration: any leak that adds even + # one entry per cycle is caught immediately, not buried in slack. + @test length(tlas.instances) == 1 + @test length(tlas.blas_storage) == 1 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + @test length(tlas._flat_blas_nodes) == length(tlas.blas_storage[1].nodes) + @test length(tlas.deleted_handles) == 0 + end + GC.gc(true) + + # Final state: exactly one live mesh, flat arrays match exactly. + @test length(tlas.instances) == 1 + @test length(tlas.blas_storage) == 1 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + @test length(tlas._flat_blas_nodes) == length(tlas.blas_storage[1].nodes) +end + +# HW TLAS mesh-update tests relocated to Lava in Phase F. + +println("\nAll mesh-update tests passed.") diff --git a/test/test_multitypeset.jl b/test/test_multitypeset.jl new file mode 100644 index 0000000..d556ddd --- /dev/null +++ b/test/test_multitypeset.jl @@ -0,0 +1,141 @@ +using Test +using Raycore: MultiTypeSet, StaticMultiTypeSet, SetKey, TextureRef +using Raycore: with_index, deref, is_valid, is_invalid, n_slots, update! +using KernelAbstractions +using Adapt +using Lava + +backend = Lava.LavaBackend() + +# Test structs - used for both CPU and GPU tests +struct SimpleMaterial{T} + color::T +end + +struct GlassMaterial{T} + ior::T +end + +struct MaterialWith2{T, T2} + albedo::T + texture::T2 +end + +@testset "MultiTypeSet basic" begin + dhv = MultiTypeSet(backend) + @test isempty(dhv) + + idx1 = push!(dhv, SimpleMaterial(0.5f0)) + @test idx1.type_idx == 1 + @test idx1.vec_idx == 1 + @test !isempty(dhv) + + idx2 = push!(dhv, GlassMaterial(1.5f0)) + @test idx2.type_idx == 2 + @test idx2.vec_idx == 1 + + idx3 = push!(dhv, SimpleMaterial(0.8f0)) + @test idx3.type_idx == 1 + @test idx3.vec_idx == 2 + + # Static is always up-to-date + @test n_slots(dhv.static) == 2 +end + +@testset "MultiTypeSet update! with invalid SetKey is a no-op" begin + # Regression guard: `push!(set, item)` can return `SetKey()` (the (0,0) + # invalid sentinel) for types that own no slot in the set — e.g. + # Hikari's `NullMaterial` which pbrt-v4 uses as the "Material interface" + # /nullptr equivalent, or a `MediumInterface` side left as `nothing`. + # Callers that reuse that key on `update!` must get a silent no-op, + # NOT a BoundsError. Prior to the fix, `update!` indexed + # `dhv.data_order[0]` → `BoundsError: attempt to access 1-element + # Vector{DataType} at index [0]` — which crashed RayMakie's mesh-swap + # path for volumes built with `MediumInterface(NullMaterial(); inside=…)`. + dhv = MultiTypeSet(backend) + _ = push!(dhv, SimpleMaterial(0.5f0)) # something so data_order is non-empty + before = (n_slots(dhv.static), length(dhv.static)) + @test update!(dhv, SetKey(), SimpleMaterial(0.9f0)) === nothing + @test update!(dhv, SetKey(), GlassMaterial(1.7f0)) === nothing # wrong type, still no-op + @test (n_slots(dhv.static), length(dhv.static)) == before +end + +@testset "Empty MultiTypeSet" begin + dhv = MultiTypeSet(backend) + @test isempty(dhv) + + smv = dhv.static + @test isempty(smv) + @test n_slots(smv) == 0 +end + +@testset "GPU kernel with MaterialWith2" begin + dhv = MultiTypeSet(backend) + arr1 = Float32[1 2; 3 4] + arr2 = Float32[5, 6, 7] + arr3 = Float32[8 9; 10 11] + arr4 = Float32[12, 13, 14] + + idx1 = push!(dhv, MaterialWith2(arr1, arr2)) + idx2 = push!(dhv, MaterialWith2(arr3, arr4)) + + # static field is already GPU-ready + smv = dhv.static + + # Check structure + @test smv.data[1] isa Lava.LavaArray + @test smv.textures[1] isa Lava.LavaArray + @test smv.textures[2] isa Lava.LavaArray + + # Kernel that accesses both texture fields via deref + @kernel function mat2_kernel(out, smv, idxs) + i = @index(Global) + get_sum(mat, s) = begin + t1 = deref(s, mat.albedo) + t2 = deref(s, mat.texture) + t1[1,1] + t2[1] # First element of each texture + end + out[i] = with_index(get_sum, smv, idxs[i], smv) + end + + indices = LavaArray([idx1, idx2]) + output = LavaArray(zeros(Float32, 2)) + + kernel = mat2_kernel(backend) + kernel(output, smv, indices; ndrange=2) + KernelAbstractions.synchronize(backend) + + result = Array(output) + @test result ≈ [arr1[1,1] + arr2[1], arr3[1,1] + arr4[1]] +end + +@testset "StaticMultiTypeSet on GPU (no textures)" begin + dhv = MultiTypeSet(backend) + idx1 = push!(dhv, SimpleMaterial(0.5f0)) + idx2 = push!(dhv, GlassMaterial(1.5f0)) + idx3 = push!(dhv, SimpleMaterial(0.8f0)) + + smv = dhv.static + + # Check that inner arrays are LavaArrays + @test smv.data[1] isa Lava.LavaArray + @test smv.data[2] isa Lava.LavaArray + + # Run kernel + @kernel function simple_kernel(output, hvec, indices) + i = @index(Global) + get_val(m::SimpleMaterial) = m.color + get_val(m::GlassMaterial) = m.ior + output[i] = with_index(get_val, hvec, indices[i]) + end + + indices = LavaArray([idx1, idx2, idx3]) + output = LavaArray(zeros(Float32, 3)) + + kernel = simple_kernel(backend) + kernel(output, smv, indices; ndrange=3) + KernelAbstractions.synchronize(backend) + + result = Array(output) + @test result ≈ [0.5f0, 1.5f0, 0.8f0] +end diff --git a/test/test_tlas_stress.jl b/test/test_tlas_stress.jl new file mode 100644 index 0000000..39d9198 --- /dev/null +++ b/test/test_tlas_stress.jl @@ -0,0 +1,927 @@ +# ============================================================================== +# TLAS Stress / Memory-Safety Tests +# ============================================================================== +# +# Heavy-duty coverage for the mutable TLAS: +# +# - Random churn (push / delete / update_transform / update_transforms / +# sync) with strict invariants between every step. +# - High-instance-count and high-BLAS-count scenarios. +# - Use-after-free attempts on handles (must error, must not crash GPU). +# - Pure refit-only loops — must keep `static_tlas` identity stable, must +# not accumulate flat-array memory. +# - Topology-change rebuild after a long refit-only run. +# - GC-pressure: many `adapt()` calls without retaining results, plus a +# hard leak bound across 200 mesh swaps. +# +# Each test asserts EXACT counts on `tlas._flat_blas_*` / `tlas.blas_storage` +# rather than loose multiples — a leak that adds even one entry per cycle +# trips the test inside a few iterations instead of hiding behind 25× slack. +# ============================================================================== + +using Test +using GeometryBasics +using LinearAlgebra +using StaticArrays +using Raycore +using KernelAbstractions +const KA = KernelAbstractions +using Adapt +using Lava +using Random + +const STRESS_BACKEND = Lava.LavaBackend() + +# ------------------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------------------ + +stress_sphere(n::Int) = GeometryBasics.normal_mesh(Tesselation(Sphere(Point3f(0), 1f0), n)) +stress_box(s::Float32) = GeometryBasics.normal_mesh(Rect3f(Vec3f(-s/2), Vec3f(s))) + +stress_xlat(dx, dy, dz) = SMatrix{4,4,Float32,16}( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + dx, dy, dz, 1, +) + +KA.@kernel function stress_trace_kernel!(hit_out, t_out, tlas, origins, directions) + i = @index(Global, Linear) + ray = Raycore.Ray(; o=origins[i], d=directions[i]) + hit, _, dist, _, _ = Raycore.closest_hit(tlas, ray) + hit_out[i] = hit + t_out[i] = Float32(dist) +end + +"""Trace `n` rays in parallel; returns (hits::Vector{Bool}, ts::Vector{Float32}).""" +function stress_trace(tlas, origins::Vector{Point3f}, directions::Vector{Vec3f}) + @assert length(origins) == length(directions) + n = length(origins) + static = Adapt.adapt(STRESS_BACKEND, tlas) + hits = KA.zeros(STRESS_BACKEND, Bool, n) + ts = KA.zeros(STRESS_BACKEND, Float32, n) + o_d = Adapt.adapt(STRESS_BACKEND, origins) + d_d = Adapt.adapt(STRESS_BACKEND, directions) + stress_trace_kernel!(STRESS_BACKEND)(hits, ts, static, o_d, d_d; ndrange=n) + KA.synchronize(STRESS_BACKEND) + return Array(hits), Array(ts) +end + +stress_trace_one(tlas, o::Point3f, d::Vec3f) = begin + h, t = stress_trace(tlas, [o], [d]) + (hit = h[1], t = t[1]) +end + +"Sum of primitives across all currently-stored BLASes." +sum_storage_prims(tlas) = isempty(tlas.blas_storage) ? 0 : + sum(length(b.primitives) for b in tlas.blas_storage) +sum_storage_nodes(tlas) = isempty(tlas.blas_storage) ? 0 : + sum(length(b.nodes) for b in tlas.blas_storage) + +"`length` that treats `nothing` (drained-to-empty flat arrays) as 0." +flat_len(x) = x === nothing ? 0 : length(x) + +"Tight invariant: flat arrays MUST equal the sum across `blas_storage` +after `sync!`. Anything else is a leak or stale entry." +function assert_compact!(tlas; ctx::AbstractString="") + @test flat_len(tlas._flat_blas_prims) == sum_storage_prims(tlas) + @test flat_len(tlas._flat_blas_nodes) == sum_storage_nodes(tlas) +end + +# ------------------------------------------------------------------------------ +# 1. Random churn with strict invariants between every operation +# ------------------------------------------------------------------------------ +# +# A scripted-but-randomized sequence of operations. After every `sync!` we +# recompute (a) the total instance count and (b) the flat-array sizes from +# `blas_storage` and assert exact equality. Any leak / off-by-one in the +# compaction path is caught within a few iterations, not eventually. + +@testset "TLAS stress — random churn with exact invariants" begin + rng = MersenneTwister(0xC0FFEE) + tlas = Raycore.TLAS(STRESS_BACKEND) + handles = Raycore.TLASHandle[] # currently live handles (1 instance each) + n_per_handle = Int[] # parallel array: how many instances under each handle + + # Seed with one BLAS so we have something to update. + h0 = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + push!(handles, h0); push!(n_per_handle, 1) + Raycore.sync!(tlas) + + n_iters = 400 + for iter in 1:n_iters + op = rand(rng, 1:5) + + if op == 1 && length(handles) < 32 + # push! single instance + n = rand(rng, [4, 6, 8, 12]) + x = Float32(rand(rng) * 4 - 2) + h = push!(tlas, stress_sphere(n), stress_xlat(x, 0, 0)) + push!(handles, h); push!(n_per_handle, 1) + elseif op == 2 && length(handles) < 16 + # push! batch (2..6 instances of one BLAS) + k = rand(rng, 2:6) + xfs = [stress_xlat(Float32(rand(rng) * 4 - 2), Float32(rand(rng) * 2), 0) for _ in 1:k] + h = push!(tlas, stress_sphere(rand(rng, [4, 8])), xfs) + push!(handles, h); push!(n_per_handle, k) + elseif op == 3 && length(handles) > 1 + # delete! + i = rand(rng, 1:length(handles)) + Raycore.delete!(tlas, handles[i]) + deleteat!(handles, i); deleteat!(n_per_handle, i) + elseif op == 4 && !isempty(handles) + # update_transform! single (only valid if handle has 1 instance) + i = rand(rng, 1:length(handles)) + if n_per_handle[i] == 1 + Raycore.update_transform!(tlas, handles[i], + stress_xlat(Float32(rand(rng) * 6 - 3), 0, 0)) + end + elseif op == 5 && !isempty(handles) + # update_transforms! batch + i = rand(rng, 1:length(handles)) + k = n_per_handle[i] + xfs = [stress_xlat(Float32(rand(rng) * 6 - 3), 0, 0) for _ in 1:k] + Raycore.update_transforms!(tlas, handles[i], xfs) + end + + # Sync every 5 iterations so we hit the refit + rebuild paths a + # balanced number of times. + if iter % 5 == 0 + Raycore.sync!(tlas) + + # Strict invariants + expected_n_inst = isempty(n_per_handle) ? 0 : sum(n_per_handle) + @test Raycore.n_instances(tlas) == expected_n_inst + @test length(tlas.blas_storage) == length(handles) + @test length(tlas.deleted_handles) == 0 + @test flat_len(tlas._flat_blas_prims) == sum_storage_prims(tlas) + @test flat_len(tlas._flat_blas_nodes) == sum_storage_nodes(tlas) + @test !tlas.dirty + @test !tlas.transforms_dirty + end + end + + # Final sync + invariants + Raycore.sync!(tlas) + expected_n_inst = isempty(n_per_handle) ? 0 : sum(n_per_handle) + @test Raycore.n_instances(tlas) == expected_n_inst + @test length(tlas.blas_storage) == length(handles) + assert_compact!(tlas) + + # And after deleting everything left, the TLAS should drain to empty. + for h in handles + Raycore.delete!(tlas, h) + end + Raycore.sync!(tlas) + @test Raycore.n_instances(tlas) == 0 + @test length(tlas.blas_storage) == 0 + @test flat_len(tlas._flat_blas_prims) == 0 + @test flat_len(tlas._flat_blas_nodes) == 0 +end + +# ------------------------------------------------------------------------------ +# 2. Many BLASes, all live simultaneously +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 200 distinct BLASes alive at once" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + handles = Raycore.TLASHandle[] + n_blas = 200 + # Spacing > 2*radius so adjacent unit spheres never touch. When we delete + # one, the gap at its position must read as a clean miss instead of a + # tangent hit on a neighbour. + spacing = 4f0 + for i in 1:n_blas + n = 4 + (i % 6) + h = push!(tlas, stress_sphere(n), stress_xlat(Float32(i) * spacing, 0, 0)) + push!(handles, h) + end + Raycore.sync!(tlas) + + @test length(tlas.blas_storage) == n_blas + @test Raycore.n_instances(tlas) == n_blas + assert_compact!(tlas) + + # Trace one ray per instance — each must hit (sphere at x=i*spacing). + origins = [Point3f(Float32(i) * spacing, 0, 5) for i in 1:n_blas] + directions = fill(Vec3f(0, 0, -1), n_blas) + hits, ts = stress_trace(tlas, origins, directions) + @test all(hits) + @test all(t -> isapprox(t, 4f0; atol=0.1f0), ts) + + # Delete every other BLAS — remaining ones must still hit, deleted ones must miss. + for (i, h) in enumerate(handles) + if iseven(i) + Raycore.delete!(tlas, h) + end + end + Raycore.sync!(tlas) + @test length(tlas.blas_storage) == n_blas ÷ 2 + assert_compact!(tlas) + + hits2, _ = stress_trace(tlas, origins, directions) + for i in 1:n_blas + @test hits2[i] == isodd(i) + end +end + +# ------------------------------------------------------------------------------ +# 3. High instance count under one BLAS (TLAS-level instancing stress) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 5000 instances of one BLAS, batch update, refit" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 5000 + + # Build initial transforms placing instances along a line + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + sphere_h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + + @test Raycore.n_instances(tlas) == n_inst + @test length(tlas.blas_storage) == 1 + assert_compact!(tlas) + + # Sample ray hits at a few positions + sample_idxs = [1, n_inst ÷ 4, n_inst ÷ 2, 3 * n_inst ÷ 4, n_inst] + sample_origins = [Point3f(Float32(i) * 0.1f0, 0, 5) for i in sample_idxs] + sample_directions = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, ts = stress_trace(tlas, sample_origins, sample_directions) + @test all(hits) + @test all(t -> isapprox(t, 4f0; atol=0.1f0), ts) + + # Bulk-update every transform via update_transforms!. Move each instance + # by +10 in y. Expected: ray from (x_i, 10, 5) should hit at t≈4. + new_xfs = [stress_xlat(Float32(i) * 0.1f0, 10f0, 0) for i in 1:n_inst] + st_before = tlas.static_tlas + Raycore.update_transforms!(tlas, sphere_h, new_xfs) + @test tlas.transforms_dirty + Raycore.sync!(tlas) + @test !tlas.transforms_dirty + # Refit must keep the same StaticTLAS object alive (in-place AABB update). + @test tlas.static_tlas === st_before + + sample_origins2 = [Point3f(Float32(i) * 0.1f0, 10f0, 5) for i in sample_idxs] + hits2, ts2 = stress_trace(tlas, sample_origins2, sample_directions) + @test all(hits2) + @test all(t -> isapprox(t, 4f0; atol=0.1f0), ts2) + + # Old positions must now MISS (instances moved away). + hits3, _ = stress_trace(tlas, sample_origins, sample_directions) + @test !any(hits3) +end + +# ------------------------------------------------------------------------------ +# 3b. High-instance-count + tight refit loop (combined stress) +# ------------------------------------------------------------------------------ +# +# Test 3 does one update + one sync on 5000 instances. Test 7 does 500 refits +# but with one instance. Real workloads (RayMakie meshscatter at 60fps) hit +# the *combination*: thousands of instances, refit every frame, for hundreds +# of frames. This testset pins that. + +@testset "TLAS stress — 5000 instances + 200 refit frames in a tight loop" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 5000 + + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + + # Pin the StaticTLAS identity: the refit path must update tlas.nodes in + # place and reuse the same StaticTLAS. If sync! drops to rebuild for any + # frame, this assertion trips immediately. + st0 = tlas.static_tlas + @test st0 !== nothing + + sample_idxs = [1, n_inst ÷ 2, n_inst] + + n_frames = 200 + for frame in 1:n_frames + # Animate every instance. Each frame is a full bulk update, mirroring + # the meshscatter per-frame call site that produced the original + # CPU-loop performance footgun. + new_xfs = [stress_xlat(Float32(i) * 0.1f0, + Float32(0.1 * frame), + 0) + for i in 1:n_inst] + Raycore.update_transforms!(tlas, h, new_xfs) + Raycore.sync!(tlas) + + # Refit-path invariants — checked every frame, not just at the end. + @test tlas.static_tlas === st0 # in-place AABB update + @test !tlas.dirty + @test !tlas.transforms_dirty + @test Raycore.n_instances(tlas) == n_inst + @test length(tlas.blas_storage) == 1 # no BLAS churn + end + + # Verify a few rays hit at the LAST frame's positions (correctness end-to-end). + last_y = Float32(0.1 * n_frames) + origins = [Point3f(Float32(i) * 0.1f0, last_y, 5) for i in sample_idxs] + dirs = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, ts = stress_trace(tlas, origins, dirs) + @test all(hits) + @test all(t -> isapprox(t, 4f0; atol=0.1f0), ts) +end + +# ------------------------------------------------------------------------------ +# 3c. High-instance-count + tight REBUILD loop (topology churn at scale) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 2000 instances + 100 rebuild frames in a tight loop" begin + # Different from 3b: every frame DELETES and re-PUSHES the batch, forcing + # a full topology rebuild (not refit). Catches leaks / fragmentation in + # the rebuild path under realistic instance counts. + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 2000 + + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + + n_frames = 100 + for frame in 1:n_frames + Raycore.delete!(tlas, h) + new_xfs = [stress_xlat(Float32(i) * 0.1f0, + Float32(0.05 * frame), + 0) + for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), new_xfs) + Raycore.sync!(tlas) + + # Per-frame strict invariants — leaks expose themselves fast. + @test Raycore.n_instances(tlas) == n_inst + @test length(tlas.blas_storage) == 1 + @test length(tlas.deleted_handles) == 0 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + @test length(tlas._flat_blas_nodes) == length(tlas.blas_storage[1].nodes) + end + + # Sanity: ray hits at last frame's positions. + last_y = Float32(0.05 * n_frames) + sample_idxs = [1, n_inst ÷ 2, n_inst] + origins = [Point3f(Float32(i) * 0.1f0, last_y, 5) for i in sample_idxs] + dirs = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, _ = stress_trace(tlas, origins, dirs) + @test all(hits) +end + +# ------------------------------------------------------------------------------ +# 3d. Interleaved update + trace + update + trace (UAF / serialization stress) +# ------------------------------------------------------------------------------ +# +# Tight loops above check refit/rebuild correctness via invariants but trace +# only at the END. This testset interleaves: every frame does +# update_transforms! → sync! → trace → verify-this-frame's-positions +# so a trace's GPU read is bracketed by writes from the previous frame +# (already-completed) AND the NEXT frame (about to start). If sync! ever +# fails to serialize the new write against in-flight reads, or hands back a +# stale `static_tlas`, the trace returns wrong t-values that don't match +# THIS frame's transforms, and the test trips immediately — not eventually. + +@testset "TLAS stress — interleaved update + trace + update tight loop (1000 inst, refit)" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 1000 + + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + + st0 = tlas.static_tlas + + sample_idxs = [1, n_inst ÷ 4, n_inst ÷ 2, 3 * n_inst ÷ 4, n_inst] + + # Bounded oscillation in z so the unit-sphere top stays reachable from + # ray origin z=5 (sphere top = z_off + 1; need z_off+1 < 5). Use a + # sawtooth that visits 100 distinct z positions in [0, 2] without ever + # walking out of reach. + n_frames = 100 + for frame in 1:n_frames + z_off = Float32((frame % 50) * 0.04) # 0 .. 1.96 + new_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, z_off) for i in 1:n_inst] + Raycore.update_transforms!(tlas, h, new_xfs) + Raycore.sync!(tlas) + + @test tlas.static_tlas === st0 + @test !tlas.dirty + @test !tlas.transforms_dirty + + # Trace THIS frame; t must reflect THIS frame's z_off. + origins = [Point3f(Float32(i) * 0.1f0, 0, 5) for i in sample_idxs] + dirs = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, ts = stress_trace(tlas, origins, dirs) + expected_t = 5f0 - z_off - 1f0 + @test all(hits) + @test all(t -> isapprox(t, expected_t; atol=0.1f0), ts) + end +end + +@testset "TLAS stress — interleaved update + trace tight loop (5000 inst, refit)" begin + # Same shape, 5x the instance count — pushes the per-frame refit kernel + # ndrange high enough that timeline ordering bugs would tend to surface + # as flaky frame results. + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 5000 + + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + st0 = tlas.static_tlas + + sample_idxs = [1, 1000, 2500, 4000, n_inst] + n_frames = 50 + for frame in 1:n_frames + z_off = Float32(frame * 0.04) # 0.04 .. 2.0 — always reachable + new_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, z_off) for i in 1:n_inst] + Raycore.update_transforms!(tlas, h, new_xfs) + Raycore.sync!(tlas) + @test tlas.static_tlas === st0 + + origins = [Point3f(Float32(i) * 0.1f0, 0, 5) for i in sample_idxs] + dirs = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, ts = stress_trace(tlas, origins, dirs) + expected_t = 5f0 - z_off - 1f0 + @test all(hits) + @test all(t -> isapprox(t, expected_t; atol=0.1f0), ts) + end +end + +@testset "TLAS stress — interleaved delete+push+sync+trace tight loop (rebuild path)" begin + # Same interleaving but every frame changes topology (delete+push), + # exercising the rebuild path's reuse / free of tlas.nodes. An older + # frame's trace MUST NOT see node buffers that have been recycled into + # this frame's BVH — KA.synchronize inside sync! is what guarantees this; + # if it ever regresses, the per-frame correctness check trips. + tlas = Raycore.TLAS(STRESS_BACKEND) + n_inst = 500 + + init_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, 0) for i in 1:n_inst] + h = push!(tlas, stress_sphere(4), init_xfs) + Raycore.sync!(tlas) + + sample_idxs = [1, 100, 250, 400, n_inst] + n_frames = 60 + for frame in 1:n_frames + Raycore.delete!(tlas, h) + z_off = Float32((frame % 40) * 0.05) # 0 .. 1.95 — always reachable + # Alternate tessellation each frame to force fresh BLAS buffer sizes + # — node array can't be reused in place. + tess = isodd(frame) ? 4 : 8 + new_xfs = [stress_xlat(Float32(i) * 0.1f0, 0, z_off) for i in 1:n_inst] + h = push!(tlas, stress_sphere(tess), new_xfs) + Raycore.sync!(tlas) + + @test Raycore.n_instances(tlas) == n_inst + @test length(tlas.blas_storage) == 1 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + + origins = [Point3f(Float32(i) * 0.1f0, 0, 5) for i in sample_idxs] + dirs = fill(Vec3f(0, 0, -1), length(sample_idxs)) + hits, ts = stress_trace(tlas, origins, dirs) + expected_t = 5f0 - z_off - 1f0 + @test all(hits) + @test all(t -> isapprox(t, expected_t; atol=0.1f0), ts) + end +end + +# ------------------------------------------------------------------------------ +# 3e. 500-iter mesh grow/shrink with raytracing every iter (correctness + leak) +# ------------------------------------------------------------------------------ +# +# Five grow→shrink cycles of 100 iters each, varying peak tessellation +# (16, 32, 48, 64, 96). Every iter: +# delete! → push!(sphere(tess(iter)), translation(z=offset)) → sync! +# trace → verify hit position matches THIS iter's offset +# Catches: stale node-buffer captures from previous iters' traces (the +# UAF window between sync's KA.synchronize and the next push's allocation), +# leaks of BLAS arrays whose sizes change each iter, off-by-ones in the +# rebuild path's flat-array repacking when the tess count grows or shrinks. + +function grow_shrink_tess(iter::Int) + # 100-iter cycle: linear ramp 8 → peak → 8. Five peaks: 16, 32, 48, 64, 96. + cycle_len = 100 + peaks = (16, 32, 48, 64, 96) + cycle_i = ((iter - 1) ÷ cycle_len) % length(peaks) + 1 + peak = peaks[cycle_i] + phase = (iter - 1) % cycle_len + half = cycle_len ÷ 2 + if phase < half + max(8, Int(round(8 + (peak - 8) * (phase / half)))) + else + max(8, Int(round(peak - (peak - 8) * ((phase - half) / half)))) + end +end + +@testset "TLAS stress — 500-iter mesh grow/shrink + trace per iter (SW)" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + h = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + + n_iters = 500 + saw_min, saw_max = typemax(Int), 0 + for iter in 1:n_iters + tess = grow_shrink_tess(iter) + saw_min, saw_max = min(saw_min, tess), max(saw_max, tess) + # offset_z bounded so the ray (origin z=5) always reaches the sphere top. + z_off = Float32((iter % 30) * 0.05) # 0 .. 1.45 + Raycore.delete!(tlas, h) + h = push!(tlas, stress_sphere(tess), stress_xlat(0, 0, z_off)) + Raycore.sync!(tlas) + + # Topology invariants. + @test Raycore.n_instances(tlas) == 1 + @test length(tlas.blas_storage) == 1 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + @test length(tlas._flat_blas_nodes) == length(tlas.blas_storage[1].nodes) + + # Trace + verify THIS iter's geometry. expected_t = 5 - z_off - 1. + r = stress_trace_one(tlas, Point3f(0, 0, 5), Vec3f(0, 0, -1)) + @test r.hit + @test isapprox(r.t, 5f0 - z_off - 1f0; atol=0.15f0) + end + + # Sanity: the schedule actually swept low and high tessellation values. + @test saw_min <= 10 + @test saw_max >= 90 +end + +# ------------------------------------------------------------------------------ +# 4. Long churn cycle with hard memory bounds +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 200 swap iterations, exact compaction every step" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + handle = push!(tlas, stress_sphere(16), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + + n_iters = 200 + for iter in 1:n_iters + n = (iter % 5 == 0) ? 64 : (iseven(iter) ? 8 : 24) + Raycore.delete!(tlas, handle) + handle = push!(tlas, stress_sphere(n), + stress_xlat(0, 0, Float32(0.001 * iter))) + Raycore.sync!(tlas) + + # Strict equality every iteration — accumulation reveals itself fast. + @test length(tlas.blas_storage) == 1 + @test Raycore.n_instances(tlas) == 1 + @test length(tlas._flat_blas_prims) == length(tlas.blas_storage[1].primitives) + @test length(tlas._flat_blas_nodes) == length(tlas.blas_storage[1].nodes) + @test length(tlas.deleted_handles) == 0 + end + + # Final ray check — geometry still works after the long run. + r = stress_trace_one(tlas, Point3f(0, 0, 5), Vec3f(0, 0, -1)) + @test r.hit + @test isapprox(r.t, Float32(5) - Float32(0.001 * n_iters) - 1f0; atol=0.15f0) +end + +# ------------------------------------------------------------------------------ +# 5. Use-after-free attempts on handles +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — deleted handles must not be usable" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + h = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + @test Raycore.is_valid(tlas, h) + + # Delete + sync (compaction) + Raycore.delete!(tlas, h) + Raycore.sync!(tlas) + @test !Raycore.is_valid(tlas, h) + + # Every mutation / inspection API must reject the handle loudly. + @test_throws ErrorException Raycore.update_transform!(tlas, h, stress_xlat(1, 0, 0)) + @test_throws ErrorException Raycore.update_transforms!(tlas, h, [stress_xlat(1, 0, 0)]) + @test_throws ErrorException Raycore.get_instance(tlas, h) + @test_throws ErrorException Raycore.get_instances(tlas, h) + @test Raycore.delete!(tlas, h) === false # idempotent + + # Pre-compaction path: deleted but not yet sync!'d. + h2 = push!(tlas, stress_sphere(6)) + Raycore.sync!(tlas) + Raycore.delete!(tlas, h2) + @test !Raycore.is_valid(tlas, h2) + @test_throws ErrorException Raycore.update_transform!(tlas, h2, stress_xlat(1, 0, 0)) + @test_throws ErrorException Raycore.update_transforms!(tlas, h2, [stress_xlat(1, 0, 0)]) + + # Wrong-arity update_transform! / update_transforms! (handle/length mismatch). + h3 = push!(tlas, stress_sphere(6), [stress_xlat(0,0,0), stress_xlat(1,0,0)]) + Raycore.sync!(tlas) + @test_throws ErrorException Raycore.update_transform!(tlas, h3, stress_xlat(2, 0, 0)) # 1 vs 2 + @test_throws ErrorException Raycore.update_transforms!(tlas, h3, + [stress_xlat(0,0,0), stress_xlat(1,0,0), stress_xlat(2,0,0)]) # 3 vs 2 +end + +# ------------------------------------------------------------------------------ +# 7. Pure refit-only loop must not allocate / not change static_tlas identity +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 500 refit-only cycles preserve static_tlas identity" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + h = push!(tlas, stress_sphere(16), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + st0 = tlas.static_tlas + @test st0 !== nothing + + # Length of the flat arrays must NEVER change during a pure refit loop. + nodes_len_before = length(tlas._flat_blas_nodes) + prims_len_before = length(tlas._flat_blas_prims) + nodes_len_top = length(tlas.nodes) + + for iter in 1:500 + Raycore.update_transform!(tlas, h, stress_xlat(0, 0, Float32(iter * 0.001))) + Raycore.sync!(tlas) + @test tlas.static_tlas === st0 + @test length(tlas._flat_blas_nodes) == nodes_len_before + @test length(tlas._flat_blas_prims) == prims_len_before + @test length(tlas.nodes) == nodes_len_top + @test !tlas.dirty + @test !tlas.transforms_dirty + end + + # Sanity: refit moved the AABBs, so a ray from above hits at the new t. + r = stress_trace_one(tlas, Point3f(0, 0, 5), Vec3f(0, 0, -1)) + @test r.hit + @test isapprox(r.t, 5f0 - 0.5f0 - 1f0; atol=0.1f0) +end + +# ------------------------------------------------------------------------------ +# 8. Topology change after a long refit-only run (no carry-over staleness) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — topology change after long refit run" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + h_a = push!(tlas, stress_sphere(8), stress_xlat(-2, 0, 0)) + h_b = push!(tlas, stress_sphere(8), stress_xlat( 2, 0, 0)) + Raycore.sync!(tlas) + + # 100 refits + for iter in 1:100 + Raycore.update_transform!(tlas, h_a, stress_xlat(-2 + iter * 0.01f0, 0, 0)) + Raycore.update_transform!(tlas, h_b, stress_xlat( 2 - iter * 0.01f0, 0, 0)) + Raycore.sync!(tlas) + end + st_pre_topology = tlas.static_tlas + + # Now actually CHANGE topology: delete one handle, push! a new mesh. This + # MUST rebuild static_tlas (different size), and traces must reflect the + # new geometry — not the old. + Raycore.delete!(tlas, h_a) + h_c = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 5)) + Raycore.sync!(tlas) + @test tlas.static_tlas !== st_pre_topology + @test Raycore.n_instances(tlas) == 2 + @test length(tlas.blas_storage) == 2 # b + c (a was deleted) + assert_compact!(tlas) + + # Old h_a position should miss; new h_c position (z=5) should hit. + r_a = stress_trace_one(tlas, Point3f(-2 + 100 * 0.01f0, 0, 5), Vec3f(0, 0, -1)) + r_c = stress_trace_one(tlas, Point3f(0, 0, 10), Vec3f(0, 0, -1)) + @test !r_a.hit + @test r_c.hit + @test isapprox(r_c.t, 10f0 - 5f0 - 1f0; atol=0.1f0) +end + +# ------------------------------------------------------------------------------ +# 9. Hard leak bound across 200 swaps (WeakRefs) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — 200-swap hard leak bound (multiple WeakRefs)" begin + # Stronger version of the existing one-WeakRef test. We keep WeakRefs to + # 9 different prior static_tlas objects sampled across the run and assert + # ALL are collectable at the end. A regression that pins even one frame's + # static across mutations trips this. + # + # Implementation note: the workload runs inside a function so its locals + # (`tlas`, `handle`, the loop's `iter`/`n` bindings) leave scope before + # `GC.gc(true)`. Putting the same code directly under `@testset` + # observably retains the most recent static_tlas (testset macro keeps + # locals alive for its scope), which would mask real leaks behind a + # benign-looking single-WeakRef survival. `GC.gc(true)` is a full sweep — + # one call must be enough; if a WeakRef survives, that's a real reference. + function workload() + tlas = Raycore.TLAS(STRESS_BACKEND) + handle = push!(tlas, stress_sphere(16), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + wrefs = WeakRef[] + sample_at = Set{Int}([1, 25, 50, 75, 100, 125, 150, 175, 195]) + + for iter in 1:200 + n = iseven(iter) ? 12 : 64 + Raycore.delete!(tlas, handle) + handle = push!(tlas, stress_sphere(n), stress_xlat(0, 0, Float32(0.001 * iter))) + Raycore.sync!(tlas) + if iter in sample_at + push!(wrefs, WeakRef(tlas.static_tlas)) + end + end + + # Reseat static_tlas one more time so the last sampled iter is + # also no longer the live one. + Raycore.delete!(tlas, handle) + handle = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + return wrefs, sort!(collect(sample_at)) + end + + wrefs, sorted_iters = workload() + GC.gc(true) + n_leaked = count(w -> w.value !== nothing, wrefs) + if n_leaked != 0 + for (i, w) in enumerate(wrefs) + @info "WeakRef status" iter=sorted_iters[i] alive=(w.value !== nothing) + end + end + @test n_leaked == 0 +end + +# ------------------------------------------------------------------------------ +# 10. Many adapts during a refit loop — must not allocate fresh StaticTLAS +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — adapt-per-frame during refit-only loop is allocation-free" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + h = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + Raycore.sync!(tlas) + st0 = tlas.static_tlas + + # Adapt once and pin: identity must not change across pure-refit cycles. + pinned_static = Adapt.adapt(STRESS_BACKEND, tlas) + @test pinned_static === st0 + + for iter in 1:200 + Raycore.update_transform!(tlas, h, stress_xlat(0, 0, Float32(iter * 0.001))) + Raycore.sync!(tlas) + # Each adapt must return the SAME object; the "no-op sync" path must + # not silently rebuild StaticTLAS. + @test Adapt.adapt(STRESS_BACKEND, tlas) === pinned_static + end +end + +# ------------------------------------------------------------------------------ +# 11. Mixed delete+push at high churn (exercise compaction edge cases) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — interleaved delete + push without intermediate sync" begin + # Several deletes + pushes BEFORE a single sync!. This exercises the + # path where compaction sees both fresh-pushed instances and deletion + # marks at the same time. + tlas = Raycore.TLAS(STRESS_BACKEND) + h1 = push!(tlas, stress_sphere(8), stress_xlat(0, 0, 0)) + h2 = push!(tlas, stress_sphere(8), stress_xlat(2, 0, 0)) + h3 = push!(tlas, stress_sphere(8), stress_xlat(4, 0, 0)) + Raycore.sync!(tlas) + @test Raycore.n_instances(tlas) == 3 + + # No sync between these: + Raycore.delete!(tlas, h2) + h4 = push!(tlas, stress_sphere(8), stress_xlat(6, 0, 0)) + Raycore.delete!(tlas, h1) + h5 = push!(tlas, stress_sphere(8), stress_xlat(8, 0, 0)) + + # Single sync should resolve all of it. + Raycore.sync!(tlas) + @test Raycore.is_valid(tlas, h3) + @test Raycore.is_valid(tlas, h4) + @test Raycore.is_valid(tlas, h5) + @test !Raycore.is_valid(tlas, h1) + @test !Raycore.is_valid(tlas, h2) + @test Raycore.n_instances(tlas) == 3 + @test length(tlas.blas_storage) == 3 + assert_compact!(tlas) + + # Hit positions: x = 4 (h3), 6 (h4), 8 (h5). x=0 and x=2 must miss. + origins = [Point3f(Float32(x), 0, 5) for x in (0, 2, 4, 6, 8)] + dirs = fill(Vec3f(0, 0, -1), 5) + hits, _ = stress_trace(tlas, origins, dirs) + @test hits == [false, false, true, true, true] +end + +# ------------------------------------------------------------------------------ +# 12. Empty-TLAS transitions (drain to zero, rebuild from zero) +# ------------------------------------------------------------------------------ + +@testset "TLAS stress — drain to empty + rebuild from empty" begin + tlas = Raycore.TLAS(STRESS_BACKEND) + + # 5 push / sync / delete / sync cycles. After every cycle the TLAS must + # be observably empty, and after every push the previous-empty backing + # buffers must not leak into the new build. + for iter in 1:5 + h = push!(tlas, stress_sphere(8), stress_xlat(Float32(iter), 0, 0)) + Raycore.sync!(tlas) + @test Raycore.n_instances(tlas) == 1 + @test length(tlas.blas_storage) == 1 + + Raycore.delete!(tlas, h) + Raycore.sync!(tlas) + @test Raycore.n_instances(tlas) == 0 + @test length(tlas.blas_storage) == 0 + @test flat_len(tlas._flat_blas_prims) == 0 + @test flat_len(tlas._flat_blas_nodes) == 0 + @test length(tlas.deleted_handles) == 0 + # Empty-state ray traces must not crash and must always miss. + r = stress_trace_one(tlas, Point3f(0, 0, 5), Vec3f(0, 0, -1)) + @test !r.hit + end +end + +# ------------------------------------------------------------------------------ +# 12b. Cross-backend Adapt.adapt errors loudly +# ------------------------------------------------------------------------------ +# +# A TLAS built on backend A handed to `Adapt.adapt(B, tlas)` would silently +# return a `static_tlas` whose arrays still live on A — the error only +# surfaces later as a confusing GPUCompiler "non-bitstype argument" inside +# kernel compilation. Pin the loud-at-the-API-boundary contract. + +@testset "TLAS stress — cross-backend adapt errors loudly" begin + cpu_tlas = Raycore.TLAS(KA.CPU()) + push!(cpu_tlas, stress_sphere(8)) + Raycore.sync!(cpu_tlas) + + # Adapting to the matching backend works. + @test Adapt.adapt(KA.CPU(), cpu_tlas) === cpu_tlas.static_tlas + + # Adapting to a different backend errors loudly. + @test_throws ErrorException Adapt.adapt(STRESS_BACKEND, cpu_tlas) + + # Same the other way: a Lava-backend TLAS adapted to KA.CPU() must error. + lava_tlas = Raycore.TLAS(STRESS_BACKEND) + push!(lava_tlas, stress_sphere(8)) + Raycore.sync!(lava_tlas) + @test Adapt.adapt(STRESS_BACKEND, lava_tlas) === lava_tlas.static_tlas + @test_throws ErrorException Adapt.adapt(KA.CPU(), lava_tlas) +end + +# ------------------------------------------------------------------------------ +# 13. HW TLAS stress — same patterns over Vulkan ray tracing +# ------------------------------------------------------------------------------ + +@testset "HW TLAS stress — random churn with strict invariants" begin + # HW path now supports push! / delete! / update_transform! / + # update_transforms!. The latter is GPU-resident: a compute kernel + # writes new records into the batch's instance_buf, sync! refits. + rng = MersenneTwister(0xBADF00D) + hwtlas = Lava.HWTLAS(STRESS_BACKEND) + handles = Raycore.TLASHandle[] + + h0 = push!(hwtlas, stress_sphere(8), stress_xlat(0, 0, 0); instance_id=UInt32(1)) + push!(handles, h0) + Raycore.sync!(hwtlas) + + for iter in 1:80 + op = rand(rng, 1:3) + if op == 1 && length(handles) < 16 + n = rand(rng, [4, 6, 8]) + x = Float32(rand(rng) * 4 - 2) + h = push!(hwtlas, stress_sphere(n), stress_xlat(x, 0, 0); + instance_id=UInt32(length(handles) + 1)) + push!(handles, h) + elseif op == 2 && length(handles) > 1 + i = rand(rng, 1:length(handles)) + Raycore.delete!(hwtlas, handles[i]) + deleteat!(handles, i) + elseif op == 3 && !isempty(handles) + i = rand(rng, 1:length(handles)) + Raycore.update_transform!(hwtlas, handles[i], + stress_xlat(Float32(rand(rng) * 6 - 3), 0, 0)) + end + + if iter % 5 == 0 + Raycore.sync!(hwtlas) + @test Raycore.n_instances(hwtlas) == length(handles) + end + end + + Raycore.sync!(hwtlas) + @test Raycore.world_bound(hwtlas) isa Raycore.Bounds3 + @test Raycore.wait_for_gpu!(hwtlas) === hwtlas +end + +@testset "HW TLAS stress — long mesh-swap loop, leak bound" begin + # Mirror of the SW TLAS leak-bound test, on the HW path. Each iteration + # drops the previous BLAS and pushes a fresh one. The HW pool / instance + # buffer counts must NOT scale with iteration count. + hwtlas = Lava.HWTLAS(STRESS_BACKEND) + h = push!(hwtlas, stress_sphere(16), stress_xlat(0, 0, 0); instance_id=UInt32(1)) + Raycore.sync!(hwtlas) + + n_iters = 100 + for iter in 1:n_iters + Raycore.delete!(hwtlas, h) + h = push!(hwtlas, stress_sphere(iseven(iter) ? 12 : 32), + stress_xlat(0, 0, Float32(0.001 * iter)); + instance_id=UInt32(1)) + Raycore.sync!(hwtlas) + @test Raycore.n_instances(hwtlas) == 1 + end + GC.gc(true); GC.gc(true) + @test Raycore.n_instances(hwtlas) == 1 +end + +println("\nAll stress tests passed.") diff --git a/test/test_type_stability.jl b/test/test_type_stability.jl deleted file mode 100644 index 16dbb72..0000000 --- a/test/test_type_stability.jl +++ /dev/null @@ -1,451 +0,0 @@ -using LinearAlgebra -using Raycore.StaticArrays -# ==================== Test Data Generators ==================== - -# Basic geometric types -gen_point3f() = Point3f(1.0f0, 2.0f0, 3.0f0) -gen_point2f() = Point2f(0.5f0, 0.5f0) -gen_vec3f() = Vec3f(0.0f0, 0.0f0, 1.0f0) -gen_normal3f() = Raycore.Normal3f(0.0f0, 0.0f0, 1.0f0) - -# Bounds -gen_bounds2() = Raycore.Bounds2(Point2f(0.0f0), Point2f(1.0f0)) -gen_bounds3() = Raycore.Bounds3(Point3f(0.0f0), Point3f(1.0f0, 1.0f0, 1.0f0)) - -# Rays -gen_ray() = Raycore.Ray(o=Point3f(0.0f0), d=Vec3f(0.0f0, 0.0f0, 1.0f0)) -gen_ray_differentials() = Raycore.RayDifferentials(o=Point3f(0.0f0), d=Vec3f(0.0f0, 0.0f0, 1.0f0)) - -# Transformations -gen_transformation() = Raycore.Transformation() -gen_transformation_translate() = Raycore.translate(Vec3f(1.0f0, 0.0f0, 0.0f0)) -gen_transformation_rotate() = Raycore.rotate_x(45.0f0) -gen_transformation_scale() = Raycore.scale(2.0f0, 2.0f0, 2.0f0) - -# Triangle -function gen_triangle() - v1 = Point3f(0.0f0, 0.0f0, 0.0f0) - v2 = Point3f(1.0f0, 0.0f0, 0.0f0) - v3 = Point3f(0.0f0, 1.0f0, 0.0f0) - n1 = Raycore.Normal3f(0.0f0, 0.0f0, 1.0f0) - uv1 = Point2f(0.0f0, 0.0f0) - uv2 = Point2f(1.0f0, 0.0f0) - uv3 = Point2f(0.0f0, 1.0f0) - Raycore.Triangle( - SVector(v1, v2, v3), - SVector(n1, n1, n1), - SVector(Vec3f(NaN), Vec3f(NaN), Vec3f(NaN)), - SVector(uv1, uv2, uv3), - UInt32(1) # metadata (single field replaces mesh_idx and material_idx) - ) -end - -# Triangle Mesh -function gen_triangle_mesh() - vertices = [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)] - indices = UInt32[1, 2, 3] # 1-based indices for Julia - normals = [Raycore.Normal3f(0, 0, 1), Raycore.Normal3f(0, 0, 1), Raycore.Normal3f(0, 0, 1)] - Raycore.TriangleMesh(vertices, indices, normals) -end - -# BVH -function gen_bvh_accel() - mesh = Rect3f(Point3f(0), Vec3f(1)) - Raycore.BVH([mesh], 1) -end - -# Quaternion -gen_quaternion() = Raycore.Quaternion() - -# ==================== Custom Test Macros ==================== - -""" - @test_opt_alloc expr - -Combined macro that tests both type stability (via @test_opt) and zero allocations. -This is equivalent to: - @test_opt expr - @test @allocated(expr) == 0 -""" -macro test_opt_alloc(expr) - return esc(quote - $expr # warmup - JET.@test_opt $expr - @test @allocated($expr) == 0 - end) -end - -# ==================== Bounds Tests ==================== - -@testset "Type Stability: bounds.jl" begin - @testset "Bounds2" begin - @test_opt_alloc Raycore.Bounds2() - - @test_opt_alloc Raycore.Bounds2(gen_point2f()) - - @test_opt_alloc Raycore.Bounds2c(gen_point2f(), Point2f(1.0f0, 1.0f0)) - end - - @testset "Bounds3" begin - @test_opt_alloc Raycore.Bounds3() - - @test_opt_alloc Raycore.Bounds3(gen_point3f()) - - @test_opt_alloc Raycore.Bounds3c(gen_point3f(), Point3f(2.0f0, 2.0f0, 2.0f0)) - end - - @testset "Bounds operations" begin - b1 = gen_bounds3() - b2 = Raycore.Bounds3(Point3f(0.5f0), Point3f(1.5f0, 1.5f0, 1.5f0)) - p = gen_point3f() - - @test_opt_alloc Base.:(==)(b1, b2) - @test_opt_alloc Base.:≈(b1, b2) - @test_opt_alloc Base.getindex(b1, 1) - @test_opt_alloc Raycore.is_valid(b1) - @test_opt_alloc Raycore.corner(b1, 1) - @test_opt_alloc Base.union(b1, b2) - @test_opt_alloc Base.intersect(b1, b2) - @test_opt_alloc Raycore.overlaps(b1, b2) - @test_opt_alloc Raycore.inside(b1, p) - @test_opt_alloc Raycore.inside_exclusive(b1, p) - @test_opt_alloc Raycore.expand(b1, 0.1f0) - @test_opt_alloc Raycore.diagonal(b1) - @test_opt_alloc Raycore.surface_area(b1) - @test_opt_alloc Raycore.volume(b1) - @test_opt_alloc Raycore.maximum_extent(b1) - @test_opt_alloc Raycore.sides(b1) - @test_opt_alloc Raycore.inclusive_sides(b1) - @test_opt_alloc Raycore.bounding_sphere(b1) - @test_opt_alloc Raycore.offset(b1, p) - end - - @testset "Bounds with Ray" begin - b = gen_bounds3() - r = gen_ray() - - @test_opt_alloc Raycore.intersect(b, r) - @test_opt_alloc Raycore.is_dir_negative(r.d) - - inv_dir = 1.0f0 ./ r.d - dir_neg = Raycore.is_dir_negative(r.d) - @test_opt_alloc Raycore.intersect_p(b, r, inv_dir, dir_neg) - end - - @testset "Bounds2 iteration" begin - b = gen_bounds2() - @test_opt_alloc Base.length(b) - @test_opt_alloc Base.iterate(b) - @test_opt_alloc Base.iterate(b, Int32(1)) - end - - @testset "Distance functions" begin - p1 = gen_point3f() - p2 = Point3f(2.0f0, 3.0f0, 4.0f0) - - @test_opt_alloc Raycore.distance(p1, p2) - @test_opt_alloc Raycore.distance_squared(p1, p2) - end - - @testset "Lerp functions" begin - b = gen_bounds3() - p = gen_point3f() - - @test_opt_alloc Raycore.lerp(0.0f0, 1.0f0, 0.5f0) - @test_opt_alloc Raycore.lerp(Point3f(0), Point3f(1), 0.5f0) - @test_opt_alloc Raycore.lerp(b, Point3f(0.5f0)) - end - - @testset "Bounds2 area" begin - b = gen_bounds2() - @test_opt_alloc Raycore.area(b) - end -end - -# ==================== Ray Tests ==================== - -@testset "Type Stability: ray.jl" begin - @testset "Ray construction" begin - @test_opt_alloc Raycore.Ray(o=gen_point3f(), d=gen_vec3f()) - @test_opt_alloc Raycore.Ray(o=gen_point3f(), d=gen_vec3f(), t_max=10.0f0) - @test_opt_alloc Raycore.Ray(o=gen_point3f(), d=gen_vec3f(), t_max=10.0f0, time=0.5f0) - end - - @testset "Ray copy constructor" begin - r = gen_ray() - @test_opt_alloc Raycore.Ray(r; o=Point3f(1.0f0)) - @test_opt_alloc Raycore.Ray(r; d=Vec3f(1.0f0, 0.0f0, 0.0f0)) - @test_opt_alloc Raycore.Ray(r; t_max=5.0f0) - end - - @testset "RayDifferentials construction" begin - @test_opt_alloc Raycore.RayDifferentials(o=gen_point3f(), d=gen_vec3f()) - @test_opt_alloc Raycore.RayDifferentials(gen_ray()) - end - - @testset "Ray operations" begin - r = gen_ray() - rd = gen_ray_differentials() - - @test_opt_alloc Raycore.set_direction(r, Vec3f(1.0f0, 0.0f0, 0.0f0)) - @test_opt_alloc Raycore.set_direction(rd, Vec3f(1.0f0, 0.0f0, 0.0f0)) - @test_opt_alloc Raycore.check_direction(r) - @test_opt_alloc Raycore.check_direction(rd) - @test_opt_alloc Raycore.apply(r, 1.0f0) - @test_opt_alloc Raycore.increase_hit(r, 0.5f0) - @test_opt_alloc Raycore.increase_hit(rd, 0.5f0) - end - - @testset "RayDifferentials operations" begin - rd = gen_ray_differentials() - @test_opt_alloc Raycore.scale_differentials(rd, 0.5f0) - end - - @testset "Intersection helpers" begin - t = gen_triangle() - r = gen_ray() - @test_opt_alloc Raycore.intersect_p!(t, r) - end -end - -# ==================== Transformation Tests ==================== - -@testset "Type Stability: transformations.jl" begin - @testset "Transformation construction" begin - @test_opt_alloc Raycore.Transformation() - @test_opt_alloc Raycore.Transformation(Mat4f(I)) - end - - @testset "Basic transformations" begin - @test_opt_alloc Raycore.translate(gen_vec3f()) - @test_opt_alloc Raycore.scale(2.0f0, 2.0f0, 2.0f0) - @test_opt_alloc Raycore.rotate_x(45.0f0) - @test_opt_alloc Raycore.rotate_y(45.0f0) - @test_opt_alloc Raycore.rotate_z(45.0f0) - @test_opt_alloc Raycore.rotate(45.0f0, Vec3f(0, 0, 1)) - end - - @testset "Transformation operations" begin - t1 = gen_transformation_translate() - t2 = gen_transformation_rotate() - - @test_opt_alloc Raycore.is_identity(t1) - @test_opt_alloc Base.transpose(t1) - @test_opt_alloc Base.inv(t1) - @test_opt_alloc Base.:(==)(t1, t2) - @test_opt_alloc Base.:≈(t1, t2) - @test_opt_alloc Base.:*(t1, t2) - end - - @testset "Transformation application" begin - t = gen_transformation_translate() - - @test_opt_alloc t(gen_point3f()) - @test_opt_alloc t(gen_vec3f()) - @test_opt_alloc t(gen_normal3f()) - @test_opt_alloc t(gen_bounds3()) - end - - @testset "Advanced transformations" begin - @test_opt_alloc Raycore.look_at(Point3f(0, 0, 5), Point3f(0), Vec3f(0, 1, 0)) - @test_opt_alloc Raycore.perspective(60.0f0, 0.1f0, 100.0f0) - end - - @testset "Transformation properties" begin - t = gen_transformation_scale() - @test_opt_alloc Raycore.has_scale(t) - @test_opt_alloc Raycore.swaps_handedness(t) - end - - @testset "Transformation with Ray" begin - t = gen_transformation_translate() - r = gen_ray() - rd = gen_ray_differentials() - - @test_opt_alloc Raycore.apply(t, r) - @test_opt_alloc Raycore.apply(t, rd) - end - - @testset "Quaternion" begin - @test_opt_alloc Raycore.Quaternion() - @test_opt_alloc Raycore.Quaternion(gen_transformation()) - - q1 = gen_quaternion() - q2 = Raycore.Quaternion(Vec3f(1, 0, 0), 0.5f0) - - @test_opt_alloc Base.:+(q1, q2) - @test_opt_alloc Base.:-(q1, q2) - @test_opt_alloc Base.:/(q1, 2.0f0) - @test_opt_alloc Base.:*(q1, 2.0f0) - @test_opt_alloc LinearAlgebra.dot(q1, q2) - @test_opt_alloc LinearAlgebra.normalize(q1) - @test_opt_alloc Raycore.Transformation(q1) - @test_opt_alloc Raycore.slerp(q1, q2, 0.5f0) - end -end - -# ==================== Math Tests ==================== - -@testset "Type Stability: math.jl" begin - @testset "Sampling functions" begin - u = gen_point2f() - - @test_opt_alloc Raycore.concentric_sample_disk(u) - @test_opt_alloc Raycore.cosine_sample_hemisphere(u) - @test_opt_alloc Raycore.uniform_sample_sphere(u) - @test_opt_alloc Raycore.uniform_sample_cone(u, 0.5f0) - @test_opt_alloc Raycore.uniform_sample_cone(u, 0.5f0, Vec3f(1,0,0), Vec3f(0,1,0), Vec3f(0,0,1)) - end - - @testset "PDF functions" begin - @test_opt_alloc Raycore.uniform_sphere_pdf() - @test_opt_alloc Raycore.uniform_cone_pdf(0.5f0) - end - - @testset "Shading coordinate system" begin - w = gen_vec3f() - - @test_opt_alloc Raycore.cos_θ(w) - @test_opt_alloc Raycore.sin_θ2(w) - @test_opt_alloc Raycore.sin_θ(w) - @test_opt_alloc Raycore.tan_θ(w) - @test_opt_alloc Raycore.cos_ϕ(w) - @test_opt_alloc Raycore.sin_ϕ(w) - end - - @testset "Vector operations" begin - wo = gen_vec3f() - n = Vec3f(0, 1, 0) - - @test_opt_alloc Raycore.reflect(wo, n) - @test_opt_alloc Raycore.face_forward(n, wo) - end - - @testset "Coordinate system" begin - v = gen_vec3f() - @test_opt_alloc Raycore.coordinate_system(v) - end - - @testset "Spherical functions" begin - @test_opt_alloc Raycore.spherical_direction(0.5f0, 0.5f0, 1.0f0) - @test_opt_alloc Raycore.spherical_direction(0.5f0, 0.5f0, 1.0f0, Vec3f(1,0,0), Vec3f(0,1,0), Vec3f(0,0,1)) - - v = gen_vec3f() - @test_opt_alloc Raycore.spherical_θ(v) - @test_opt_alloc Raycore.spherical_ϕ(v) - end - - @testset "Helper functions" begin - v = gen_vec3f() - @test_opt_alloc Raycore.get_orthogonal_basis(v) - - t = gen_triangle() - @test_opt_alloc Raycore.random_triangle_point(t) - end - - @testset "sum_mul" begin - a = Point3f(0.2f0, 0.3f0, 0.5f0) - b = Raycore.StaticArrays.SVector(Point3f(0,0,0), Point3f(1,0,0), Point3f(0,1,0)) - @test_opt_alloc Raycore.sum_mul(a, b) - end -end - -@testset "Type Stability: triangle_mesh.jl" begin - @testset "TriangleMesh construction" begin - vertices = [Point3f(0, 0, 0), Point3f(1, 0, 0), Point3f(0, 1, 0)] - indices = UInt32[0, 1, 2] - normals = [Raycore.Normal3f(0, 0, 1), Raycore.Normal3f(0, 0, 1), Raycore.Normal3f(0, 0, 1)] - - @test_opt Raycore.TriangleMesh(vertices, indices, normals) - @test_opt Raycore.TriangleMesh(vertices, indices) - end - - @testset "Triangle construction" begin - mesh = gen_triangle_mesh() - @test_opt_alloc Raycore.Triangle(mesh, 1, UInt32(1)) - end - - @testset "Triangle operations" begin - t = gen_triangle() - - @test_opt_alloc Raycore.vertices(t) - @test_opt_alloc Raycore.normals(t) - @test_opt_alloc Raycore.tangents(t) - @test_opt_alloc Raycore.uvs(t) - @test_opt_alloc Raycore.area(t) - @test_opt_alloc Raycore.object_bound(t) - @test_opt_alloc Raycore.world_bound(t) - end - - @testset "Triangle intersection" begin - t = gen_triangle() - r = gen_ray() - - @test_opt_alloc Raycore.intersect(t, r) - @test_opt_alloc Raycore.intersect_p(t, r) - @test_opt_alloc Raycore.intersect_triangle(t.vertices, r) - end - - @testset "Triangle helper functions" begin - t = gen_triangle() - r = gen_ray() - - # Test _to_ray_coordinate_space - @test_opt_alloc Raycore._to_ray_coordinate_space(t.vertices, r) - - # Test partial_derivatives - @test_opt_alloc Raycore.partial_derivatives(t, t.vertices, t.uv) - - # Test normal_derivatives - @test_opt_alloc Raycore.normal_derivatives(t, t.uv) - end - - @testset "Triangle utilities" begin - t = gen_triangle() - @test_opt_alloc Raycore.is_degenerate(t.vertices) - end -end - -# ==================== BVH Tests ==================== - -@testset "Type Stability: bvh.jl" begin - @testset "BVHPrimitiveInfo" begin - b = gen_bounds3() - @test_opt_alloc Raycore.BVHPrimitiveInfo(UInt32(1), b) - end - - @testset "BVHNode construction" begin - b = gen_bounds3() - @test_opt Raycore.BVHNode(UInt32(0), UInt32(1), b) - end - - @testset "LinearBVH construction" begin - b = gen_bounds3() - @test_opt_alloc Raycore.LinearBVHLeaf(b, UInt32(0), UInt32(1)) - @test_opt_alloc Raycore.LinearBVHInterior(b, UInt32(1), UInt8(0)) - end - - @testset "BVH operations" begin - bvh = gen_bvh_accel() - r = gen_ray() - - @test_opt Raycore.world_bound(bvh) - @test_opt Raycore.closest_hit(bvh, r) - @test_opt Raycore.any_hit(bvh, r) - end - - @testset "Ray grid generation" begin - bvh = gen_bvh_accel() - direction = Vec3f(0, 0, 1) - # generate_ray_grid allocates - needs optimization - @test_opt Raycore.generate_ray_grid(bvh, direction, 10) - end -end - -# ==================== Kernels Tests ==================== - -@testset "Type Stability: kernels.jl" begin - @testset "RayHit construction" begin - @test_opt_alloc Raycore.RayHit(true, gen_point3f(), UInt32(1)) - end -end diff --git a/test/test_unrolled.jl b/test/test_unrolled.jl new file mode 100644 index 0000000..c28f9ac --- /dev/null +++ b/test/test_unrolled.jl @@ -0,0 +1,225 @@ +using Test +using Raycore + +# ============================================================================ +# Test FastClosure capture detection +# ============================================================================ + +@testset "FastClosure capture detection" begin + # Regular function - should work + @test begin + add(x, y) = x + y + fc = FastClosure(add, (10,)) + fc(5) == 15 + end + + # Anonymous function without capture - should work + @test begin + fc = FastClosure((x, y) -> x * y, (3,)) + fc(4) == 12 + end + + # Closure WITH capture (in local scope) - should error + # Note: captures only become fields when created in local function scope + function make_capturing_closure() + captured = 42 + return x -> x + captured + end + @test_throws ErrorException FastClosure(make_capturing_closure(), ()) + + # Closure with boxed capture (reassigned after closure creation) - should error + function make_boxed_closure() + captured = 42 + closure = x -> x + captured + captured = 100 # reassignment causes Core.Box + return closure + end + @test_throws ErrorException FastClosure(make_boxed_closure(), ()) +end + +# ============================================================================ +# Test for_unrolled +# ============================================================================ + +@testset "for_unrolled with tuple" begin + # Basic iteration + @test begin + results = Int[] + push_val!(x, arr) = push!(arr, x) + for_unrolled(push_val!, (1, 2, 3), results) + results == [1, 2, 3] + end + + # Heterogeneous tuple + @test begin + results = Any[] + push_val!(x, arr) = push!(arr, x) + for_unrolled(push_val!, (1, "hello", 3.14), results) + results == [1, "hello", 3.14] + end + + # Empty tuple + @test begin + count = Ref(0) + inc!(x, c) = c[] += 1 + for_unrolled(inc!, (), count) + count[] == 0 + end + + # Multiple extra args + @test begin + results = Float64[] + scaled_push!(x, arr, scale, offset) = push!(arr, x * scale + offset) + for_unrolled(scaled_push!, (1, 2, 3), results, 2.0, 0.5) + results == [2.5, 4.5, 6.5] + end +end + +@testset "for_unrolled with Val{N}" begin + @test begin + results = Int32[] + push_val!(i, arr) = push!(arr, i) + for_unrolled(push_val!, Val(5), results) + results == Int32[1, 2, 3, 4, 5] + end + + @test begin + results = Int32[] + push_val!(i, arr) = push!(arr, i) + for_unrolled(push_val!, Val(0), results) + isempty(results) + end +end + +# ============================================================================ +# Test map_unrolled +# ============================================================================ + +@testset "map_unrolled" begin + # Basic mapping + @test begin + double(x) = 2x + map_unrolled(double, (1, 2, 3)) == (2, 4, 6) + end + + # With extra args + @test begin + scale(x, factor) = x * factor + map_unrolled(scale, (1, 2, 3), 10) == (10, 20, 30) + end + + # Heterogeneous tuple - returns heterogeneous result + @test begin + wrap(x) = [x] + result = map_unrolled(wrap, (1, "a", 3.0)) + result == ([1], ["a"], [3.0]) + end + + # Empty tuple + @test map_unrolled(identity, ()) == () + + # Type-changing map + @test begin + to_string(x) = string(x) + map_unrolled(to_string, (1, 2, 3)) == ("1", "2", "3") + end +end + +# ============================================================================ +# Test reduce_unrolled +# ============================================================================ + +@testset "reduce_unrolled" begin + # Sum reduction + @test begin + add_to_acc(acc, x) = acc + x + reduce_unrolled(add_to_acc, (1, 2, 3, 4), 0) == 10 + end + + # With extra args + @test begin + scaled_add(acc, x, scale) = acc + x * scale + reduce_unrolled(scaled_add, (1, 2, 3), 0, 2) == 12 # 2 + 4 + 6 + end + + # Product reduction + @test begin + mul_to_acc(acc, x) = acc * x + reduce_unrolled(mul_to_acc, (1, 2, 3, 4), 1) == 24 + end + + # Empty tuple returns init + @test begin + add_to_acc(acc, x) = acc + x + reduce_unrolled(add_to_acc, (), 42) == 42 + end + + # Collecting into array + @test begin + collect_acc(acc, x) = push!(copy(acc), x) + reduce_unrolled(collect_acc, (1, 2, 3), Int[]) == [1, 2, 3] + end +end + +# ============================================================================ +# Test sum_unrolled +# ============================================================================ + +@testset "sum_unrolled" begin + # Direct sum + @test begin + identity_val(x) = x + sum_unrolled(identity_val, (1, 2, 3, 4)) == 10 + end + + # With transformation + @test begin + square(x) = x^2 + sum_unrolled(square, (1, 2, 3)) == 14 # 1 + 4 + 9 + end + + # With extra args + @test begin + scaled(x, factor) = x * factor + sum_unrolled(scaled, (1, 2, 3), 2) == 12 # 2 + 4 + 6 + end + + # Single element + @test begin + identity_val(x) = x + sum_unrolled(identity_val, (42,)) == 42 + end + + # Empty tuple returns nothing + @test begin + identity_val(x) = x + sum_unrolled(identity_val, ()) === nothing + end + + # Float accumulation + @test begin + identity_val(x) = x + sum_unrolled(identity_val, (1.0, 2.0, 3.0)) ≈ 6.0 + end +end + +# ============================================================================ +# Test compiler limits checking +# ============================================================================ + +@testset "Compiler limits" begin + # Too many args should error + @test_throws ErrorException begin + f(x) = x + big_args = ntuple(i -> i, 40) # 40 > MAX_TUPLE_LENGTH + FastClosure(f, big_args) + end + + # Tuple arg that's too long should error + @test_throws ErrorException begin + f(x, _big_tuple) = x + big_tuple = ntuple(i -> i, 40) + FastClosure(f, (big_tuple,)) + end +end +