diff --git a/ROADMAP.md b/ROADMAP.md index 961c759..587f481 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -69,17 +69,17 @@ Documentation fixes and test coverage gaps. **Complete**. --- -## Phase 4: Deferred Features +## Phase 4: Deferred Features ✅ -Valuable features without current demand. Revisit when a concrete use case arises. +Valuable features that were deferred until concrete use cases arose. **Complete**. -| # | Item | Effort | Revisit when | Source | -|---|------|--------|--------------|--------| -| 4.1 | Indefinite dense STDE (eigendecomposition for indefinite C matrices) | medium | A user needs indefinite C support | [ADR](docs/adr-deferred-work.md) | -| 4.2 | General-K GPU Taylor kernels (beyond K=3) | medium | Need for GPU-accelerated 3rd+ order derivatives | [ADR](docs/adr-deferred-work.md) | -| 4.3 | Chunked GPU Taylor dispatch (exceed 128 MB WebGPU limit) | small | Users hit the buffer limit in practice | [ADR](docs/adr-deferred-work.md) | -| 4.4 | CUDA `laplacian_with_control_gpu_cuda` | small | CUDA users need variance-reduced Laplacian | [ADR](docs/adr-deferred-work.md) | -| 4.5 | `taylor_forward_2nd_batch` in `GpuBackend` trait | small | Multiple backends need to be used generically | [ADR](docs/adr-deferred-work.md) | +| # | Item | Status | +|---|------|--------| +| 4.1 | Indefinite dense STDE (`dense_stde_2nd_indefinite`, eigendecomposition + sign-splitting) | ✅ Done | +| 4.2 | General-K GPU Taylor kernels (K=1..5 via runtime codegen, `taylor_forward_kth_batch`) | ✅ Done | +| 4.3 | Chunked GPU Taylor dispatch (`taylor_forward_2nd_batch_chunked`, buffer + dispatch limits) | ✅ Done | +| 4.4 | Generic `laplacian_with_control_gpu` (works with any `GpuBackend`, replaces CUDA-specific fn) | ✅ Done | +| 4.5 | `taylor_forward_2nd_batch` in `GpuBackend` trait (all stde_gpu fns now generic over backend) | ✅ Done | --- @@ -110,7 +110,7 @@ Nice-to-haves with no urgency. Pursue opportunistically or if the relevant area | Item | Current | Latest | Effort | Notes | |------|---------|--------|--------|-------| -| cudarc | 0.17 | 0.19 | medium | Breaking API changes in 0.18. Defer until GPU backend is actively developed. | +| cudarc | 0.19 | 0.19 | — | ✅ Up to date | --- @@ -132,7 +132,6 @@ These items were evaluated and explicitly rejected. Rationale is in [docs/adr-de ## Dependencies Between Phases ``` -Phase 0–3 (complete) — all done as of 2026-03-14 -Phase 4 (deferred features) — each item independent; all require active use of base feature +Phase 0–4 (complete) — all done as of 2026-03-14 Phase 5 (aspirational) — independent nice-to-haves ``` diff --git a/src/gpu/cuda_backend.rs b/src/gpu/cuda_backend.rs index 14e3a37..54f5536 100644 --- a/src/gpu/cuda_backend.rs +++ b/src/gpu/cuda_backend.rs @@ -6,7 +6,7 @@ //! # Usage //! //! ```no_run -//! use echidna::gpu::{CudaContext, GpuTapeData}; +//! use echidna::gpu::{CudaContext, GpuBackend, GpuTapeData}; //! use echidna::{record, Scalar}; //! //! let ctx = CudaContext::new().expect("CUDA device required"); @@ -19,7 +19,7 @@ use std::sync::Arc; use cudarc::driver::{ - CudaContext as CudarContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, + CudaContext as CudarContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg, }; use cudarc::nvrtc::compile_ptx; @@ -670,14 +670,38 @@ impl GpuBackend for CudaContext { ) -> Result<(f32, Vec, crate::sparse::SparsityPattern, Vec), GpuError> { cuda_sparse_hessian_body!(self, tape, tape_cpu, x, f32, hvp_batch) } + + #[cfg(feature = "stde")] + fn taylor_forward_2nd_batch( + &self, + tape: &CudaTapeBuffers, + primal_inputs: &[f32], + direction_seeds: &[f32], + batch_size: u32, + ) -> Result, GpuError> { + cuda_taylor_fwd_2nd_body!( + self, + tape, + primal_inputs, + direction_seeds, + batch_size, + f32, + constants_f32, + taylor_fwd_2nd_f32 + ) + } } impl CudaContext { /// Batched second-order Taylor forward propagation (f32). /// - /// Each batch element pushes one direction through the tape, producing - /// a Taylor jet with 3 coefficients (c0=value, c1=first derivative, - /// c2=second derivative / 2). + /// Deprecated: this inherent method delegates to the `GpuBackend` trait method. + /// Import `GpuBackend` and call `taylor_forward_2nd_batch` directly. + #[cfg(feature = "stde")] + #[deprecated( + since = "0.5.0", + note = "import GpuBackend trait and call taylor_forward_2nd_batch() directly" + )] pub fn taylor_forward_2nd_batch( &self, tape: &CudaTapeBuffers, @@ -685,15 +709,12 @@ impl CudaContext { direction_seeds: &[f32], batch_size: u32, ) -> Result, GpuError> { - cuda_taylor_fwd_2nd_body!( + ::taylor_forward_2nd_batch( self, tape, primal_inputs, direction_seeds, batch_size, - f32, - constants_f32, - taylor_fwd_2nd_f32 ) } diff --git a/src/gpu/mod.rs b/src/gpu/mod.rs index e6a1249..4f483aa 100644 --- a/src/gpu/mod.rs +++ b/src/gpu/mod.rs @@ -16,7 +16,7 @@ //! - [`sparse_jacobian`](GpuBackend::sparse_jacobian) — GPU-accelerated sparse Jacobian //! - [`hvp_batch`](GpuBackend::hvp_batch) — batched Hessian-vector product //! - [`sparse_hessian`](GpuBackend::sparse_hessian) — GPU-accelerated sparse Hessian -//! - `taylor_forward_2nd_batch` — batched second-order Taylor forward propagation (inherent, requires `stde`) +//! - [`taylor_forward_2nd_batch`](GpuBackend::taylor_forward_2nd_batch) — batched second-order Taylor forward propagation (requires `stde`) //! //! CUDA additionally provides f64 methods as inherent methods on [`CudaContext`]. //! @@ -45,6 +45,9 @@ pub mod cuda_backend; #[cfg(feature = "stde")] pub mod stde_gpu; +#[cfg(feature = "stde")] +pub mod taylor_codegen; + #[cfg(feature = "gpu-wgpu")] pub use wgpu_backend::{WgpuContext, WgpuTapeBuffers}; @@ -134,6 +137,26 @@ pub trait GpuBackend { tape_cpu: &mut BytecodeTape, x: &[f32], ) -> Result<(f32, Vec, crate::sparse::SparsityPattern, Vec), GpuError>; + + /// Batched second-order Taylor forward propagation. + /// + /// Each batch element pushes one direction through the tape, producing + /// a Taylor jet with 3 coefficients (c0=value, c1=first derivative, + /// c2=second derivative / 2). + /// + /// `primal_inputs` is `[f32; batch_size * num_inputs]` — primals for each element. + /// `direction_seeds` is `[f32; batch_size * num_inputs]` — c1 seeds for each element. + /// + /// Returns `TaylorBatchResult` with `values`, `c1s`, `c2s` each of size + /// `[f32; batch_size * num_outputs]`. + #[cfg(feature = "stde")] + fn taylor_forward_2nd_batch( + &self, + tape: &Self::TapeBuffers, + primal_inputs: &[f32], + direction_seeds: &[f32], + batch_size: u32, + ) -> Result, GpuError>; } /// Result of a batched second-order Taylor forward propagation. @@ -152,6 +175,19 @@ pub struct TaylorBatchResult { pub c2s: Vec, } +/// Result of a batched K-th order Taylor forward propagation. +/// +/// `coefficients[k]` has `batch_size * num_outputs` elements for coefficient index k. +/// The Taylor convention is `c[k] = f^(k)(t₀) / k!`. +#[cfg(feature = "stde")] +pub struct TaylorKthBatchResult { + /// Taylor coefficients: `coefficients[k]` is the k-th order coefficient vector + /// with `batch_size * num_outputs` elements. + pub coefficients: Vec>, + /// The Taylor order (number of coefficients per output). + pub order: usize, +} + /// Error type for GPU operations. #[derive(Debug)] pub enum GpuError { @@ -283,3 +319,116 @@ pub struct TapeMeta { pub fn opcode_to_gpu(op: OpCode) -> u32 { op as u32 } + +/// Default maximum buffer size for WebGPU (128 MiB). +/// +/// WebGPU's `maxBufferSize` limit is 256 MiB, but we use 128 MiB as a +/// conservative default to avoid hitting device-specific limits. +#[cfg(feature = "stde")] +pub const WGPU_MAX_BUFFER_BYTES: u64 = 128 * 1024 * 1024; + +/// Maximum workgroup dispatches per dimension in WebGPU (65535). +#[cfg(feature = "stde")] +const MAX_WORKGROUPS_PER_DIM: u64 = 65535; + +/// Workgroup size used by the Taylor forward shader. +#[cfg(feature = "stde")] +const TAYLOR_WORKGROUP_SIZE: u64 = 256; + +/// Chunked batched second-order Taylor forward propagation. +/// +/// Splits a large batch into chunks that fit within GPU buffer size limits, +/// dispatches each chunk, and concatenates results. This avoids hitting WebGPU's +/// 128 MiB buffer limit or workgroup dispatch limits. +/// +/// # Arguments +/// +/// - `backend`: any `GpuBackend` implementation +/// - `tape`: uploaded tape buffers +/// - `primal_inputs`: `[f32; batch_size * num_inputs]` — primals for each element +/// - `direction_seeds`: `[f32; batch_size * num_inputs]` — c1 seeds for each element +/// - `batch_size`: total number of batch elements +/// - `num_inputs`: number of input variables per element +/// - `num_variables`: total tape variable slots (inputs + constants + intermediates) +/// - `max_buffer_bytes`: maximum GPU buffer size in bytes (use [`WGPU_MAX_BUFFER_BYTES`]) +/// +/// # Errors +/// +/// Returns `GpuError::Other` if `max_buffer_bytes` is too small for even a single element. +#[cfg(feature = "stde")] +#[allow(clippy::too_many_arguments)] +pub fn taylor_forward_2nd_batch_chunked( + backend: &B, + tape: &B::TapeBuffers, + primal_inputs: &[f32], + direction_seeds: &[f32], + batch_size: u32, + num_inputs: u32, + num_variables: u32, + max_buffer_bytes: u64, +) -> Result, GpuError> { + if batch_size == 0 { + return Ok(TaylorBatchResult { + values: vec![], + c1s: vec![], + c2s: vec![], + }); + } + + // The largest buffer is the jets working buffer: batch_size * num_variables * 3 * 4 bytes + let bytes_per_element = (num_variables as u64) * 3 * 4; + if bytes_per_element == 0 { + return Err(GpuError::Other("num_variables is zero".into())); + } + + let mut chunk_size = max_buffer_bytes / bytes_per_element; + if chunk_size == 0 { + return Err(GpuError::Other(format!( + "max_buffer_bytes ({max_buffer_bytes}) too small for a single element \ + ({bytes_per_element} bytes per element)" + ))); + } + + // Also cap at workgroup dispatch limit: 65535 workgroups * 256 threads + let dispatch_limit = MAX_WORKGROUPS_PER_DIM * TAYLOR_WORKGROUP_SIZE; + chunk_size = chunk_size.min(dispatch_limit); + + let chunk_size = chunk_size as u32; + + // If everything fits in one chunk, dispatch directly + if batch_size <= chunk_size { + return backend.taylor_forward_2nd_batch(tape, primal_inputs, direction_seeds, batch_size); + } + + // Multi-chunk dispatch + let ni = num_inputs as usize; + let mut all_values = Vec::new(); + let mut all_c1s = Vec::new(); + let mut all_c2s = Vec::new(); + + let mut offset = 0u32; + while offset < batch_size { + let this_chunk = chunk_size.min(batch_size - offset); + let start = (offset as usize) * ni; + let end = start + (this_chunk as usize) * ni; + + let chunk_result = backend.taylor_forward_2nd_batch( + tape, + &primal_inputs[start..end], + &direction_seeds[start..end], + this_chunk, + )?; + + all_values.extend(chunk_result.values); + all_c1s.extend(chunk_result.c1s); + all_c2s.extend(chunk_result.c2s); + + offset += this_chunk; + } + + Ok(TaylorBatchResult { + values: all_values, + c1s: all_c1s, + c2s: all_c2s, + }) +} diff --git a/src/gpu/stde_gpu.rs b/src/gpu/stde_gpu.rs index 8344be8..8a5406d 100644 --- a/src/gpu/stde_gpu.rs +++ b/src/gpu/stde_gpu.rs @@ -4,12 +4,10 @@ //! These use batched second-order Taylor forward propagation on the GPU to evaluate //! many directions in parallel. //! -//! # Supported backends -//! -//! - **wgpu** (`gpu-wgpu`): cross-platform, f32 only -//! - **CUDA** (`gpu-cuda`): NVIDIA, f32 (f64 via inherent methods on `CudaContext`) +//! All functions are generic over `B: GpuBackend`, working with any backend +//! (wgpu, CUDA, or future backends). -use super::{GpuError, TaylorBatchResult}; +use super::{GpuBackend, GpuError, TaylorBatchResult}; use crate::stde::EstimatorResult; /// GPU-accelerated Laplacian estimation via Hutchinson + Taylor-mode. @@ -20,11 +18,10 @@ use crate::stde::EstimatorResult; /// `directions` is `&[&[f32]]` with S direction vectors, each of length n. /// The primal point `x` is replicated for each batch element. /// -/// Works with any backend that provides `taylor_forward_2nd_batch`. -#[cfg(feature = "gpu-wgpu")] -pub fn laplacian_gpu( - backend: &super::WgpuContext, - tape: &super::WgpuTapeBuffers, +/// Works with any backend that implements `GpuBackend`. +pub fn laplacian_gpu( + backend: &B, + tape: &B::TapeBuffers, x: &[f32], directions: &[&[f32]], ) -> Result, GpuError> { @@ -47,32 +44,6 @@ pub fn laplacian_gpu( Ok(aggregate_laplacian(&result, s)) } -/// GPU-accelerated Laplacian estimation (CUDA f32). -#[cfg(feature = "gpu-cuda")] -pub fn laplacian_gpu_cuda( - backend: &super::CudaContext, - tape: &super::CudaTapeBuffers, - x: &[f32], - directions: &[&[f32]], -) -> Result, GpuError> { - let n = x.len(); - let s = directions.len(); - if s == 0 { - return Err(GpuError::Other("no directions provided".into())); - } - - let mut primals = Vec::with_capacity(s * n); - let mut seeds = Vec::with_capacity(s * n); - for dir in directions { - assert_eq!(dir.len(), n, "direction length must match x"); - primals.extend_from_slice(x); - seeds.extend_from_slice(dir); - } - - let result = backend.taylor_forward_2nd_batch(tape, &primals, &seeds, s as u32)?; - Ok(aggregate_laplacian(&result, s)) -} - /// CPU-side Welford aggregation of c2 values into a Laplacian estimate. fn aggregate_laplacian(result: &TaylorBatchResult, s: usize) -> EstimatorResult { // For Hutchinson: E[v^T H v] = tr(H), and v^T H v = 2 * c2 for unit-variance v. @@ -110,10 +81,9 @@ fn aggregate_laplacian(result: &TaylorBatchResult, s: usize) -> EstimatorRe /// /// Uses one batch element per input dimension, with each direction being a /// standard basis vector e_j. Returns `(f(x), diag(H))`. -#[cfg(feature = "gpu-wgpu")] -pub fn hessian_diagonal_gpu( - backend: &super::WgpuContext, - tape: &super::WgpuTapeBuffers, +pub fn hessian_diagonal_gpu( + backend: &B, + tape: &B::TapeBuffers, x: &[f32], ) -> Result<(f32, Vec), GpuError> { let n = x.len(); @@ -135,38 +105,13 @@ pub fn hessian_diagonal_gpu( Ok((value, diag)) } -/// GPU-accelerated exact Hessian diagonal (CUDA f32). -#[cfg(feature = "gpu-cuda")] -pub fn hessian_diagonal_gpu_cuda( - backend: &super::CudaContext, - tape: &super::CudaTapeBuffers, - x: &[f32], -) -> Result<(f32, Vec), GpuError> { - let n = x.len(); - - let mut primals = Vec::with_capacity(n * n); - let mut seeds = vec![0.0f32; n * n]; - for j in 0..n { - primals.extend_from_slice(x); - seeds[j * n + j] = 1.0; - } - - let result = backend.taylor_forward_2nd_batch(tape, &primals, &seeds, n as u32)?; - - let value = result.values[0]; - let diag: Vec = result.c2s.iter().map(|&c2| 2.0 * c2).collect(); - - Ok((value, diag)) -} - /// GPU-accelerated Laplacian with diagonal control variate. /// /// Uses a precomputed Hessian diagonal to reduce estimator variance. /// The control variate estimate is: `tr(H_diag) + mean(v^T H v - v^T diag(H) v)`. -#[cfg(feature = "gpu-wgpu")] -pub fn laplacian_with_control_gpu( - backend: &super::WgpuContext, - tape: &super::WgpuTapeBuffers, +pub fn laplacian_with_control_gpu( + backend: &B, + tape: &B::TapeBuffers, x: &[f32], directions: &[&[f32]], control_diagonal: &[f32], @@ -230,3 +175,34 @@ pub fn laplacian_with_control_gpu( num_samples: s, }) } + +// ── Deprecated backend-specific wrappers ── + +/// Deprecated: Use [`laplacian_gpu`] instead (now generic over any `GpuBackend`). +#[cfg(feature = "gpu-cuda")] +#[deprecated( + since = "0.5.0", + note = "use laplacian_gpu() which is now generic over GpuBackend" +)] +pub fn laplacian_gpu_cuda( + backend: &super::CudaContext, + tape: &super::CudaTapeBuffers, + x: &[f32], + directions: &[&[f32]], +) -> Result, GpuError> { + laplacian_gpu(backend, tape, x, directions) +} + +/// Deprecated: Use [`hessian_diagonal_gpu`] instead (now generic over any `GpuBackend`). +#[cfg(feature = "gpu-cuda")] +#[deprecated( + since = "0.5.0", + note = "use hessian_diagonal_gpu() which is now generic over GpuBackend" +)] +pub fn hessian_diagonal_gpu_cuda( + backend: &super::CudaContext, + tape: &super::CudaTapeBuffers, + x: &[f32], +) -> Result<(f32, Vec), GpuError> { + hessian_diagonal_gpu(backend, tape, x) +} diff --git a/src/gpu/taylor_codegen.rs b/src/gpu/taylor_codegen.rs new file mode 100644 index 0000000..0d0b603 --- /dev/null +++ b/src/gpu/taylor_codegen.rs @@ -0,0 +1,1643 @@ +//! Runtime codegen for K-specialized Taylor forward GPU shaders. +//! +//! Generates fully unrolled WGSL and CUDA shader source for Taylor jets of order +//! K ∈ {1, 2, 3, 4, 5}. Each generated shader has K-specific Cauchy products and +//! recurrences with no dynamic loops, preserving the performance of the handwritten +//! K=3 shader while supporting arbitrary orders. + +use std::fmt::Write; + +/// Generate a complete WGSL shader for K-th order Taylor forward propagation. +/// +/// The generated shader has the same bind group layout as `taylor_forward_2nd.wgsl` +/// but works with K coefficients per jet instead of 3. +/// +/// # Panics +/// Panics if `k` is not in 1..=5. +pub fn generate_taylor_wgsl(k: usize) -> String { + assert!((1..=5).contains(&k), "K must be in 1..=5, got {k}"); + + let mut s = String::with_capacity(16384); + writeln!(s, "// Auto-generated Taylor forward K={k} kernel.").unwrap(); + writeln!(s, "// Do not edit — generated by taylor_codegen.rs.\n").unwrap(); + + // Opcode constants + write_wgsl_opcodes(&mut s); + + // TapeMeta, bind groups + write_wgsl_bindings(&mut s, k); + + // Helper builtins missing from WGSL + write_wgsl_helpers(&mut s); + + // JetK struct and operations + write_wgsl_jet_type(&mut s, k); + write_wgsl_jet_arithmetic(&mut s, k); + write_wgsl_jet_transcendental(&mut s, k); + write_wgsl_jet_inverse_trig(&mut s, k); + + // Main kernel + write_wgsl_main_kernel(&mut s, k); + + s +} + +/// Generate a complete CUDA kernel for K-th order Taylor forward propagation. +/// +/// Expects `#define FLOAT_TYPE float` or `double` to be prepended by the caller. +/// +/// # Panics +/// Panics if `k` is not in 1..=5. +pub fn generate_taylor_cuda(k: usize) -> String { + assert!((1..=5).contains(&k), "K must be in 1..=5, got {k}"); + + let mut s = String::with_capacity(16384); + writeln!(s, "// Auto-generated Taylor forward K={k} kernel.").unwrap(); + writeln!(s, "// Do not edit — generated by taylor_codegen.rs.\n").unwrap(); + writeln!(s, "typedef FLOAT_TYPE F;").unwrap(); + + // Opcode constants + write_cuda_opcodes(&mut s); + + // Math helpers + write_cuda_helpers(&mut s); + + // JetK struct and operations + write_cuda_jet_type(&mut s, k); + write_cuda_jet_arithmetic(&mut s, k); + write_cuda_jet_transcendental(&mut s, k); + write_cuda_jet_inverse_trig(&mut s, k); + + // Main kernel + write_cuda_main_kernel(&mut s, k); + + s +} + +// ══════════════════════════════════════════════ +// WGSL code generation +// ══════════════════════════════════════════════ + +fn write_wgsl_opcodes(s: &mut String) { + let ops = [ + ("OP_INPUT", 0), + ("OP_CONST", 1), + ("OP_ADD", 2), + ("OP_SUB", 3), + ("OP_MUL", 4), + ("OP_DIV", 5), + ("OP_REM", 6), + ("OP_POWF", 7), + ("OP_ATAN2", 8), + ("OP_HYPOT", 9), + ("OP_MAX", 10), + ("OP_MIN", 11), + ("OP_NEG", 12), + ("OP_RECIP", 13), + ("OP_SQRT", 14), + ("OP_CBRT", 15), + ("OP_POWI", 16), + ("OP_EXP", 17), + ("OP_EXP2", 18), + ("OP_EXPM1", 19), + ("OP_LN", 20), + ("OP_LOG2", 21), + ("OP_LOG10", 22), + ("OP_LN1P", 23), + ("OP_SIN", 24), + ("OP_COS", 25), + ("OP_TAN", 26), + ("OP_ASIN", 27), + ("OP_ACOS", 28), + ("OP_ATAN", 29), + ("OP_SINH", 30), + ("OP_COSH", 31), + ("OP_TANH", 32), + ("OP_ASINH", 33), + ("OP_ACOSH", 34), + ("OP_ATANH", 35), + ("OP_ABS", 36), + ("OP_SIGNUM", 37), + ("OP_FLOOR", 38), + ("OP_CEIL", 39), + ("OP_ROUND", 40), + ("OP_TRUNC", 41), + ("OP_FRACT", 42), + ]; + for (name, val) in &ops { + writeln!(s, "const {name}: u32 = {val}u;").unwrap(); + } + writeln!(s).unwrap(); +} + +fn write_wgsl_bindings(s: &mut String, k: usize) { + writeln!( + s, + "struct TapeMeta {{ + num_ops: u32, + num_inputs: u32, + num_variables: u32, + num_outputs: u32, + batch_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +}} + +@group(0) @binding(0) var opcodes: array; +@group(0) @binding(1) var arg0: array; +@group(0) @binding(2) var arg1: array; +@group(0) @binding(3) var constants: array; +@group(0) @binding(4) var tape_meta: TapeMeta; +@group(0) @binding(5) var output_indices: array; + +@group(1) @binding(0) var primal_inputs: array; +@group(1) @binding(1) var direction_seeds: array; +@group(1) @binding(2) var jets: array; +@group(1) @binding(3) var jet_outputs: array; +" + ) + .unwrap(); + let _ = k; // K only affects buffer sizes at dispatch time, not bindings +} + +fn write_wgsl_helpers(s: &mut String) { + writeln!( + s, + "fn sinh_f(x: f32) -> f32 {{ return (exp(x) - exp(-x)) * 0.5; }} +fn cosh_f(x: f32) -> f32 {{ return (exp(x) + exp(-x)) * 0.5; }} +fn asinh_f(x: f32) -> f32 {{ return log(x + sqrt(x * x + 1.0)); }} +fn acosh_f(x: f32) -> f32 {{ return log(x + sqrt(x * x - 1.0)); }} +fn atanh_f(x: f32) -> f32 {{ return 0.5 * log((1.0 + x) / (1.0 - x)); }} +" + ) + .unwrap(); +} + +fn write_wgsl_jet_type(s: &mut String, k: usize) { + writeln!(s, "struct JetK {{ v: array, }}\n").unwrap(); + + // jet_const: create jet from scalar + write!( + s, + "fn jet_const(val: f32) -> JetK {{\n var j: JetK;\n j.v[0] = val;\n" + ) + .unwrap(); + for i in 1..k { + writeln!(s, " j.v[{i}] = 0.0;").unwrap(); + } + writeln!(s, " return j;\n}}\n").unwrap(); + + // jet_load: read jet from buffer + writeln!(s, "fn jet_load(base: u32) -> JetK {{").unwrap(); + writeln!(s, " var j: JetK;").unwrap(); + for i in 0..k { + writeln!(s, " j.v[{i}] = jets[base + {i}u];").unwrap(); + } + writeln!(s, " return j;\n}}\n").unwrap(); + + // jet_store: write jet to buffer + writeln!(s, "fn jet_store(base: u32, j: JetK) {{").unwrap(); + for i in 0..k { + writeln!(s, " jets[base + {i}u] = j.v[{i}];").unwrap(); + } + writeln!(s, "}}\n").unwrap(); +} + +fn write_wgsl_jet_arithmetic(s: &mut String, k: usize) { + // Add + writeln!(s, "fn jet_add(a: JetK, b: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] + b.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sub + writeln!(s, "fn jet_sub(a: JetK, b: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] - b.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Neg + writeln!(s, "fn jet_neg(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = -a.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Scale + writeln!(s, "fn jet_scale(a: JetK, s: f32) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] * s;").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Mul (Cauchy product): c[i] = Σ_{j=0}^{i} a[j] * b[i-j] + writeln!(s, "fn jet_mul(a: JetK, b: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + for i in 0..k { + let mut terms = Vec::new(); + for j in 0..=i { + terms.push(format!("a.v[{j}] * b.v[{}]", i - j)); + } + writeln!(s, " c.v[{i}] = {};", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Div: c[i] = (a[i] - Σ_{j=1}^{i} b[j]*c[i-j]) / b[0] + writeln!(s, "fn jet_div(a: JetK, b: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " let inv_b0 = 1.0 / b.v[0];").unwrap(); + writeln!(s, " c.v[0] = a.v[0] * inv_b0;").unwrap(); + for i in 1..k { + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("b.v[{j}] * c.v[{}]", i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - ({})) * inv_b0;", + terms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Recip: c = 1/a — special case: c[i] = -(Σ_{j=1}^{i} a[j]*c[i-j]) * c[0] + writeln!(s, "fn jet_recip(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = 1.0 / a.v[0];").unwrap(); + for i in 1..k { + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("a.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " c.v[{i}] = -({}) * c.v[0];", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); +} + +fn write_wgsl_jet_transcendental(s: &mut String, k: usize) { + // Exp: c[0] = exp(a[0]), c[i] = (1/i) * Σ_{j=1}^{i} j*a[j]*c[i-j] + writeln!(s, "fn jet_exp(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = exp(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * c.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Ln: c[0] = ln(a[0]), c[i] = (a[i] - (1/i)*Σ_{j=1}^{i-1} j*c[j]*a[i-j]) / a[0] + writeln!(s, "fn jet_ln(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " let inv_a0 = 1.0 / a.v[0];").unwrap(); + writeln!(s, " c.v[0] = log(a.v[0]);").unwrap(); + for i in 1..k { + if i == 1 { + writeln!(s, " c.v[1] = a.v[1] * inv_a0;").unwrap(); + } else { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..i { + terms.push(format!("{:.1} * c.v[{j}] * a.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - {inv_i:.10} * ({})) * inv_a0;", + terms.join(" + ") + ) + .unwrap(); + } + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sqrt: c[0] = sqrt(a[0]), c[i] = (a[i] - Σ_{j=1}^{i-1} c[j]*c[i-j]) / (2*c[0]) + writeln!(s, "fn jet_sqrt(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = sqrt(a.v[0]);").unwrap(); + if k > 1 { + writeln!(s, " let inv_2c0 = 0.5 / c.v[0];").unwrap(); + } + for i in 1..k { + if i == 1 { + writeln!(s, " c.v[1] = a.v[1] * inv_2c0;").unwrap(); + } else { + let mut terms = Vec::new(); + for j in 1..i { + terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - ({})) * inv_2c0;", + terms.join(" + ") + ) + .unwrap(); + } + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sin/Cos coupled recurrence: + // s[0] = sin(a[0]), co[0] = cos(a[0]) + // s[i] = (1/i) * Σ_{j=1}^{i} j*a[j]*co[i-j] + // co[i] = -(1/i) * Σ_{j=1}^{i} j*a[j]*s[i-j] + writeln!(s, "struct JetPair {{ a: JetK, b: JetK, }}\n").unwrap(); + + writeln!(s, "fn jet_sin_cos(a: JetK) -> JetPair {{").unwrap(); + writeln!(s, " var sn: JetK;").unwrap(); + writeln!(s, " var co: JetK;").unwrap(); + writeln!(s, " sn.v[0] = sin(a.v[0]);").unwrap(); + writeln!(s, " co.v[0] = cos(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut sterms = Vec::new(); + let mut cterms = Vec::new(); + for j in 1..=i { + sterms.push(format!("{:.1} * a.v[{j}] * co.v[{}]", j as f64, i - j)); + cterms.push(format!("{:.1} * a.v[{j}] * sn.v[{}]", j as f64, i - j)); + } + writeln!(s, " sn.v[{i}] = {inv_i:.10} * ({});", sterms.join(" + ")).unwrap(); + writeln!( + s, + " co.v[{i}] = -{inv_i:.10} * ({});", + cterms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return JetPair(sn, co);\n}}\n").unwrap(); + + // Sinh/Cosh coupled: same but positive signs for cosh + writeln!(s, "fn jet_sinh_cosh(a: JetK) -> JetPair {{").unwrap(); + writeln!(s, " var sh: JetK;").unwrap(); + writeln!(s, " var ch: JetK;").unwrap(); + writeln!(s, " sh.v[0] = sinh_f(a.v[0]);").unwrap(); + writeln!(s, " ch.v[0] = cosh_f(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut shterms = Vec::new(); + let mut chterms = Vec::new(); + for j in 1..=i { + shterms.push(format!("{:.1} * a.v[{j}] * ch.v[{}]", j as f64, i - j)); + chterms.push(format!("{:.1} * a.v[{j}] * sh.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " sh.v[{i}] = {inv_i:.10} * ({});", + shterms.join(" + ") + ) + .unwrap(); + writeln!( + s, + " ch.v[{i}] = {inv_i:.10} * ({});", + chterms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return JetPair(sh, ch);\n}}\n").unwrap(); + + // Tan: c' = a' * (1 + c²), where s = 1 + c² is a helper jet + // s[0] = 1 + c[0]², then integration: c[i] = (1/i) * Σ j*a[j]*s[i-j] + // But s depends on c, so we update s as we compute c. + // s = jet_mul(c,c); s.v[0] += 1 + // c[i] = (1/i) * Σ_{j=1}^{i} j*a[j]*s[i-j], then update s[i] = Σ c[j]*c[i-j] for j=0..i + writeln!(s, "fn jet_tan(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " var sc: JetK; // 1 + c²").unwrap(); + writeln!(s, " c.v[0] = tan(a.v[0]);").unwrap(); + writeln!(s, " sc.v[0] = 1.0 + c.v[0] * c.v[0];").unwrap(); + for i in 1..k { + // First update sc[i] (c² contribution) from already-computed c[0..i-1] + // Actually, we need sc[i-j] in the integration, so we compute sc up to i-1 first + // Then compute c[i], then sc[i]. + // sc[p] = Σ_{j=0}^{p} c[j]*c[p-j] for the c² part + // Actually the recurrence is: c[i] = (1/i) * Σ_{j=1}^{i} j*a[j]*sc[i-j] + // where sc = 1+c². So sc[0] = 1 + c[0]², sc[p] = Σ_{j=0}^{p} c[j]*c[p-j] for p>0 + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + // Now update sc[i] = Σ c[j]*c[i-j] + let mut sc_terms = Vec::new(); + for j in 0..=i { + sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " sc.v[{i}] = {};", sc_terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Tanh: c' = a' * (1 - c²) + writeln!(s, "fn jet_tanh(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " var sc: JetK; // 1 - c²").unwrap(); + writeln!(s, " c.v[0] = tanh(a.v[0]);").unwrap(); + writeln!(s, " sc.v[0] = 1.0 - c.v[0] * c.v[0];").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + let mut sc_terms = Vec::new(); + for j in 0..=i { + sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " sc.v[{i}] = -({});", sc_terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); +} + +fn write_wgsl_jet_inverse_trig(s: &mut String, k: usize) { + // atan: c' = a' / (1 + a²), so g = 1/(1+a²) + writeln!(s, "fn jet_atan(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = 1.0 + asq.v[0];\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(d);").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = atan(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // asin: c' = a' / sqrt(1 - a²) + writeln!(s, "fn jet_asin(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = asin(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // acos: π/2 - asin → negate higher coefficients + writeln!(s, "fn jet_acos(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = acos(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = -{inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // asinh: c' = a' / sqrt(1 + a²) + writeln!(s, "fn jet_asinh(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = 1.0 + asq.v[0];\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = asinh_f(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // acosh: c' = a' / sqrt(a² - 1) + writeln!(s, "fn jet_acosh(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = asq.v[0] - 1.0;\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = acosh_f(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // atanh: c' = a' / (1 - a²) + writeln!(s, "fn jet_atanh(a: JetK) -> JetK {{").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap(); + } + writeln!(s, " let g = jet_recip(d);").unwrap(); + writeln!(s, " var c: JetK;").unwrap(); + writeln!(s, " c.v[0] = atanh_f(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); +} + +fn write_wgsl_main_kernel(s: &mut String, k: usize) { + writeln!( + s, + "@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) {{ + let bid = gid.x; + if bid >= tape_meta.batch_size {{ + return; + }} + + let nv = tape_meta.num_variables; + let ni = tape_meta.num_inputs; + let num_ops = tape_meta.num_ops; + let n_out = tape_meta.num_outputs; + let K = {k}u; + + let j_base = bid * nv * K;" + ) + .unwrap(); + + // Initialize all variables from constants + writeln!(s, "\n // Initialize from constants").unwrap(); + writeln!(s, " for (var i = 0u; i < nv; i = i + 1u) {{").unwrap(); + writeln!(s, " let off = j_base + i * K;").unwrap(); + writeln!(s, " jets[off] = constants[i];").unwrap(); + for c in 1..k { + writeln!(s, " jets[off + {c}u] = 0.0;").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // Set input variables + writeln!(s, "\n // Set input jets").unwrap(); + writeln!(s, " let in_base = bid * ni;").unwrap(); + writeln!(s, " for (var i = 0u; i < ni; i = i + 1u) {{").unwrap(); + writeln!(s, " let off = j_base + i * K;").unwrap(); + writeln!(s, " jets[off] = primal_inputs[in_base + i];").unwrap(); + if k > 1 { + writeln!(s, " jets[off + 1u] = direction_seeds[in_base + i];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // Tape walk + writeln!(s, "\n // Walk the tape").unwrap(); + writeln!(s, " for (var i = ni; i < num_ops; i = i + 1u) {{").unwrap(); + writeln!(s, " let op = opcodes[i];").unwrap(); + writeln!(s, " if op == OP_CONST {{ continue; }}").unwrap(); + writeln!(s, " let a_idx = arg0[i];").unwrap(); + writeln!(s, " let b_idx = arg1[i];").unwrap(); + writeln!(s, " let a = jet_load(j_base + a_idx * K);").unwrap(); + writeln!(s, " var r = jet_const(0.0);").unwrap(); + writeln!(s, " switch op {{").unwrap(); + + // Binary ops + for (case, name) in &[ + (2, "jet_add"), + (3, "jet_sub"), + (4, "jet_mul"), + (5, "jet_div"), + ] { + writeln!(s, " case {case}u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!(s, " r = {name}(a, b);").unwrap(); + writeln!(s, " }}").unwrap(); + } + + // REM + writeln!(s, " case 6u: {{").unwrap(); + writeln!(s, " let b_val = jets[j_base + b_idx * K];").unwrap(); + writeln!( + s, + " r = jet_const(a.v[0] - trunc(a.v[0] / b_val) * b_val);" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // POWF + writeln!(s, " case 7u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!(s, " let lna = jet_ln(a);").unwrap(); + writeln!(s, " let product = jet_mul(b, lna);").unwrap(); + writeln!(s, " r = jet_exp(product);").unwrap(); + writeln!(s, " r.v[0] = pow(a.v[0], b.v[0]);").unwrap(); + writeln!(s, " }}").unwrap(); + + // ATAN2 + writeln!(s, " case 8u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!(s, " let ratio = jet_div(a, b);").unwrap(); + writeln!(s, " let at = jet_atan(ratio);").unwrap(); + writeln!(s, " r = at;").unwrap(); + writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]);").unwrap(); + writeln!(s, " }}").unwrap(); + + // HYPOT + writeln!(s, " case 9u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!(s, " let asq = jet_mul(a, a);").unwrap(); + writeln!(s, " let bsq = jet_mul(b, b);").unwrap(); + writeln!(s, " let sum2 = jet_add(asq, bsq);").unwrap(); + writeln!(s, " r = jet_sqrt(sum2);").unwrap(); + writeln!( + s, + " r.v[0] = sqrt(a.v[0] * a.v[0] + b.v[0] * b.v[0]);" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // MAX, MIN + writeln!(s, " case 10u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!( + s, + " if a.v[0] >= b.v[0] {{ r = a; }} else {{ r = b; }}" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + writeln!(s, " case 11u: {{").unwrap(); + writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap(); + writeln!( + s, + " if a.v[0] <= b.v[0] {{ r = a; }} else {{ r = b; }}" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // Unary ops + writeln!(s, " case 12u: {{ r = jet_neg(a); }}").unwrap(); + writeln!(s, " case 13u: {{ r = jet_recip(a); }}").unwrap(); + writeln!(s, " case 14u: {{ r = jet_sqrt(a); }}").unwrap(); + + // CBRT + writeln!(s, " case 15u: {{").unwrap(); + writeln!(s, " let sg = sign(a.v[0]);").unwrap(); + write!( + s, + " var abs_a: JetK;\n abs_a.v[0] = abs(a.v[0]);\n" + ) + .unwrap(); + for i in 1..k { + writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap(); + } + writeln!(s, " let lna = jet_ln(abs_a);").unwrap(); + writeln!(s, " let third = jet_scale(lna, 1.0 / 3.0);").unwrap(); + writeln!(s, " let e = jet_exp(third);").unwrap(); + writeln!(s, " r.v[0] = sg * e.v[0];").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = sg * e.v[{i}];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // POWI + writeln!(s, " case 16u: {{").unwrap(); + writeln!(s, " let n = f32(bitcast(b_idx));").unwrap(); + writeln!(s, " if n == 0.0 {{").unwrap(); + writeln!(s, " r = jet_const(1.0);").unwrap(); + writeln!(s, " }} else if n == 1.0 {{").unwrap(); + writeln!(s, " r = a;").unwrap(); + writeln!(s, " }} else {{").unwrap(); + writeln!(s, " let lna = jet_ln(a);").unwrap(); + writeln!(s, " let nlna = jet_scale(lna, n);").unwrap(); + writeln!(s, " r = jet_exp(nlna);").unwrap(); + writeln!(s, " r.v[0] = pow(a.v[0], n);").unwrap(); + writeln!(s, " }}").unwrap(); + writeln!(s, " }}").unwrap(); + + // Transcendental unary + writeln!(s, " case 17u: {{ r = jet_exp(a); }}").unwrap(); + writeln!(s, " case 18u: {{").unwrap(); + writeln!(s, " let ln2 = log(2.0);").unwrap(); + writeln!(s, " let scaled = jet_scale(a, ln2);").unwrap(); + writeln!(s, " r = jet_exp(scaled);").unwrap(); + writeln!(s, " r.v[0] = exp2(a.v[0]);").unwrap(); + writeln!(s, " }}").unwrap(); + writeln!(s, " case 19u: {{").unwrap(); + writeln!(s, " r = jet_exp(a);").unwrap(); + writeln!(s, " r.v[0] = exp(a.v[0]) - 1.0;").unwrap(); + writeln!(s, " }}").unwrap(); + writeln!(s, " case 20u: {{ r = jet_ln(a); }}").unwrap(); + writeln!(s, " case 21u: {{").unwrap(); + writeln!(s, " r = jet_ln(a);").unwrap(); + writeln!(s, " let inv_ln2 = 1.0 / log(2.0);").unwrap(); + writeln!(s, " r.v[0] = log2(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = r.v[{i}] * inv_ln2;").unwrap(); + } + writeln!(s, " }}").unwrap(); + writeln!(s, " case 22u: {{").unwrap(); + writeln!(s, " r = jet_ln(a);").unwrap(); + writeln!(s, " let inv_ln10 = 1.0 / log(10.0);").unwrap(); + writeln!(s, " r.v[0] = log(a.v[0]) * inv_ln10;").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = r.v[{i}] * inv_ln10;").unwrap(); + } + writeln!(s, " }}").unwrap(); + writeln!(s, " case 23u: {{").unwrap(); + write!( + s, + " var one_plus_a: JetK;\n one_plus_a.v[0] = 1.0 + a.v[0];\n" + ) + .unwrap(); + for i in 1..k { + writeln!(s, " one_plus_a.v[{i}] = a.v[{i}];").unwrap(); + } + writeln!(s, " r = jet_ln(one_plus_a);").unwrap(); + writeln!(s, " r.v[0] = log(1.0 + a.v[0]);").unwrap(); + writeln!(s, " }}").unwrap(); + + // Sin, Cos + writeln!( + s, + " case 24u: {{ let sc = jet_sin_cos(a); r = sc.a; }}" + ) + .unwrap(); + writeln!( + s, + " case 25u: {{ let sc = jet_sin_cos(a); r = sc.b; }}" + ) + .unwrap(); + writeln!(s, " case 26u: {{ r = jet_tan(a); }}").unwrap(); + writeln!(s, " case 27u: {{ r = jet_asin(a); }}").unwrap(); + writeln!(s, " case 28u: {{ r = jet_acos(a); }}").unwrap(); + writeln!(s, " case 29u: {{ r = jet_atan(a); }}").unwrap(); + writeln!( + s, + " case 30u: {{ let sc = jet_sinh_cosh(a); r = sc.a; }}" + ) + .unwrap(); + writeln!( + s, + " case 31u: {{ let sc = jet_sinh_cosh(a); r = sc.b; }}" + ) + .unwrap(); + writeln!(s, " case 32u: {{ r = jet_tanh(a); }}").unwrap(); + writeln!(s, " case 33u: {{ r = jet_asinh(a); }}").unwrap(); + writeln!(s, " case 34u: {{ r = jet_acosh(a); }}").unwrap(); + writeln!(s, " case 35u: {{ r = jet_atanh(a); }}").unwrap(); + + // Nonsmooth + writeln!(s, " case 36u: {{").unwrap(); + writeln!(s, " let sg = sign(a.v[0]);").unwrap(); + writeln!(s, " r.v[0] = abs(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = sg * a.v[{i}];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + writeln!(s, " case 37u, 38u, 39u, 40u, 41u: {{").unwrap(); + writeln!(s, " var val = 0.0f;").unwrap(); + writeln!(s, " switch op {{").unwrap(); + writeln!(s, " case 37u: {{ val = sign(a.v[0]); }}").unwrap(); + writeln!( + s, + " case 38u: {{ val = floor(a.v[0]); }}" + ) + .unwrap(); + writeln!(s, " case 39u: {{ val = ceil(a.v[0]); }}").unwrap(); + writeln!( + s, + " case 40u: {{ val = round(a.v[0]); }}" + ) + .unwrap(); + writeln!( + s, + " case 41u: {{ val = trunc(a.v[0]); }}" + ) + .unwrap(); + writeln!(s, " default: {{}}").unwrap(); + writeln!(s, " }}").unwrap(); + writeln!(s, " r = jet_const(val);").unwrap(); + writeln!(s, " }}").unwrap(); + + writeln!(s, " case 42u: {{").unwrap(); + writeln!(s, " r.v[0] = fract(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = a.v[{i}];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + writeln!(s, " default: {{}}").unwrap(); + writeln!(s, " }}").unwrap(); + + // Store result + writeln!(s, " jet_store(j_base + i * K, r);").unwrap(); + writeln!(s, " }}").unwrap(); + + // Write output jets + writeln!(s, "\n // Write output jets").unwrap(); + writeln!(s, " let out_base = bid * n_out * K;").unwrap(); + writeln!(s, " for (var j = 0u; j < n_out; j = j + 1u) {{").unwrap(); + writeln!(s, " let oi = output_indices[j];").unwrap(); + writeln!(s, " let src = j_base + oi * K;").unwrap(); + writeln!(s, " let dst = out_base + j * K;").unwrap(); + for c in 0..k { + writeln!(s, " jet_outputs[dst + {c}u] = jets[src + {c}u];").unwrap(); + } + writeln!(s, " }}").unwrap(); + writeln!(s, "}}").unwrap(); +} + +// ══════════════════════════════════════════════ +// CUDA code generation +// ══════════════════════════════════════════════ + +fn write_cuda_opcodes(s: &mut String) { + let ops = [ + "OP_INPUT", + "OP_CONST", + "OP_ADD", + "OP_SUB", + "OP_MUL", + "OP_DIV", + "OP_REM", + "OP_POWF", + "OP_ATAN2", + "OP_HYPOT", + "OP_MAX", + "OP_MIN", + "OP_NEG", + "OP_RECIP", + "OP_SQRT", + "OP_CBRT", + "OP_POWI", + "OP_EXP", + "OP_EXP2", + "OP_EXPM1", + "OP_LN", + "OP_LOG2", + "OP_LOG10", + "OP_LN1P", + "OP_SIN", + "OP_COS", + "OP_TAN", + "OP_ASIN", + "OP_ACOS", + "OP_ATAN", + "OP_SINH", + "OP_COSH", + "OP_TANH", + "OP_ASINH", + "OP_ACOSH", + "OP_ATANH", + "OP_ABS", + "OP_SIGNUM", + "OP_FLOOR", + "OP_CEIL", + "OP_ROUND", + "OP_TRUNC", + "OP_FRACT", + ]; + for (i, name) in ops.iter().enumerate() { + writeln!(s, "#define {name} {i}").unwrap(); + } + writeln!(s).unwrap(); +} + +fn write_cuda_helpers(s: &mut String) { + writeln!( + s, + "__device__ F _sign(F x) {{ return (x > (F)0) - (x < (F)0); }} +__device__ F _cbrt_f(F x) {{ return (x >= (F)0) ? pow(x, (F)(1.0/3.0)) : -pow(-x, (F)(1.0/3.0)); }} +__device__ F _fract(F x) {{ return x - floor(x); }} +" + ) + .unwrap(); +} + +fn write_cuda_jet_type(s: &mut String, k: usize) { + writeln!( + s, + "struct JetK {{ + F v[{k}]; + __device__ JetK() {{ for(int i=0;i<{k};i++) v[i]=(F)0; }} +}}; +" + ) + .unwrap(); + + // jet_const + writeln!(s, "__device__ JetK jet_const(F val) {{").unwrap(); + writeln!(s, " JetK j;").unwrap(); + writeln!(s, " j.v[0] = val;").unwrap(); + writeln!(s, " return j;\n}}\n").unwrap(); +} + +fn write_cuda_jet_arithmetic(s: &mut String, k: usize) { + // Add + writeln!(s, "__device__ JetK jet_add(JetK a, JetK b) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] + b.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sub + writeln!(s, "__device__ JetK jet_sub(JetK a, JetK b) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] - b.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Neg + writeln!(s, "__device__ JetK jet_neg(JetK a) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = -a.v[{i}];").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Scale + writeln!(s, "__device__ JetK jet_scale(JetK a, F s) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + for i in 0..k { + writeln!(s, " c.v[{i}] = a.v[{i}] * s;").unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Mul (Cauchy product) + writeln!(s, "__device__ JetK jet_mul(JetK a, JetK b) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + for i in 0..k { + let mut terms = Vec::new(); + for j in 0..=i { + terms.push(format!("a.v[{j}] * b.v[{}]", i - j)); + } + writeln!(s, " c.v[{i}] = {};", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Div + writeln!(s, "__device__ JetK jet_div(JetK a, JetK b) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " F inv_b0 = (F)1 / b.v[0];").unwrap(); + writeln!(s, " c.v[0] = a.v[0] * inv_b0;").unwrap(); + for i in 1..k { + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("b.v[{j}] * c.v[{}]", i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - ({})) * inv_b0;", + terms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Recip + writeln!(s, "__device__ JetK jet_recip(JetK a) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " c.v[0] = (F)1 / a.v[0];").unwrap(); + for i in 1..k { + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("a.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " c.v[{i}] = -({}) * c.v[0];", terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); +} + +fn write_cuda_jet_transcendental(s: &mut String, k: usize) { + // Exp + writeln!(s, "__device__ JetK jet_exp(JetK a) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " c.v[0] = exp(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("(F){:.1} * a.v[{j}] * c.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = (F){inv_i:.10} * ({});", + terms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Ln + writeln!(s, "__device__ JetK jet_ln(JetK a) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " F inv_a0 = (F)1 / a.v[0];").unwrap(); + writeln!(s, " c.v[0] = log(a.v[0]);").unwrap(); + for i in 1..k { + if i == 1 { + writeln!(s, " c.v[1] = a.v[1] * inv_a0;").unwrap(); + } else { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..i { + terms.push(format!("(F){:.1} * c.v[{j}] * a.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - (F){inv_i:.10} * ({})) * inv_a0;", + terms.join(" + ") + ) + .unwrap(); + } + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sqrt + writeln!(s, "__device__ JetK jet_sqrt(JetK a) {{").unwrap(); + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " c.v[0] = sqrt(a.v[0]);").unwrap(); + if k > 1 { + writeln!(s, " F inv_2c0 = (F)0.5 / c.v[0];").unwrap(); + } + for i in 1..k { + if i == 1 { + writeln!(s, " c.v[1] = a.v[1] * inv_2c0;").unwrap(); + } else { + let mut terms = Vec::new(); + for j in 1..i { + terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!( + s, + " c.v[{i}] = (a.v[{i}] - ({})) * inv_2c0;", + terms.join(" + ") + ) + .unwrap(); + } + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Sin/Cos + writeln!(s, "struct JetPair {{ JetK a; JetK b; }};\n").unwrap(); + + writeln!(s, "__device__ JetPair jet_sin_cos(JetK a) {{").unwrap(); + writeln!(s, " JetK sn, co;").unwrap(); + writeln!(s, " sn.v[0] = sin(a.v[0]);").unwrap(); + writeln!(s, " co.v[0] = cos(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut sterms = Vec::new(); + let mut cterms = Vec::new(); + for j in 1..=i { + sterms.push(format!("(F){:.1} * a.v[{j}] * co.v[{}]", j as f64, i - j)); + cterms.push(format!("(F){:.1} * a.v[{j}] * sn.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " sn.v[{i}] = (F){inv_i:.10} * ({});", + sterms.join(" + ") + ) + .unwrap(); + writeln!( + s, + " co.v[{i}] = -(F){inv_i:.10} * ({});", + cterms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " JetPair p; p.a = sn; p.b = co;").unwrap(); + writeln!(s, " return p;\n}}\n").unwrap(); + + // Sinh/Cosh + writeln!(s, "__device__ JetPair jet_sinh_cosh(JetK a) {{").unwrap(); + writeln!(s, " JetK sh, ch;").unwrap(); + writeln!(s, " sh.v[0] = sinh(a.v[0]);").unwrap(); + writeln!(s, " ch.v[0] = cosh(a.v[0]);").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut shterms = Vec::new(); + let mut chterms = Vec::new(); + for j in 1..=i { + shterms.push(format!("(F){:.1} * a.v[{j}] * ch.v[{}]", j as f64, i - j)); + chterms.push(format!("(F){:.1} * a.v[{j}] * sh.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " sh.v[{i}] = (F){inv_i:.10} * ({});", + shterms.join(" + ") + ) + .unwrap(); + writeln!( + s, + " ch.v[{i}] = (F){inv_i:.10} * ({});", + chterms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " JetPair p; p.a = sh; p.b = ch;").unwrap(); + writeln!(s, " return p;\n}}\n").unwrap(); + + // Tan + writeln!(s, "__device__ JetK jet_tan(JetK a) {{").unwrap(); + writeln!(s, " JetK c, sc;").unwrap(); + writeln!(s, " c.v[0] = tan(a.v[0]);").unwrap(); + writeln!(s, " sc.v[0] = (F)1 + c.v[0] * c.v[0];").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("(F){:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = (F){inv_i:.10} * ({});", + terms.join(" + ") + ) + .unwrap(); + let mut sc_terms = Vec::new(); + for j in 0..=i { + sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " sc.v[{i}] = {};", sc_terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + + // Tanh + writeln!(s, "__device__ JetK jet_tanh(JetK a) {{").unwrap(); + writeln!(s, " JetK c, sc;").unwrap(); + writeln!(s, " c.v[0] = tanh(a.v[0]);").unwrap(); + writeln!(s, " sc.v[0] = (F)1 - c.v[0] * c.v[0];").unwrap(); + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("(F){:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = (F){inv_i:.10} * ({});", + terms.join(" + ") + ) + .unwrap(); + let mut sc_terms = Vec::new(); + for j in 0..=i { + sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j)); + } + writeln!(s, " sc.v[{i}] = -({});", sc_terms.join(" + ")).unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); +} + +fn write_cuda_jet_inverse_trig(s: &mut String, k: usize) { + // Helper macro for inverse trig/hyp functions + // All follow the pattern: compute derivative factor g, then integrate + let inv_trig_fns = [ + ("atan", "1.0 + asq", false, "atan(a.v[0])", false), // g = 1/(1+a²) + ("asin", "1.0 - asq", true, "asin(a.v[0])", false), // g = 1/sqrt(1-a²) + ("acos", "1.0 - asq", true, "acos(a.v[0])", true), // g = -1/sqrt(1-a²) + ("asinh", "1.0 + asq", true, "asinh(a.v[0])", false), // g = 1/sqrt(1+a²) + ("acosh", "asq - 1.0", true, "acosh(a.v[0])", false), // g = 1/sqrt(a²-1) + ("atanh", "1.0 - asq", false, "atanh(a.v[0])", false), // g = 1/(1-a²) + ]; + + for (name, d_expr, use_sqrt, c0_expr, negate) in &inv_trig_fns { + writeln!(s, "__device__ JetK jet_{name}(JetK a) {{").unwrap(); + writeln!(s, " JetK asq = jet_mul(a, a);").unwrap(); + writeln!(s, " JetK d;").unwrap(); + + // Build d from expression template + if d_expr.starts_with("1.0 +") { + writeln!(s, " d.v[0] = (F)1 + asq.v[0];").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap(); + } + } else if d_expr.starts_with("1.0 -") { + writeln!(s, " d.v[0] = (F)1 - asq.v[0];").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap(); + } + } else { + // asq - 1.0 + writeln!(s, " d.v[0] = asq.v[0] - (F)1;").unwrap(); + for i in 1..k { + writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap(); + } + } + + if *use_sqrt { + writeln!(s, " JetK g = jet_recip(jet_sqrt(d));").unwrap(); + } else { + writeln!(s, " JetK g = jet_recip(d);").unwrap(); + } + + writeln!(s, " JetK c;").unwrap(); + writeln!(s, " c.v[0] = {c0_expr};").unwrap(); + let sign_str = if *negate { "-" } else { "" }; + for i in 1..k { + let inv_i = 1.0 / i as f64; + let mut terms = Vec::new(); + for j in 1..=i { + terms.push(format!("(F){:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j)); + } + writeln!( + s, + " c.v[{i}] = {sign_str}(F){inv_i:.10} * ({});", + terms.join(" + ") + ) + .unwrap(); + } + writeln!(s, " return c;\n}}\n").unwrap(); + } +} + +fn write_cuda_main_kernel(s: &mut String, k: usize) { + writeln!(s, "extern \"C\" __global__ void taylor_forward_kth(").unwrap(); + writeln!(s, " const unsigned int* __restrict__ opcodes,").unwrap(); + writeln!(s, " const unsigned int* __restrict__ arg0,").unwrap(); + writeln!(s, " const unsigned int* __restrict__ arg1,").unwrap(); + writeln!(s, " const F* __restrict__ constants,").unwrap(); + writeln!(s, " const F* __restrict__ primal_inputs,").unwrap(); + writeln!(s, " const F* __restrict__ direction_seeds,").unwrap(); + writeln!(s, " F* __restrict__ jets,").unwrap(); + writeln!(s, " F* __restrict__ jet_outputs,").unwrap(); + writeln!(s, " const unsigned int* __restrict__ output_indices,").unwrap(); + writeln!(s, " unsigned int num_ops,").unwrap(); + writeln!(s, " unsigned int num_inputs,").unwrap(); + writeln!(s, " unsigned int num_variables,").unwrap(); + writeln!(s, " unsigned int num_outputs,").unwrap(); + writeln!(s, " unsigned int batch_size").unwrap(); + writeln!(s, ") {{").unwrap(); + writeln!( + s, + " unsigned int bid = blockIdx.x * blockDim.x + threadIdx.x;" + ) + .unwrap(); + writeln!(s, " if (bid >= batch_size) return;").unwrap(); + writeln!(s, " const unsigned int K = {k};").unwrap(); + writeln!(s, " unsigned int j_base = bid * num_variables * K;").unwrap(); + + // Initialize + writeln!(s, " for (unsigned int i = 0; i < num_variables; i++) {{").unwrap(); + writeln!(s, " unsigned int off = j_base + i * K;").unwrap(); + writeln!(s, " jets[off] = constants[i];").unwrap(); + for c in 1..k { + writeln!(s, " jets[off + {c}] = (F)0;").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // Set inputs + writeln!(s, " unsigned int in_base = bid * num_inputs;").unwrap(); + writeln!(s, " for (unsigned int i = 0; i < num_inputs; i++) {{").unwrap(); + writeln!(s, " unsigned int off = j_base + i * K;").unwrap(); + writeln!(s, " jets[off] = primal_inputs[in_base + i];").unwrap(); + if k > 1 { + writeln!(s, " jets[off + 1] = direction_seeds[in_base + i];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // Tape walk + writeln!( + s, + " for (unsigned int i = num_inputs; i < num_ops; i++) {{" + ) + .unwrap(); + writeln!(s, " unsigned int op = opcodes[i];").unwrap(); + writeln!(s, " if (op == OP_CONST) continue;").unwrap(); + writeln!(s, " unsigned int a_idx = arg0[i];").unwrap(); + writeln!(s, " unsigned int b_idx = arg1[i];").unwrap(); + + // Load a + writeln!(s, " JetK a;").unwrap(); + writeln!(s, " unsigned int a_off = j_base + a_idx * K;").unwrap(); + for c in 0..k { + writeln!(s, " a.v[{c}] = jets[a_off + {c}];").unwrap(); + } + writeln!(s, " JetK r;").unwrap(); + + // Switch on opcode + writeln!(s, " switch (op) {{").unwrap(); + + // Binary ops + for (case, func) in &[ + (2, "jet_add"), + (3, "jet_sub"), + (4, "jet_mul"), + (5, "jet_div"), + ] { + writeln!(s, " case {case}: {{").unwrap(); + writeln!( + s, + " JetK b; unsigned int b_off = j_base + b_idx * K;" + ) + .unwrap(); + for c in 0..k { + writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap(); + } + writeln!(s, " r = {func}(a, b); break;").unwrap(); + writeln!(s, " }}").unwrap(); + } + + // REM + writeln!(s, " case 6: {{").unwrap(); + writeln!(s, " F b_val = jets[j_base + b_idx * K];").unwrap(); + writeln!( + s, + " r = jet_const(a.v[0] - trunc(a.v[0] / b_val) * b_val); break;" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // POWF + writeln!(s, " case 7: {{").unwrap(); + writeln!( + s, + " JetK b; unsigned int b_off = j_base + b_idx * K;" + ) + .unwrap(); + for c in 0..k { + writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap(); + } + writeln!(s, " JetK lna = jet_ln(a);").unwrap(); + writeln!(s, " JetK product = jet_mul(b, lna);").unwrap(); + writeln!(s, " r = jet_exp(product);").unwrap(); + writeln!(s, " r.v[0] = pow(a.v[0], b.v[0]); break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // ATAN2 + writeln!(s, " case 8: {{").unwrap(); + writeln!( + s, + " JetK b; unsigned int b_off = j_base + b_idx * K;" + ) + .unwrap(); + for c in 0..k { + writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap(); + } + writeln!(s, " JetK ratio = jet_div(a, b);").unwrap(); + writeln!(s, " r = jet_atan(ratio);").unwrap(); + writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]); break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // HYPOT + writeln!(s, " case 9: {{").unwrap(); + writeln!( + s, + " JetK b; unsigned int b_off = j_base + b_idx * K;" + ) + .unwrap(); + for c in 0..k { + writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap(); + } + writeln!( + s, + " r = jet_sqrt(jet_add(jet_mul(a,a), jet_mul(b,b)));" + ) + .unwrap(); + writeln!( + s, + " r.v[0] = sqrt(a.v[0]*a.v[0] + b.v[0]*b.v[0]); break;" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // MAX, MIN + for (case, cmp) in &[(10, ">="), (11, "<=")] { + writeln!(s, " case {case}: {{").unwrap(); + writeln!( + s, + " JetK b; unsigned int b_off = j_base + b_idx * K;" + ) + .unwrap(); + for c in 0..k { + writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap(); + } + writeln!(s, " r = (a.v[0] {cmp} b.v[0]) ? a : b; break;").unwrap(); + writeln!(s, " }}").unwrap(); + } + + // Unary + writeln!(s, " case 12: r = jet_neg(a); break;").unwrap(); + writeln!(s, " case 13: r = jet_recip(a); break;").unwrap(); + writeln!(s, " case 14: r = jet_sqrt(a); break;").unwrap(); + + // CBRT + writeln!(s, " case 15: {{").unwrap(); + writeln!(s, " F sg = _sign(a.v[0]);").unwrap(); + writeln!(s, " JetK abs_a; abs_a.v[0] = fabs(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap(); + } + writeln!( + s, + " JetK e = jet_exp(jet_scale(jet_ln(abs_a), (F)(1.0/3.0)));" + ) + .unwrap(); + for i in 0..k { + writeln!(s, " r.v[{i}] = sg * e.v[{i}];").unwrap(); + } + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // POWI + writeln!(s, " case 16: {{").unwrap(); + writeln!(s, " F n = (F)((int)b_idx);").unwrap(); + writeln!(s, " if (n == (F)0) {{ r = jet_const((F)1); }}").unwrap(); + writeln!(s, " else if (n == (F)1) {{ r = a; }}").unwrap(); + writeln!( + s, + " else {{ r = jet_exp(jet_scale(jet_ln(a), n)); r.v[0] = pow(a.v[0], n); }}" + ) + .unwrap(); + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // Transcendental + writeln!(s, " case 17: r = jet_exp(a); break;").unwrap(); + writeln!(s, " case 18: {{ r = jet_exp(jet_scale(a, log((F)2))); r.v[0] = exp2(a.v[0]); break; }}").unwrap(); + writeln!( + s, + " case 19: {{ r = jet_exp(a); r.v[0] = exp(a.v[0]) - (F)1; break; }}" + ) + .unwrap(); + writeln!(s, " case 20: r = jet_ln(a); break;").unwrap(); + + // LOG2 + writeln!(s, " case 21: {{").unwrap(); + writeln!(s, " r = jet_ln(a);").unwrap(); + writeln!(s, " F inv_ln2 = (F)1 / log((F)2);").unwrap(); + writeln!(s, " r.v[0] = log2(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] *= inv_ln2;").unwrap(); + } + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // LOG10 + writeln!(s, " case 22: {{").unwrap(); + writeln!(s, " r = jet_ln(a);").unwrap(); + writeln!(s, " F inv_ln10 = (F)1 / log((F)10);").unwrap(); + writeln!(s, " r.v[0] = log(a.v[0]) * inv_ln10;").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] *= inv_ln10;").unwrap(); + } + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // LN1P + writeln!(s, " case 23: {{").unwrap(); + writeln!(s, " JetK opa; opa.v[0] = (F)1 + a.v[0];").unwrap(); + for i in 1..k { + writeln!(s, " opa.v[{i}] = a.v[{i}];").unwrap(); + } + writeln!( + s, + " r = jet_ln(opa); r.v[0] = log((F)1 + a.v[0]); break;" + ) + .unwrap(); + writeln!(s, " }}").unwrap(); + + // Sin, Cos, Tan, Asin, Acos, Atan, Sinh, Cosh, Tanh, Asinh, Acosh, Atanh + writeln!( + s, + " case 24: {{ JetPair sc = jet_sin_cos(a); r = sc.a; break; }}" + ) + .unwrap(); + writeln!( + s, + " case 25: {{ JetPair sc = jet_sin_cos(a); r = sc.b; break; }}" + ) + .unwrap(); + writeln!(s, " case 26: r = jet_tan(a); break;").unwrap(); + writeln!(s, " case 27: r = jet_asin(a); break;").unwrap(); + writeln!(s, " case 28: r = jet_acos(a); break;").unwrap(); + writeln!(s, " case 29: r = jet_atan(a); break;").unwrap(); + writeln!( + s, + " case 30: {{ JetPair sc = jet_sinh_cosh(a); r = sc.a; break; }}" + ) + .unwrap(); + writeln!( + s, + " case 31: {{ JetPair sc = jet_sinh_cosh(a); r = sc.b; break; }}" + ) + .unwrap(); + writeln!(s, " case 32: r = jet_tanh(a); break;").unwrap(); + writeln!(s, " case 33: r = jet_asinh(a); break;").unwrap(); + writeln!(s, " case 34: r = jet_acosh(a); break;").unwrap(); + writeln!(s, " case 35: r = jet_atanh(a); break;").unwrap(); + + // ABS + writeln!(s, " case 36: {{").unwrap(); + writeln!(s, " F sg = _sign(a.v[0]);").unwrap(); + writeln!(s, " r.v[0] = fabs(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = sg * a.v[{i}];").unwrap(); + } + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // SIGNUM..TRUNC + writeln!(s, " case 37: r = jet_const(_sign(a.v[0])); break;").unwrap(); + writeln!(s, " case 38: r = jet_const(floor(a.v[0])); break;").unwrap(); + writeln!(s, " case 39: r = jet_const(ceil(a.v[0])); break;").unwrap(); + writeln!(s, " case 40: r = jet_const(round(a.v[0])); break;").unwrap(); + writeln!(s, " case 41: r = jet_const(trunc(a.v[0])); break;").unwrap(); + + // FRACT + writeln!(s, " case 42: {{").unwrap(); + writeln!(s, " r.v[0] = _fract(a.v[0]);").unwrap(); + for i in 1..k { + writeln!(s, " r.v[{i}] = a.v[{i}];").unwrap(); + } + writeln!(s, " break;").unwrap(); + writeln!(s, " }}").unwrap(); + + writeln!(s, " default: break;").unwrap(); + writeln!(s, " }}").unwrap(); + + // Store result + writeln!(s, " unsigned int r_off = j_base + i * K;").unwrap(); + for c in 0..k { + writeln!(s, " jets[r_off + {c}] = r.v[{c}];").unwrap(); + } + writeln!(s, " }}").unwrap(); + + // Write outputs + writeln!(s, " unsigned int out_base = bid * num_outputs * K;").unwrap(); + writeln!(s, " for (unsigned int j = 0; j < num_outputs; j++) {{").unwrap(); + writeln!(s, " unsigned int oi = output_indices[j];").unwrap(); + writeln!(s, " unsigned int src = j_base + oi * K;").unwrap(); + writeln!(s, " unsigned int dst = out_base + j * K;").unwrap(); + for c in 0..k { + writeln!(s, " jet_outputs[dst + {c}] = jets[src + {c}];").unwrap(); + } + writeln!(s, " }}").unwrap(); + writeln!(s, "}}").unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wgsl_k3_compiles() { + let shader = generate_taylor_wgsl(3); + assert!(shader.contains("struct JetK { v: array, }")); + assert!(shader.contains("fn jet_mul")); + assert!(shader.contains("fn jet_exp")); + assert!(shader.contains("fn main")); + } + + #[test] + fn cuda_k3_compiles() { + let kernel = generate_taylor_cuda(3); + assert!(kernel.contains("F v[3]")); + assert!(kernel.contains("jet_mul")); + assert!(kernel.contains("jet_exp")); + assert!(kernel.contains("taylor_forward_kth")); + } + + #[test] + fn all_k_values_generate() { + for k in 1..=5 { + let wgsl = generate_taylor_wgsl(k); + assert!(wgsl.contains(&format!("array"))); + let cuda = generate_taylor_cuda(k); + assert!(cuda.contains(&format!("F v[{k}]"))); + } + } +} diff --git a/src/gpu/wgpu_backend.rs b/src/gpu/wgpu_backend.rs index c00f5fc..3110d51 100644 --- a/src/gpu/wgpu_backend.rs +++ b/src/gpu/wgpu_backend.rs @@ -26,6 +26,9 @@ pub struct WgpuContext { tangent_fwd_pipeline: wgpu::ComputePipeline, tangent_rev_pipeline: wgpu::ComputePipeline, taylor_fwd_2nd_pipeline: wgpu::ComputePipeline, + /// K-specialized Taylor forward pipelines for K=1..5 (index = K-1). + #[cfg(feature = "stde")] + taylor_fwd_kth_pipelines: [wgpu::ComputePipeline; 5], tape_bind_group_layout: wgpu::BindGroupLayout, forward_io_bind_group_layout: wgpu::BindGroupLayout, reverse_io_bind_group_layout: wgpu::BindGroupLayout, @@ -264,6 +267,28 @@ impl WgpuContext { cache: None, }); + // Compile K-specialized Taylor forward pipelines for K=1..5 + #[cfg(feature = "stde")] + let taylor_fwd_kth_pipelines = { + use super::taylor_codegen::generate_taylor_wgsl; + std::array::from_fn(|idx| { + let k = idx + 1; + let wgsl_src = generate_taylor_wgsl(k); + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some(&format!("echidna_taylor_fwd_k{k}_shader")), + source: wgpu::ShaderSource::Wgsl(wgsl_src.into()), + }); + device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some(&format!("echidna_taylor_fwd_k{k}_pipeline")), + layout: Some(&taylor2_layout), + module: &shader, + entry_point: Some("main"), + compilation_options: Default::default(), + cache: None, + }) + }) + }; + Some(WgpuContext { device, queue, @@ -272,6 +297,8 @@ impl WgpuContext { tangent_fwd_pipeline, tangent_rev_pipeline, taylor_fwd_2nd_pipeline, + #[cfg(feature = "stde")] + taylor_fwd_kth_pipelines, tape_bind_group_layout, forward_io_bind_group_layout, reverse_io_bind_group_layout, @@ -319,26 +346,33 @@ impl WgpuContext { }) } - /// Batched second-order Taylor forward propagation on GPU. + /// Batched K-th order Taylor forward propagation on GPU. /// - /// Each batch element pushes one direction through the tape, producing - /// a Taylor jet with 3 coefficients (c0=value, c1=first derivative, - /// c2=second derivative / 2). + /// Supports `order` (K) from 1 to 5. Each batch element pushes one direction + /// through the tape, producing a Taylor jet with K coefficients. /// - /// `primal_inputs` is `[f32; batch_size * num_inputs]` — primals for each element. - /// `direction_seeds` is `[f32; batch_size * num_inputs]` — c1 seeds for each element. + /// `primal_inputs` is `[f32; batch_size * num_inputs]`. + /// `direction_seeds` is `[f32; batch_size * num_inputs]` — only c1 seeds are used. /// - /// Returns `TaylorBatchResult` with `values`, `c1s`, `c2s` each of size - /// `[f32; batch_size * num_outputs]`. - pub fn taylor_forward_2nd_batch( + /// Returns `TaylorKthBatchResult` with K coefficient vectors. + #[cfg(feature = "stde")] + pub fn taylor_forward_kth_batch( &self, tape: &WgpuTapeBuffers, primal_inputs: &[f32], direction_seeds: &[f32], batch_size: u32, - ) -> Result, GpuError> { + order: usize, + ) -> Result, GpuError> { use wgpu::util::DeviceExt; + if !(1..=5).contains(&order) { + return Err(GpuError::Other(format!( + "unsupported Taylor order {order}, must be 1..=5" + ))); + } + + let k = order as u32; let ni = tape.num_inputs; let nv = tape.num_variables; let no = tape.num_outputs; @@ -366,7 +400,7 @@ impl WgpuContext { let meta_buf = self .device .create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("taylor2_meta"), + label: Some("taylor_kth_meta"), contents: bytemuck::bytes_of(&meta), usage: wgpu::BufferUsages::UNIFORM, }); @@ -374,39 +408,39 @@ impl WgpuContext { let primal_buf = self .device .create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("taylor2_primals"), + label: Some("taylor_kth_primals"), contents: bytemuck::cast_slice(primal_inputs), usage: wgpu::BufferUsages::STORAGE, }); let seed_buf = self .device .create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("taylor2_seeds"), + label: Some("taylor_kth_seeds"), contents: bytemuck::cast_slice(direction_seeds), usage: wgpu::BufferUsages::STORAGE, }); - // Jets working buffer: B * nv * 3 floats - let jets_size = (batch_size as u64) * (nv as u64) * 3 * 4; + // Jets working buffer: B * nv * K floats + let jets_size = (batch_size as u64) * (nv as u64) * (k as u64) * 4; let jets_buf = self.device.create_buffer(&wgpu::BufferDescriptor { - label: Some("taylor2_jets"), + label: Some("taylor_kth_jets"), size: jets_size, usage: wgpu::BufferUsages::STORAGE, mapped_at_creation: false, }); - // Jet outputs: B * n_out * 3 floats (interleaved c0,c1,c2) - let out_count = (batch_size as u64) * (no as u64) * 3; + // Jet outputs: B * n_out * K floats + let out_count = (batch_size as u64) * (no as u64) * (k as u64); let out_size = out_count * 4; let jet_out_buf = self.device.create_buffer(&wgpu::BufferDescriptor { - label: Some("taylor2_jet_out"), + label: Some("taylor_kth_jet_out"), size: out_size, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, mapped_at_creation: false, }); let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor { - label: Some("taylor2_staging"), + label: Some("taylor_kth_staging"), size: out_size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, @@ -415,7 +449,7 @@ impl WgpuContext { let tape_bg = self.create_tape_bind_group(tape, &meta_buf); let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - label: Some("taylor2_io_bg"), + label: Some("taylor_kth_io_bg"), layout: &self.taylor_fwd_2nd_io_bind_group_layout, entries: &[ wgpu::BindGroupEntry { @@ -440,15 +474,15 @@ impl WgpuContext { let mut encoder = self .device .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("taylor2_enc"), + label: Some("taylor_kth_enc"), }); { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { - label: Some("taylor2_pass"), + label: Some("taylor_kth_pass"), timestamp_writes: None, }); - pass.set_pipeline(&self.taylor_fwd_2nd_pipeline); + pass.set_pipeline(&self.taylor_fwd_kth_pipelines[order - 1]); pass.set_bind_group(0, &tape_bg, &[]); pass.set_bind_group(1, &io_bg, &[]); pass.dispatch_workgroups(batch_size.div_ceil(256), 1, 1); @@ -474,22 +508,24 @@ impl WgpuContext { let data = slice.get_mapped_range(); let raw: &[f32] = bytemuck::cast_slice(&data); - // Deinterleave: raw is [c0_0, c1_0, c2_0, c0_1, c1_1, c2_1, ...] per batch×output + // Deinterleave: raw is [c0, c1, ..., c_{K-1}] per output per batch element let total_out = (batch_size * no) as usize; - let mut values = Vec::with_capacity(total_out); - let mut c1s = Vec::with_capacity(total_out); - let mut c2s = Vec::with_capacity(total_out); + let mut coefficients: Vec> = + (0..order).map(|_| Vec::with_capacity(total_out)).collect(); for i in 0..total_out { - values.push(raw[i * 3]); - c1s.push(raw[i * 3 + 1]); - c2s.push(raw[i * 3 + 2]); + for c in 0..order { + coefficients[c].push(raw[i * order + c]); + } } drop(data); staging_buf.unmap(); - Ok(super::TaylorBatchResult { values, c1s, c2s }) + Ok(super::TaylorKthBatchResult { + coefficients, + order, + }) } } @@ -1381,6 +1417,169 @@ impl GpuBackend for WgpuContext { Ok((value, gradient, pattern, hess_values)) } + + #[cfg(feature = "stde")] + fn taylor_forward_2nd_batch( + &self, + tape: &WgpuTapeBuffers, + primal_inputs: &[f32], + direction_seeds: &[f32], + batch_size: u32, + ) -> Result, GpuError> { + use wgpu::util::DeviceExt; + + let ni = tape.num_inputs; + let nv = tape.num_variables; + let no = tape.num_outputs; + let total_inputs = (batch_size * ni) as usize; + + assert_eq!( + primal_inputs.len(), + total_inputs, + "primal_inputs length mismatch" + ); + assert_eq!( + direction_seeds.len(), + total_inputs, + "direction_seeds length mismatch" + ); + + let meta = TapeMeta { + num_ops: tape.num_ops, + num_inputs: ni, + num_variables: nv, + num_outputs: no, + batch_size, + _pad: [0; 3], + }; + let meta_buf = self + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("taylor2_meta"), + contents: bytemuck::bytes_of(&meta), + usage: wgpu::BufferUsages::UNIFORM, + }); + + let primal_buf = self + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("taylor2_primals"), + contents: bytemuck::cast_slice(primal_inputs), + usage: wgpu::BufferUsages::STORAGE, + }); + let seed_buf = self + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("taylor2_seeds"), + contents: bytemuck::cast_slice(direction_seeds), + usage: wgpu::BufferUsages::STORAGE, + }); + + // Jets working buffer: B * nv * 3 floats + let jets_size = (batch_size as u64) * (nv as u64) * 3 * 4; + let jets_buf = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("taylor2_jets"), + size: jets_size, + usage: wgpu::BufferUsages::STORAGE, + mapped_at_creation: false, + }); + + // Jet outputs: B * n_out * 3 floats (interleaved c0,c1,c2) + let out_count = (batch_size as u64) * (no as u64) * 3; + let out_size = out_count * 4; + let jet_out_buf = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("taylor2_jet_out"), + size: out_size, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("taylor2_staging"), + size: out_size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + let tape_bg = self.create_tape_bind_group(tape, &meta_buf); + + let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("taylor2_io_bg"), + layout: &self.taylor_fwd_2nd_io_bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: primal_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: seed_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: jets_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: jet_out_buf.as_entire_binding(), + }, + ], + }); + + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("taylor2_enc"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("taylor2_pass"), + timestamp_writes: None, + }); + pass.set_pipeline(&self.taylor_fwd_2nd_pipeline); + pass.set_bind_group(0, &tape_bg, &[]); + pass.set_bind_group(1, &io_bg, &[]); + pass.dispatch_workgroups(batch_size.div_ceil(256), 1, 1); + } + + encoder.copy_buffer_to_buffer(&jet_out_buf, 0, &staging_buf, 0, out_size); + let sub_idx = self.queue.submit(std::iter::once(encoder.finish())); + + let slice = staging_buf.slice(..); + let (tx, rx) = std::sync::mpsc::channel(); + slice.map_async(wgpu::MapMode::Read, move |result| { + let _ = tx.send(result); + }); + let _ = self.device.poll(wgpu::PollType::Wait { + submission_index: Some(sub_idx), + timeout: None, + }); + + rx.recv() + .map_err(|e| GpuError::Other(format!("channel recv failed: {e}")))? + .map_err(|e| GpuError::Other(format!("buffer map failed: {e}")))?; + + let data = slice.get_mapped_range(); + let raw: &[f32] = bytemuck::cast_slice(&data); + + // Deinterleave: raw is [c0_0, c1_0, c2_0, c0_1, c1_1, c2_1, ...] per batch×output + let total_out = (batch_size * no) as usize; + let mut values = Vec::with_capacity(total_out); + let mut c1s = Vec::with_capacity(total_out); + let mut c2s = Vec::with_capacity(total_out); + + for i in 0..total_out { + values.push(raw[i * 3]); + c1s.push(raw[i * 3 + 1]); + c2s.push(raw[i * 3 + 2]); + } + + drop(data); + staging_buf.unmap(); + + Ok(super::TaylorBatchResult { values, c1s, c2s }) + } } // ── Helpers ── diff --git a/src/stde/mod.rs b/src/stde/mod.rs index 315dc40..74368ca 100644 --- a/src/stde/mod.rs +++ b/src/stde/mod.rs @@ -71,6 +71,12 @@ //! `v = L · z` is computed, then pushed through a second-order Taylor jet. When //! `L = I`, this reduces to the Hutchinson Laplacian estimator. //! +//! # Dense STDE for Indefinite Operators (requires `nalgebra`) +//! +//! [`dense_stde_2nd_indefinite`] handles arbitrary symmetric C matrices by +//! eigendecomposing into positive and negative parts. Near-zero eigenvalues are +//! clamped to prevent sign-flipping from floating-point noise. +//! //! # Parabolic PDE σ-Transform //! //! [`parabolic_diffusion`] computes `½ tr(σσ^T · Hess u)` for parabolic PDEs @@ -107,6 +113,8 @@ pub use laplacian::{ hessian_diagonal, hessian_diagonal_with_buf, laplacian, laplacian_hutchpp, laplacian_with_control, laplacian_with_stats, }; +#[cfg(feature = "nalgebra")] +pub use pde::dense_stde_2nd_indefinite; pub use pde::{dense_stde_2nd, divergence, parabolic_diffusion, parabolic_diffusion_stochastic}; pub use pipeline::{estimate, estimate_weighted}; pub use types::{DivergenceResult, EstimatorResult}; diff --git a/src/stde/pde.rs b/src/stde/pde.rs index e60554e..d555e79 100644 --- a/src/stde/pde.rs +++ b/src/stde/pde.rs @@ -4,6 +4,42 @@ use crate::bytecode_tape::BytecodeTape; use crate::dual::Dual; use crate::Float; +/// Inner loop for STDE: apply mat-vec, push through Taylor jet, Welford-accumulate. +/// +/// `matvec` takes `(z, v)` and writes `v = M · z` for some matrix M. +/// Each z-vector produces a sample `2 * c2` where `c2 = v^T H v / 2`. +fn stde_2nd_inner( + tape: &BytecodeTape, + x: &[F], + z_vectors: &[&[F]], + matvec: impl Fn(&[F], &mut [F]), +) -> EstimatorResult { + let n = tape.num_inputs(); + let two = F::from(2.0).unwrap(); + let mut buf = Vec::new(); + let mut v = vec![F::zero(); n]; + let mut value = F::zero(); + let mut acc = WelfordAccumulator::new(); + + for z in z_vectors.iter() { + assert_eq!(z.len(), n, "z_vector length must match tape.num_inputs()"); + matvec(z, &mut v); + let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, &v, &mut buf); + value = c0; + acc.update(two * c2); + } + + let (estimate, sample_variance, standard_error) = acc.finalize(); + + EstimatorResult { + value, + estimate, + sample_variance, + standard_error, + num_samples: z_vectors.len(), + } +} + /// Estimate the divergence (trace of Jacobian) of a vector field f: R^n → R^n. /// /// Uses Hutchinson's trace estimator with first-order forward-mode AD (`Dual`). @@ -220,30 +256,199 @@ pub fn dense_stde_2nd( "cholesky_rows.len() must match tape.num_inputs()" ); - let two = F::from(2.0).unwrap(); + stde_2nd_inner(tape, x, z_vectors, |z, v| { + // Lower-triangular mat-vec: v = L · z + for i in 0..n { + let row = cholesky_rows[i]; + let mut sum = F::zero(); + for j in 0..=i { + sum = sum + row[j] * z[j]; + } + v[i] = sum; + } + }) +} + +/// Dense STDE for a possibly-indefinite 2nd-order operator. +/// +/// Given a symmetric (possibly indefinite) coefficient matrix `C`, estimates +/// `tr(C · H_u(x)) = Σ_{ij} C_{ij} ∂²u/∂x_i∂x_j`. +/// +/// # Algorithm +/// +/// 1. Eigendecompose: `C = Q Λ Qᵀ` via `nalgebra::SymmetricEigen`. +/// 2. Clamp near-zero eigenvalues: `|λᵢ| < ε` → 0, where `ε = eps_factor * max(|λᵢ|)`. +/// 3. Split: `λ⁺ᵢ = max(λᵢ, 0)`, `λ⁻ᵢ = max(-λᵢ, 0)`. +/// 4. Form square-root factors: `L⁺ = Q · diag(√λ⁺)`, `L⁻ = Q · diag(√λ⁻)`. +/// 5. For each z-vector, compute `v⁺ = L⁺·z`, `v⁻ = L⁻·z`, push through Taylor jets, +/// and accumulate `2·c2⁺ - 2·c2⁻`. +/// +/// If all eigenvalues are non-negative (or non-positive), the negative (or positive) +/// half is skipped entirely for efficiency. +/// +/// # Arguments +/// +/// - `c_matrix`: symmetric `n×n` coefficient matrix (row-major `DMatrix`). +/// - `z_vectors`: standard Gaussian random vectors, each of length n. +/// - `eps_factor`: relative threshold for clamping near-zero eigenvalues. +/// Typical value: `1e-12`. Eigenvalues with `|λ| < eps_factor * max(|λ|)` are +/// treated as zero to prevent sign-flipping from floating-point noise. +/// +/// # Panics +/// +/// Panics if `z_vectors` is empty, `c_matrix` is not square with dimension +/// matching `tape.num_inputs()`, or any z-vector has the wrong length. +#[cfg(feature = "nalgebra")] +pub fn dense_stde_2nd_indefinite( + tape: &BytecodeTape, + x: &[f64], + c_matrix: &nalgebra::DMatrix, + z_vectors: &[&[f64]], + eps_factor: f64, +) -> EstimatorResult { + assert!(!z_vectors.is_empty(), "z_vectors must not be empty"); + let n = tape.num_inputs(); + assert_eq!(x.len(), n, "x.len() must match tape.num_inputs()"); + assert_eq!(c_matrix.nrows(), n, "c_matrix rows must match num_inputs"); + assert_eq!(c_matrix.ncols(), n, "c_matrix cols must match num_inputs"); + + // Eigendecompose C = Q Λ Qᵀ + let eigen = nalgebra::SymmetricEigen::new(c_matrix.clone()); + let eigenvalues = &eigen.eigenvalues; + let q = &eigen.eigenvectors; + + // Epsilon threshold based on largest eigenvalue magnitude + let max_abs = eigenvalues.iter().fold(0.0f64, |m, &v| m.max(v.abs())); + let eps = eps_factor * max_abs; + + // Classify eigenvalues + let mut has_positive = false; + let mut has_negative = false; + let mut sqrt_pos = vec![0.0f64; n]; + let mut sqrt_neg = vec![0.0f64; n]; + + for i in 0..n { + let lam = eigenvalues[i]; + if lam > eps { + sqrt_pos[i] = lam.sqrt(); + has_positive = true; + } else if lam < -eps { + sqrt_neg[i] = (-lam).sqrt(); + has_negative = true; + } + // else: near-zero, both stay 0 + } + + // All eigenvalues are zero → result is zero + if !has_positive && !has_negative { + let mut buf = Vec::new(); + let v = vec![0.0f64; n]; + let (value, _, _) = taylor_jet_2nd_with_buf(tape, x, &v, &mut buf); + return EstimatorResult { + value, + estimate: 0.0, + sample_variance: 0.0, + standard_error: 0.0, + num_samples: z_vectors.len(), + }; + } + + // Build L⁺ = Q · diag(√λ⁺) and/or L⁻ = Q · diag(√λ⁻) + // Store as column-major n×n for mat-vec + let l_pos = if has_positive { + let mut m = nalgebra::DMatrix::zeros(n, n); + for j in 0..n { + if sqrt_pos[j] > 0.0 { + for i in 0..n { + m[(i, j)] = q[(i, j)] * sqrt_pos[j]; + } + } + } + Some(m) + } else { + None + }; + + let l_neg = if has_negative { + let mut m = nalgebra::DMatrix::zeros(n, n); + for j in 0..n { + if sqrt_neg[j] > 0.0 { + for i in 0..n { + m[(i, j)] = q[(i, j)] * sqrt_neg[j]; + } + } + } + Some(m) + } else { + None + }; + + // Optimization: if all eigenvalues have same sign, use single-pass + if has_positive && !has_negative { + let lp = l_pos.as_ref().unwrap(); + return stde_2nd_inner(tape, x, z_vectors, |z, v| { + // Full mat-vec: v = L⁺ · z + for i in 0..n { + let mut sum = 0.0; + for j in 0..n { + sum += lp[(i, j)] * z[j]; + } + v[i] = sum; + } + }); + } + if !has_positive && has_negative { + let ln = l_neg.as_ref().unwrap(); + // All negative: tr(C·H) = -E[v⁻ᵀ H v⁻], so negate + let mut result = stde_2nd_inner(tape, x, z_vectors, |z, v| { + for i in 0..n { + let mut sum = 0.0; + for j in 0..n { + sum += ln[(i, j)] * z[j]; + } + v[i] = sum; + } + }); + result.estimate = -result.estimate; + // variance and SE stay the same magnitude (negation doesn't change variance) + return result; + } + + // Mixed sign: need both passes per z-vector + let lp = l_pos.as_ref().unwrap(); + let ln = l_neg.as_ref().unwrap(); + let mut buf = Vec::new(); - let mut v = vec![F::zero(); n]; - let mut value = F::zero(); + let mut v_pos = vec![0.0f64; n]; + let mut v_neg = vec![0.0f64; n]; + let mut value = 0.0f64; let mut acc = WelfordAccumulator::new(); for z in z_vectors.iter() { assert_eq!(z.len(), n, "z_vector length must match tape.num_inputs()"); - // Compute v = L · z (lower-triangular mat-vec) + // v⁺ = L⁺ · z for i in 0..n { - let row = cholesky_rows[i]; - let mut sum = F::zero(); - // Only use lower-triangular entries: j <= i - for j in 0..=i { - sum = sum + row[j] * z[j]; + let mut sum = 0.0; + for j in 0..n { + sum += lp[(i, j)] * z[j]; } - v[i] = sum; + v_pos[i] = sum; + } + // v⁻ = L⁻ · z + for i in 0..n { + let mut sum = 0.0; + for j in 0..n { + sum += ln[(i, j)] * z[j]; + } + v_neg[i] = sum; } - let (c0, _, c2) = taylor_jet_2nd_with_buf(tape, x, &v, &mut buf); + let (c0, _, c2_pos) = taylor_jet_2nd_with_buf(tape, x, &v_pos, &mut buf); + let (_, _, c2_neg) = taylor_jet_2nd_with_buf(tape, x, &v_neg, &mut buf); value = c0; - // 2 * c2 = v^T H v, and E[v^T H v] = tr(C · H) when E[vv^T] = C - acc.update(two * c2); + // sample = 2·c2⁺ - 2·c2⁻ = v⁺ᵀHv⁺ - v⁻ᵀHv⁻ + acc.update(2.0 * c2_pos - 2.0 * c2_neg); } let (estimate, sample_variance, standard_error) = acc.finalize(); diff --git a/tests/gpu_stde.rs b/tests/gpu_stde.rs index f562a6a..2ae8141 100644 --- a/tests/gpu_stde.rs +++ b/tests/gpu_stde.rs @@ -546,6 +546,429 @@ fn gpu_polynomial_exact_hessian_diagonal() { assert!((diag[1] - 2.0).abs() < 1e-3, "diag[1]: {}", diag[1]); } +// ══════════════════════════════════════════════ +// Section 4: Chunked GPU Taylor dispatch tests +// ══════════════════════════════════════════════ + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_chunked_single_chunk() { + // When batch fits in one chunk, results match direct call + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [3.0_f64, 4.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let batch_size = 4u32; + let mut primals = Vec::new(); + let mut seeds = Vec::new(); + for b in 0..batch_size { + primals.extend_from_slice(&[3.0f32, 4.0]); + if b % 2 == 0 { + seeds.extend_from_slice(&[1.0f32, 0.0]); + } else { + seeds.extend_from_slice(&[0.0f32, 1.0]); + } + } + + let direct = ctx + .taylor_forward_2nd_batch(&tape_buf, &primals, &seeds, batch_size) + .unwrap(); + + // Use very large max_buffer_bytes so everything fits in one chunk + let chunked = echidna::gpu::taylor_forward_2nd_batch_chunked( + &ctx, + &tape_buf, + &primals, + &seeds, + batch_size, + gpu_data.num_inputs, + gpu_data.num_variables, + 1024 * 1024 * 1024, // 1 GiB + ) + .unwrap(); + + assert_eq!(direct.values.len(), chunked.values.len()); + for i in 0..direct.values.len() { + assert!( + (direct.values[i] - chunked.values[i]).abs() < 1e-6, + "values[{}] mismatch", + i + ); + assert!( + (direct.c1s[i] - chunked.c1s[i]).abs() < 1e-6, + "c1s[{}] mismatch", + i + ); + assert!( + (direct.c2s[i] - chunked.c2s[i]).abs() < 1e-6, + "c2s[{}] mismatch", + i + ); + } +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_chunked_multi_chunk() { + // Force multi-chunk by setting a tiny max_buffer_bytes + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [3.0_f64, 4.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let batch_size = 8u32; + let mut primals = Vec::new(); + let mut seeds = Vec::new(); + for b in 0..batch_size { + primals.extend_from_slice(&[3.0f32, 4.0]); + if b % 2 == 0 { + seeds.extend_from_slice(&[1.0f32, 0.0]); + } else { + seeds.extend_from_slice(&[0.0f32, 1.0]); + } + } + + let direct = ctx + .taylor_forward_2nd_batch(&tape_buf, &primals, &seeds, batch_size) + .unwrap(); + + // bytes_per_element = num_variables * 3 * 4 + // With a tiny limit, each chunk gets only ~2 elements + let bytes_per_element = (gpu_data.num_variables as u64) * 3 * 4; + let max_bytes = bytes_per_element * 2; // force chunks of 2 + + let chunked = echidna::gpu::taylor_forward_2nd_batch_chunked( + &ctx, + &tape_buf, + &primals, + &seeds, + batch_size, + gpu_data.num_inputs, + gpu_data.num_variables, + max_bytes, + ) + .unwrap(); + + assert_eq!(direct.values.len(), chunked.values.len()); + for i in 0..direct.values.len() { + assert!( + (direct.values[i] - chunked.values[i]).abs() < 1e-5, + "values[{}]: {} vs {}", + i, + direct.values[i], + chunked.values[i] + ); + assert!( + (direct.c1s[i] - chunked.c1s[i]).abs() < 1e-5, + "c1s[{}]: {} vs {}", + i, + direct.c1s[i], + chunked.c1s[i] + ); + assert!( + (direct.c2s[i] - chunked.c2s[i]).abs() < 1e-5, + "c2s[{}]: {} vs {}", + i, + direct.c2s[i], + chunked.c2s[i] + ); + } +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_chunked_exact_boundary() { + // Batch exactly fills one chunk (boundary condition) + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [2.0_f64, 3.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let batch_size = 4u32; + let mut primals = Vec::new(); + let mut seeds = Vec::new(); + for _ in 0..batch_size { + primals.extend_from_slice(&[2.0f32, 3.0]); + seeds.extend_from_slice(&[1.0f32, 0.0]); + } + + let bytes_per_element = (gpu_data.num_variables as u64) * 3 * 4; + let max_bytes = bytes_per_element * (batch_size as u64); // exact fit + + let result = echidna::gpu::taylor_forward_2nd_batch_chunked( + &ctx, + &tape_buf, + &primals, + &seeds, + batch_size, + gpu_data.num_inputs, + gpu_data.num_variables, + max_bytes, + ) + .unwrap(); + + assert_eq!(result.values.len(), batch_size as usize); + for v in &result.values { + assert!((v - 13.0).abs() < 1e-4, "value: {}", v); + } +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_chunked_zero_batch() { + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [1.0_f64, 2.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let result = echidna::gpu::taylor_forward_2nd_batch_chunked( + &ctx, + &tape_buf, + &[], + &[], + 0, + gpu_data.num_inputs, + gpu_data.num_variables, + 1024, + ) + .unwrap(); + + assert!(result.values.is_empty()); + assert!(result.c1s.is_empty()); + assert!(result.c2s.is_empty()); +} + +// ══════════════════════════════════════════════ +// Section 5: General-K Taylor forward tests +// ══════════════════════════════════════════════ + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_taylor_kth_polynomial_all_orders() { + // f(x,y) = x² + y², at (3,4), direction (1,0) + // c0 = 25, c1 = 6, c2 = 1, c3+ = 0 (polynomial of degree 2) + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [3.0_f64, 4.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + for order in 1..=5 { + let result = ctx + .taylor_forward_kth_batch(&tape_buf, &[3.0f32, 4.0], &[1.0f32, 0.0], 1, order) + .unwrap(); + + assert_eq!(result.order, order); + assert_eq!(result.coefficients.len(), order); + assert_eq!(result.coefficients[0].len(), 1); + + // c0 = 25 + assert!( + (result.coefficients[0][0] - 25.0).abs() < 1e-3, + "K={order} c0: {}", + result.coefficients[0][0] + ); + + if order >= 2 { + // c1 = 6 + assert!( + (result.coefficients[1][0] - 6.0).abs() < 1e-3, + "K={order} c1: {}", + result.coefficients[1][0] + ); + } + if order >= 3 { + // c2 = 1 (since d²/dt² (3+t)² = 2, and c2 = f''/2! = 1) + assert!( + (result.coefficients[2][0] - 1.0).abs() < 1e-3, + "K={order} c2: {}", + result.coefficients[2][0] + ); + } + if order >= 4 { + // c3 = 0 (polynomial degree 2) + assert!( + result.coefficients[3][0].abs() < 1e-3, + "K={order} c3: {}", + result.coefficients[3][0] + ); + } + if order >= 5 { + assert!( + result.coefficients[4][0].abs() < 1e-3, + "K={order} c4: {}", + result.coefficients[4][0] + ); + } + } +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_taylor_kth_k3_matches_2nd() { + // K=3 should match taylor_forward_2nd_batch exactly + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [1.5_f64, 2.5]; + let (tape, _) = record(|v| rosenbrock(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let primals = [1.5f32, 2.5]; + let seeds = [0.6f32, 0.8]; + + let result_2nd = ctx + .taylor_forward_2nd_batch(&tape_buf, &primals, &seeds, 1) + .unwrap(); + + let result_kth = ctx + .taylor_forward_kth_batch(&tape_buf, &primals, &seeds, 1, 3) + .unwrap(); + + assert_eq!(result_kth.order, 3); + assert!( + (result_2nd.values[0] - result_kth.coefficients[0][0]).abs() < 1e-4, + "c0: {} vs {}", + result_2nd.values[0], + result_kth.coefficients[0][0] + ); + assert!( + (result_2nd.c1s[0] - result_kth.coefficients[1][0]).abs() < 1e-4, + "c1: {} vs {}", + result_2nd.c1s[0], + result_kth.coefficients[1][0] + ); + assert!( + (result_2nd.c2s[0] - result_kth.coefficients[2][0]).abs() < 1e-3, + "c2: {} vs {}", + result_2nd.c2s[0], + result_kth.coefficients[2][0] + ); +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_taylor_kth_exp_higher_order() { + // f(x) = exp(x) at x=1, direction 1 + // c_k = exp(1) / k! for all k (since exp^(k) = exp) + // c0 = e, c1 = e, c2 = e/2, c3 = e/6, c4 = e/24 + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + fn f_exp(x: &[T]) -> T { + x[0].exp() + } + + let x = [1.0_f64]; + let (tape, _) = record(f_exp, &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + // Also compute CPU reference + let cpu_coeffs = echidna::stde::taylor_jet_dyn(&tape, &x, &[1.0], 5); + + let result = ctx + .taylor_forward_kth_batch(&tape_buf, &[1.0f32], &[1.0f32], 1, 5) + .unwrap(); + + let e = std::f64::consts::E; + let expected = [e, e, e / 2.0, e / 6.0, e / 24.0]; + + for (k, exp_val) in expected.iter().enumerate() { + let gpu_val = result.coefficients[k][0] as f64; + let tol = 0.05 * exp_val.abs(); + assert!( + (gpu_val - exp_val).abs() < tol.max(1e-2), + "K=5 c{k}: gpu={gpu_val} expected={exp_val} cpu={:.6}", + cpu_coeffs[k] + ); + } +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_taylor_kth_unsupported_order() { + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [1.0_f64]; + let (tape, _) = record(|v: &[echidna::BReverse]| v[0] * v[0], &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let result = ctx.taylor_forward_kth_batch(&tape_buf, &[1.0f32], &[1.0f32], 1, 6); + assert!(result.is_err()); +} + +#[cfg(all(feature = "gpu-wgpu", feature = "stde"))] +#[test] +fn gpu_taylor_kth_multi_batch() { + // Verify deinterleaving is correct with batch_size > 1 + // f(x,y) = x² + y², directions (1,0) and (0,1) + let ctx = match gpu_context() { + Some(c) => c, + None => return, + }; + + let x = [3.0_f64, 4.0]; + let (tape, _) = record(|v| polynomial(v), &x); + let gpu_data = GpuTapeData::from_tape_f64_lossy(&tape).unwrap(); + let tape_buf = ctx.upload_tape(&gpu_data); + + let primals = [3.0f32, 4.0, 3.0, 4.0]; + let seeds = [1.0f32, 0.0, 0.0, 1.0]; + + let result = ctx + .taylor_forward_kth_batch(&tape_buf, &primals, &seeds, 2, 4) + .unwrap(); + + assert_eq!(result.order, 4); + // Both batch elements: c0 = 25 + assert!((result.coefficients[0][0] - 25.0).abs() < 1e-3); + assert!((result.coefficients[0][1] - 25.0).abs() < 1e-3); + // Batch 0 dir (1,0): c1 = 2*3 = 6, c2 = 1 + assert!((result.coefficients[1][0] - 6.0).abs() < 1e-3); + assert!((result.coefficients[2][0] - 1.0).abs() < 1e-3); + // Batch 1 dir (0,1): c1 = 2*4 = 8, c2 = 1 + assert!((result.coefficients[1][1] - 8.0).abs() < 1e-3); + assert!((result.coefficients[2][1] - 1.0).abs() < 1e-3); + // c3 = 0 for both (polynomial degree 2) + assert!(result.coefficients[3][0].abs() < 1e-3); + assert!(result.coefficients[3][1].abs() < 1e-3); +} + #[cfg(all(feature = "gpu-wgpu", feature = "stde"))] #[test] fn gpu_trig_taylor_2nd() { diff --git a/tests/stde.rs b/tests/stde.rs index 0d6bddc..b255d40 100644 --- a/tests/stde.rs +++ b/tests/stde.rs @@ -1472,3 +1472,159 @@ mod sparse_stde_tests { assert_relative_eq!(result.estimate, expected, epsilon = 1e-6); } } + +// ══════════════════════════════════════════════ +// 22. Indefinite Dense STDE (requires stde + nalgebra) +// ══════════════════════════════════════════════ + +#[cfg(feature = "nalgebra")] +mod indefinite_stde_tests { + use super::*; + + /// PD matrix: verify result matches dense_stde_2nd with same z-vectors. + #[test] + fn indefinite_stde_matches_positive_definite() { + // C = [[4, 1], [1, 3]] (positive definite) + // Cholesky: L = [[2, 0], [0.5, sqrt(2.75)]] + let tape = record_fn(sum_of_squares, &[1.0, 2.0]); + let x = [1.0, 2.0]; + + let c = nalgebra::DMatrix::from_row_slice(2, 2, &[4.0, 1.0, 1.0, 3.0]); + + // Cholesky factor for comparison + let l00 = 2.0; + let l10 = 0.5; + let l11 = (3.0 - 0.25_f64).sqrt(); // sqrt(2.75) + let row0 = vec![l00, 0.0]; + let row1 = vec![l10, l11]; + let cholesky: Vec<&[f64]> = vec![&row0, &row1]; + + // Use Rademacher-like z-vectors + let z0 = vec![1.0, 1.0]; + let z1 = vec![1.0, -1.0]; + let z2 = vec![-1.0, 1.0]; + let z3 = vec![-1.0, -1.0]; + let z_vecs: Vec<&[f64]> = vec![&z0, &z1, &z2, &z3]; + + let chol_result = echidna::stde::dense_stde_2nd(&tape, &x, &cholesky, &z_vecs); + let indef_result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + + // H = diag(2, 2), tr(C·H) = 4*2 + 3*2 = 14 + assert_relative_eq!(chol_result.estimate, 14.0, epsilon = 1e-8); + assert_relative_eq!(indef_result.estimate, 14.0, epsilon = 1e-8); + } + + /// Diagonal indefinite C = diag(2, -3): verify tr(C·H) against analytical 2·H₀₀ - 3·H₁₁. + #[test] + fn indefinite_stde_diagonal_indefinite() { + // f(x,y) = x² + y², H = diag(2, 2) + // C = diag(2, -3), tr(C·H) = 2·2 + (-3)·2 = 4 - 6 = -2 + let tape = record_fn(sum_of_squares, &[1.0, 2.0]); + let x = [1.0, 2.0]; + + let c = nalgebra::DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, -3.0]); + + let z0 = vec![1.0, 1.0]; + let z1 = vec![1.0, -1.0]; + let z2 = vec![-1.0, 1.0]; + let z3 = vec![-1.0, -1.0]; + let z_vecs: Vec<&[f64]> = vec![&z0, &z1, &z2, &z3]; + + let result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + assert_relative_eq!(result.estimate, -2.0, epsilon = 1e-8); + } + + /// Full symmetric indefinite C, verify against tr(C·H) computed from dense Hessian. + #[test] + fn indefinite_stde_full_indefinite() { + // f(x,y,z) = x²y + y³, H at (1,2,3): + // H = [[2y, 2x, 0], [2x, 6y, 0], [0, 0, 0]] = [[4, 2, 0], [2, 12, 0], [0, 0, 0]] + let tape = record_fn(cubic_mix, &[1.0, 2.0, 3.0]); + let x = [1.0, 2.0, 3.0]; + + // C = [[1, 0, -1], [0, -2, 0], [-1, 0, 3]] — indefinite + let c = nalgebra::DMatrix::from_row_slice( + 3, + 3, + &[1.0, 0.0, -1.0, 0.0, -2.0, 0.0, -1.0, 0.0, 3.0], + ); + + // tr(C·H) = Σ_{ij} C_{ij} H_{ij} + // = 1·4 + 0·2 + (-1)·0 + 0·2 + (-2)·12 + 0·0 + (-1)·0 + 0·0 + 3·0 + // = 4 - 24 = -20 + let expected = -20.0; + + // Use many z-vectors for convergence (Rademacher-like from all sign combos) + let signs: Vec> = vec![ + vec![1.0, 1.0, 1.0], + vec![1.0, 1.0, -1.0], + vec![1.0, -1.0, 1.0], + vec![1.0, -1.0, -1.0], + vec![-1.0, 1.0, 1.0], + vec![-1.0, 1.0, -1.0], + vec![-1.0, -1.0, 1.0], + vec![-1.0, -1.0, -1.0], + ]; + let z_vecs: Vec<&[f64]> = signs.iter().map(|v| v.as_slice()).collect(); + + let result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + assert_relative_eq!(result.estimate, expected, epsilon = 1e-6); + } + + /// All-negative eigenvalues: C = diag(-2, -3), H = diag(2, 2). + /// tr(C·H) = -2·2 + (-3)·2 = -10. + #[test] + fn indefinite_stde_all_negative() { + let tape = record_fn(sum_of_squares, &[1.0, 2.0]); + let x = [1.0, 2.0]; + + let c = nalgebra::DMatrix::from_row_slice(2, 2, &[-2.0, 0.0, 0.0, -3.0]); + + let z0 = vec![1.0, 1.0]; + let z1 = vec![1.0, -1.0]; + let z2 = vec![-1.0, 1.0]; + let z3 = vec![-1.0, -1.0]; + let z_vecs: Vec<&[f64]> = vec![&z0, &z1, &z2, &z3]; + + let result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + assert_relative_eq!(result.estimate, -10.0, epsilon = 1e-8); + } + + /// C = 0: result should be 0. + #[test] + fn indefinite_stde_zero_matrix() { + let tape = record_fn(sum_of_squares, &[1.0, 2.0]); + let x = [1.0, 2.0]; + + let c = nalgebra::DMatrix::zeros(2, 2); + + let z0 = vec![1.0, 1.0]; + let z1 = vec![1.0, -1.0]; + let z_vecs: Vec<&[f64]> = vec![&z0, &z1]; + + let result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + assert_relative_eq!(result.estimate, 0.0, epsilon = 1e-10); + } + + /// C with eigenvalue ~1e-15: verify epsilon clamping prevents sign-flip. + #[test] + fn indefinite_stde_near_zero_eigenvalue() { + // C = [[1, 0], [0, 1e-15]] — the tiny eigenvalue should be clamped to zero + let tape = record_fn(sum_of_squares, &[1.0, 2.0]); + let x = [1.0, 2.0]; + + let c = nalgebra::DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1e-15]); + + let z0 = vec![1.0, 1.0]; + let z1 = vec![1.0, -1.0]; + let z2 = vec![-1.0, 1.0]; + let z3 = vec![-1.0, -1.0]; + let z_vecs: Vec<&[f64]> = vec![&z0, &z1, &z2, &z3]; + + // With eps_factor=1e-12, threshold = 1e-12 * 1.0 = 1e-12. + // The eigenvalue 1e-15 < 1e-12, so it's clamped to zero. + // Result should be tr(diag(1,0) · diag(2,2)) = 1·2 + 0·2 = 2 + let result = echidna::stde::dense_stde_2nd_indefinite(&tape, &x, &c, &z_vecs, 1e-12); + assert_relative_eq!(result.estimate, 2.0, epsilon = 1e-8); + } +}