diff --git a/CHANGELOG.md b/CHANGELOG.md index 534c784..77b977a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **simba traits for DualVec**: implemented traits `SimdValue`, `PrimitiveSimdValue`, `SubsetOf`, `AbsDiffEq`, `RelativeEq`, `UlpsEq`, `Field`, `ComplexField`, `RealField` for `DualVec`. + ## [0.5.0] - 2026-03-14 ### Added diff --git a/src/traits/simba_impls.rs b/src/traits/simba_impls.rs index 1ccd7ad..adf75a2 100644 --- a/src/traits/simba_impls.rs +++ b/src/traits/simba_impls.rs @@ -8,6 +8,7 @@ use simba::scalar::{ComplexField, Field, RealField, SubsetOf}; use simba::simd::{PrimitiveSimdValue, SimdValue}; use crate::dual::Dual; +use crate::dual_vec::DualVec; use crate::float::Float; use crate::reverse::Reverse; use crate::tape::TapeThreadLocal; @@ -57,6 +58,47 @@ impl SimdValue for Dual { impl PrimitiveSimdValue for Dual {} +impl SimdValue for DualVec { + const LANES: usize = 1; + type Element = Self; + type SimdBool = bool; + + #[inline(always)] + fn splat(val: Self::Element) -> Self { + val + } + #[inline(always)] + fn extract(&self, _: usize) -> Self::Element { + *self + } + #[inline(always)] + // SAFETY: This is a single-lane (LANES=1) scalar type, so the lane index + // is always 0 and the operation is trivially safe regardless of input. + unsafe fn extract_unchecked(&self, _: usize) -> Self::Element { + *self + } + #[inline(always)] + fn replace(&mut self, _: usize, val: Self::Element) { + *self = val; + } + #[inline(always)] + // SAFETY: This is a single-lane (LANES=1) scalar type, so the lane index + // is always 0 and the operation is trivially safe regardless of input. + unsafe fn replace_unchecked(&mut self, _: usize, val: Self::Element) { + *self = val; + } + #[inline(always)] + fn select(self, cond: Self::SimdBool, other: Self) -> Self { + if cond { + self + } else { + other + } + } +} + +impl PrimitiveSimdValue for DualVec {} + impl SimdValue for Reverse { const LANES: usize = 1; type Element = Self; @@ -103,6 +145,7 @@ impl PrimitiveSimdValue for Reverse {} // ══════════════════════════════════════════════ impl Field for Dual {} +impl Field for DualVec {} impl Field for Reverse {} // ══════════════════════════════════════════════ @@ -190,6 +233,87 @@ impl SubsetOf> for f32 { } } +// Identity: DualVec ⊂ DualVec +impl SubsetOf> for DualVec { + #[inline] + fn to_superset(&self) -> DualVec { + *self + } + #[inline] + fn from_superset_unchecked(element: &DualVec) -> Self { + *element + } + #[inline] + fn is_in_subset(_: &DualVec) -> bool { + true + } +} + +// f64 ⊂ DualVec (lossless: f64 → constant dual vector) +impl SubsetOf> for f64 { + #[inline] + fn to_superset(&self) -> DualVec { + DualVec::constant(*self) + } + #[inline] + fn from_superset_unchecked(element: &DualVec) -> Self { + element.re + } + #[inline] + fn is_in_subset(element: &DualVec) -> bool { + element.eps.into_iter().all(|e| e == 0.0) + } +} + +// f32 ⊂ DualVec (lossless: f32 → constant dual vector) +impl SubsetOf> for f32 { + #[inline] + fn to_superset(&self) -> DualVec { + DualVec::constant(*self) + } + #[inline] + fn from_superset_unchecked(element: &DualVec) -> Self { + element.re + } + #[inline] + fn is_in_subset(element: &DualVec) -> bool { + element.eps.into_iter().all(|e| e == 0.0) + } +} + +// f64 ⊂ DualVec (lossy: f64 → f32 → constant dual vector) +// Required by ComplexField: SupersetOf +impl SubsetOf> for f64 { + #[inline] + fn to_superset(&self) -> DualVec { + DualVec::constant(*self as f32) + } + #[inline] + fn from_superset_unchecked(element: &DualVec) -> Self { + element.re as f64 + } + #[inline] + fn is_in_subset(element: &DualVec) -> bool { + element.eps.into_iter().all(|e| e == 0.0) + } +} + +// f32 ⊂ DualVec (lossless: f32 → f64 → constant dual vector) +impl SubsetOf> for f32 { + #[inline] + fn to_superset(&self) -> DualVec { + DualVec::constant(*self as f64) + } + #[inline] + fn from_superset_unchecked(element: &DualVec) -> Self { + element.re as f32 + } + #[inline] + fn is_in_subset(element: &DualVec) -> bool { + element.eps.into_iter().all(|e| e == 0.0) + } +} + // Identity: Reverse ⊂ Reverse impl SubsetOf> for Reverse { #[inline] @@ -322,6 +446,53 @@ where } } +impl AbsDiffEq for DualVec +where + F: AbsDiffEq, +{ + type Epsilon = Self; + + #[inline] + fn default_epsilon() -> Self { + DualVec::constant(F::default_epsilon()) + } + + #[inline] + fn abs_diff_eq(&self, other: &Self, epsilon: Self) -> bool { + self.re.abs_diff_eq(&other.re, epsilon.re) + } +} + +impl RelativeEq for DualVec +where + F: RelativeEq, +{ + #[inline] + fn default_max_relative() -> Self { + DualVec::constant(F::default_max_relative()) + } + + #[inline] + fn relative_eq(&self, other: &Self, epsilon: Self, max_relative: Self) -> bool { + self.re.relative_eq(&other.re, epsilon.re, max_relative.re) + } +} + +impl UlpsEq for DualVec +where + F: UlpsEq, +{ + #[inline] + fn default_max_ulps() -> u32 { + F::default_max_ulps() + } + + #[inline] + fn ulps_eq(&self, other: &Self, epsilon: Self, max_ulps: u32) -> bool { + self.re.ulps_eq(&other.re, epsilon.re, max_ulps) + } +} + impl AbsDiffEq for Reverse where F: AbsDiffEq, @@ -698,6 +869,334 @@ macro_rules! impl_real_field_dual { impl_real_field_dual!(f32); impl_real_field_dual!(f64); +// ══════════════════════════════════════════════ +// ComplexField for DualVec +// ══════════════════════════════════════════════ + +// We implement ComplexField concretely for f32 and f64 to satisfy all trait +// bounds (SubsetOf conversions require concrete types). Use a macro to avoid +// duplication. + +macro_rules! impl_complex_field_dual_vec { + ($f:ty) => { + impl ComplexField for DualVec<$f, N> { + type RealField = Self; + + #[inline] + fn from_real(re: Self::RealField) -> Self { + re + } + #[inline] + fn real(self) -> Self::RealField { + self + } + #[inline] + fn imaginary(self) -> Self::RealField { + Self::zero() + } + #[inline] + fn modulus(self) -> Self::RealField { + DualVec::abs(self) + } + #[inline] + fn modulus_squared(self) -> Self::RealField { + self * self + } + #[inline] + fn argument(self) -> Self::RealField { + if self.re >= <$f>::zero() { + Self::zero() + } else { + Self::pi() + } + } + #[inline] + fn norm1(self) -> Self::RealField { + DualVec::abs(self) + } + #[inline] + fn scale(self, factor: Self::RealField) -> Self { + self * factor + } + #[inline] + fn unscale(self, factor: Self::RealField) -> Self { + self / factor + } + #[inline] + fn floor(self) -> Self { + DualVec::floor(self) + } + #[inline] + fn ceil(self) -> Self { + DualVec::ceil(self) + } + #[inline] + fn round(self) -> Self { + DualVec::round(self) + } + #[inline] + fn trunc(self) -> Self { + DualVec::trunc(self) + } + #[inline] + fn fract(self) -> Self { + DualVec::fract(self) + } + #[inline] + fn mul_add(self, a: Self, b: Self) -> Self { + DualVec::mul_add(self, a, b) + } + #[inline] + fn abs(self) -> Self::RealField { + DualVec::abs(self) + } + #[inline] + fn hypot(self, other: Self) -> Self::RealField { + DualVec::hypot(self, other) + } + #[inline] + fn recip(self) -> Self { + DualVec::recip(self) + } + #[inline] + fn conjugate(self) -> Self { + self // real type + } + #[inline] + fn sin(self) -> Self { + DualVec::sin(self) + } + #[inline] + fn cos(self) -> Self { + DualVec::cos(self) + } + #[inline] + fn sin_cos(self) -> (Self, Self) { + DualVec::sin_cos(self) + } + #[inline] + fn tan(self) -> Self { + DualVec::tan(self) + } + #[inline] + fn asin(self) -> Self { + DualVec::asin(self) + } + #[inline] + fn acos(self) -> Self { + DualVec::acos(self) + } + #[inline] + fn atan(self) -> Self { + DualVec::atan(self) + } + #[inline] + fn sinh(self) -> Self { + DualVec::sinh(self) + } + #[inline] + fn cosh(self) -> Self { + DualVec::cosh(self) + } + #[inline] + fn tanh(self) -> Self { + DualVec::tanh(self) + } + #[inline] + fn asinh(self) -> Self { + DualVec::asinh(self) + } + #[inline] + fn acosh(self) -> Self { + DualVec::acosh(self) + } + #[inline] + fn atanh(self) -> Self { + DualVec::atanh(self) + } + #[inline] + fn log(self, base: Self::RealField) -> Self { + DualVec::log(self, base) + } + #[inline] + fn log2(self) -> Self { + DualVec::log2(self) + } + #[inline] + fn log10(self) -> Self { + DualVec::log10(self) + } + #[inline] + fn ln(self) -> Self { + DualVec::ln(self) + } + #[inline] + fn ln_1p(self) -> Self { + DualVec::ln_1p(self) + } + #[inline] + fn sqrt(self) -> Self { + DualVec::sqrt(self) + } + #[inline] + fn exp(self) -> Self { + DualVec::exp(self) + } + #[inline] + fn exp2(self) -> Self { + DualVec::exp2(self) + } + #[inline] + fn exp_m1(self) -> Self { + DualVec::exp_m1(self) + } + #[inline] + fn powi(self, n: i32) -> Self { + DualVec::powi(self, n) + } + #[inline] + fn powf(self, n: Self::RealField) -> Self { + DualVec::powf(self, n) + } + #[inline] + fn powc(self, n: Self) -> Self { + DualVec::powf(self, n) + } + #[inline] + fn cbrt(self) -> Self { + DualVec::cbrt(self) + } + #[inline] + fn is_finite(&self) -> bool { + self.re.is_finite() + } + #[inline] + fn try_sqrt(self) -> Option { + if self.re >= <$f>::zero() { + Some(DualVec::sqrt(self)) + } else { + None + } + } + } + }; +} + +impl_complex_field_dual_vec!(f32); +impl_complex_field_dual_vec!(f64); + +// ══════════════════════════════════════════════ +// RealField for DualVec +// ══════════════════════════════════════════════ + +macro_rules! impl_real_field_dual_vec { + ($f:ty) => { + impl RealField for DualVec<$f, N> { + #[inline] + fn is_sign_positive(&self) -> bool { + self.re.is_sign_positive() + } + #[inline] + fn is_sign_negative(&self) -> bool { + self.re.is_sign_negative() + } + #[inline] + fn copysign(self, sign: Self) -> Self { + DualVec::abs(self) * DualVec::signum(sign) + } + #[inline] + fn max(self, other: Self) -> Self { + DualVec::max(self, other) + } + #[inline] + fn min(self, other: Self) -> Self { + DualVec::min(self, other) + } + #[inline] + fn clamp(self, min: Self, max: Self) -> Self { + DualVec::max(DualVec::min(self, max), min) + } + #[inline] + fn atan2(self, other: Self) -> Self { + DualVec::atan2(self, other) + } + #[inline] + fn min_value() -> Option { + Some(DualVec::constant(<$f>::MIN)) + } + #[inline] + fn max_value() -> Option { + Some(DualVec::constant(<$f>::MAX)) + } + + // ── Constants ── + #[inline] + fn pi() -> Self { + DualVec::constant(<$f>::PI()) + } + #[inline] + fn two_pi() -> Self { + DualVec::constant(<$f>::TAU()) + } + #[inline] + fn frac_pi_2() -> Self { + DualVec::constant(<$f>::FRAC_PI_2()) + } + #[inline] + fn frac_pi_3() -> Self { + DualVec::constant(<$f>::FRAC_PI_3()) + } + #[inline] + fn frac_pi_4() -> Self { + DualVec::constant(<$f>::FRAC_PI_4()) + } + #[inline] + fn frac_pi_6() -> Self { + DualVec::constant(<$f>::FRAC_PI_6()) + } + #[inline] + fn frac_pi_8() -> Self { + DualVec::constant(<$f>::FRAC_PI_8()) + } + #[inline] + fn frac_1_pi() -> Self { + DualVec::constant(<$f>::FRAC_1_PI()) + } + #[inline] + fn frac_2_pi() -> Self { + DualVec::constant(<$f>::FRAC_2_PI()) + } + #[inline] + fn frac_2_sqrt_pi() -> Self { + DualVec::constant(<$f>::FRAC_2_SQRT_PI()) + } + #[inline] + fn e() -> Self { + DualVec::constant(<$f>::E()) + } + #[inline] + fn log2_e() -> Self { + DualVec::constant(<$f>::LOG2_E()) + } + #[inline] + fn log10_e() -> Self { + DualVec::constant(<$f>::LOG10_E()) + } + #[inline] + fn ln_2() -> Self { + DualVec::constant(<$f>::LN_2()) + } + #[inline] + fn ln_10() -> Self { + DualVec::constant(<$f>::LN_10()) + } + } + }; +} + +impl_real_field_dual_vec!(f32); +impl_real_field_dual_vec!(f64); + // ══════════════════════════════════════════════ // ComplexField for Reverse // ══════════════════════════════════════════════ diff --git a/tests/nalgebra_integration.rs b/tests/nalgebra_integration.rs index 181473f..8c1ad99 100644 --- a/tests/nalgebra_integration.rs +++ b/tests/nalgebra_integration.rs @@ -3,7 +3,7 @@ #![cfg(feature = "simba")] use approx::assert_relative_eq; -use echidna::{Dual64, Reverse64}; +use echidna::{Dual64, DualVec64, Reverse64}; use nalgebra::{Matrix3, Vector3}; use num_traits::Float; @@ -74,6 +74,81 @@ fn dual_matrix_vector_product() { assert_relative_eq!(result[2].eps, 7.0, max_relative = 1e-12); } +// ── DualVec in nalgebra ── + +#[test] +fn dual_vec_vector3_dot_product() { + let a = Vector3::new( + DualVec64::<3>::new(1.0, [1.0, 0.0, 0.0]), + DualVec64::<3>::new(2.0, [0.0, 1.0, 0.0]), + DualVec64::<3>::new(3.0, [0.0, 0.0, 1.0]), + ); + let b = Vector3::new( + DualVec64::<3>::constant(4.0), + DualVec64::<3>::constant(5.0), + DualVec64::<3>::constant(6.0), + ); + let dot = a.dot(&b); + // dot = 1*4 + 2*5 + 3*6 = 32 + assert_relative_eq!(dot.re, 32.0, max_relative = 1e-12); + // d(dot)/d(a[0]) = b[0] = 4 + // d(dot)/d(a[1]) = b[1] = 5 + // d(dot)/d(a[2]) = b[2] = 6 + assert_relative_eq!(dot.eps[0], 4.0, max_relative = 1e-12); + assert_relative_eq!(dot.eps[1], 5.0, max_relative = 1e-12); + assert_relative_eq!(dot.eps[2], 6.0, max_relative = 1e-12); +} + +#[test] +fn dual_vec_vector3_norm() { + // v = [x, y, z], norm = sqrt(x² + y² + z²) + // d(norm)/dv_i = v_i / norm + let v = Vector3::new( + DualVec64::<3>::new(3.0, [1.0, 0.0, 0.0]), + DualVec64::<3>::new(2.0, [0.0, 1.0, 0.0]), + DualVec64::<3>::new(3.0, [0.0, 0.0, 1.0]), + ); + let n = v.norm(); + let expected_norm = (9.0 + 4.0 + 9.0_f64).sqrt(); + assert_relative_eq!(n.re, expected_norm, max_relative = 1e-12); + let expected_deriv = v.map(|x| x.re / expected_norm); + assert_relative_eq!(n.eps[0], expected_deriv[0], max_relative = 1e-10); + assert_relative_eq!(n.eps[1], expected_deriv[1], max_relative = 1e-10); + assert_relative_eq!(n.eps[2], expected_deriv[2], max_relative = 1e-10); +} + +#[test] +fn dual_vec_matrix_vector_product() { + // M * v where v[0] is the variable + let m = Matrix3::new( + DualVec64::<3>::constant(1.0), + DualVec64::<3>::constant(2.0), + DualVec64::<3>::constant(3.0), + DualVec64::<3>::constant(4.0), + DualVec64::<3>::constant(5.0), + DualVec64::<3>::constant(6.0), + DualVec64::<3>::constant(7.0), + DualVec64::<3>::constant(8.0), + DualVec64::<3>::constant(9.0), + ); + let v = Vector3::new( + DualVec64::<3>::new(1.0, [1.0, 0.0, 0.0]), + DualVec64::<3>::new(3.0, [0.0, 1.0, 0.0]), + DualVec64::<3>::new(5.0, [0.0, 0.0, 1.0]), + ); + let result = m * v; + // result = [22, 49, 76] + assert_relative_eq!(result[0].re, 22.0, max_relative = 1e-12); + assert_relative_eq!(result[1].re, 49.0, max_relative = 1e-12); + assert_relative_eq!(result[2].re, 76.0, max_relative = 1e-12); + // dv_i/dx_j = M_ij + for i in 0..3 { + for j in 0..3 { + assert_relative_eq!(result[i].eps[j], m[(i, j)].re, max_relative = 1e-12); + } + } +} + // ── Reverse in nalgebra ── #[test] @@ -157,6 +232,36 @@ fn dual_matrix3_try_inverse() { } } +#[test] +fn dual_vec_matrix3_try_inverse() { + // A 3×3 matrix of dual vectors + // test that matrix times inv(matrix) is the identity with all entries with zero differential part + let m = Matrix3::new( + DualVec64::<3>::new(2.0, [1.0, 0.0, 0.0]), + DualVec64::<3>::new(1.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(0.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(1.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(3.0, [0.0, 1.0, 0.0]), + DualVec64::<3>::new(1.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(0.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(1.0, [0.0, 0.0, 0.0]), + DualVec64::<3>::new(2.0, [0.0, 0.0, 1.0]), + ); + dbg!(println!("{}", m.map(|x| x.re))); + let inv = m.try_inverse().expect("matrix should be invertible"); + let identity = m * inv; + for i in 0..3 { + for j in 0..3 { + let actual = identity[(i, j)]; + let expected = if i == j { 1.0 } else { 0.0 }; + assert_relative_eq!(actual.re, expected, max_relative = 1e-10); + for k in 0..3 { + assert_relative_eq!(actual.eps[k], 0.0, max_relative = 1e-10); + } + } + } +} + #[test] fn reverse_matrix3_try_inverse() { // Same test for Reverse — validates the full ComplexField/RealField chain.