From 0c664523198bee063da51dc2524cbe75e8882c2d Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 10 Mar 2025 09:46:31 -0600 Subject: [PATCH 01/18] include uom example in README --- README.md | 7 +++++-- src/lib.rs | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3dddd22..c8b9f44 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,15 @@ Also see the [`examples`](https://github.com/NREL/ninterp/tree/62a62ccd2b3c28591 meaning strategy concrete types are determined at compilation. This gives increased performance at the cost of runtime flexibility. To allow swapping strategies at runtime, - use *dynamic dispatch* by providing a trait object `Box`/etc. to the `new` method. + use *dynamic dispatch* by providing a boxed trait object + `Box`/etc. to the `new` method. - Interpolator dynamic dispatch using `Box`: [`dynamic_interpolator.rs`](https://github.com/NREL/ninterp/blob/62a62ccd2b3c285919baae609089dee287dc3842/examples/dynamic_interpolator.rs) - Defining custom strategies: [`custom_strategy.rs`](https://github.com/NREL/ninterp/blob/62a62ccd2b3c285919baae609089dee287dc3842/examples/custom_strategy.rs) +- Using transmutable (transparent) types, such as `uom::si::Quantity`: [`uom.rs`](https://github.com/NREL/ninterp/blob/de2c770dc3614ba43af9e015481fecdc20538380/examples/uom.rs) + ## Overview A prelude module has been defined: ```rust @@ -65,7 +68,7 @@ call [`Interpolator::validate`](https://docs.rs/ninterp/latest/ninterp/trait.Int To change the extrapolation setting, call `set_extrapolate`. To change the interpolation strategy, -supply a `Box`/etc. in the new method, +supply a `Box`/etc. upon instantiation, and call `set_strategy`. ### Strategies diff --git a/src/lib.rs b/src/lib.rs index 8da9628..724568a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,8 @@ //! meaning strategy concrete types are determined at compilation. //! This gives increased performance at the cost of runtime flexibility. //! To allow swapping strategies at runtime, -//! use *dynamic dispatch* by providing a trait object `Box`/etc. to the `new` method. +//! use *dynamic dispatch* by providing a boxed trait object +//! `Box`/etc. to the `new` method. //! //! - Interpolator dynamic dispatch using `Box`: //! [`dynamic_interpolator.rs`](https://github.com/NREL/ninterp/blob/62a62ccd2b3c285919baae609089dee287dc3842/examples/dynamic_interpolator.rs) @@ -43,6 +44,8 @@ //! - Defining custom strategies: //! [`custom_strategy.rs`](https://github.com/NREL/ninterp/blob/62a62ccd2b3c285919baae609089dee287dc3842/examples/custom_strategy.rs) //! +//! - Using transmutable (transparent) types, such as `uom::si::Quantity`: [`uom.rs`](https://github.com/NREL/ninterp/blob/de2c770dc3614ba43af9e015481fecdc20538380/examples/uom.rs) +//! //! # Overview //! A prelude module has been defined: //! ```rust,text From e0bebac2abbc396aec5dce7cee356fa980cce305 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 10 Mar 2025 10:26:10 -0600 Subject: [PATCH 02/18] starting point for cubic interpolation --- src/lib.rs | 4 ++-- src/strategy.rs | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 724568a..ae2c185 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,7 +75,7 @@ //! and call `set_strategy`. //! //! ## Strategies -//! An interpolation strategy (e.g. [`Linear`], [`Nearest`], [`LeftNearest`], [`RightNearest`]) must be specified. +//! An interpolation strategy (e.g. [`Linear`], [`Cubic`], [`Nearest`], [`LeftNearest`], [`RightNearest`]) must be specified. //! Not all interpolation strategies are implemented for every dimensionality. //! [`Linear`] and [`Nearest`] are implemented for all dimensionalities. //! @@ -117,7 +117,7 @@ /// - The extrapolation setting enum: [`Extrapolate`] pub mod prelude { pub use crate::interpolator::*; - pub use crate::strategy::{LeftNearest, Linear, Nearest, RightNearest}; + pub use crate::strategy::{Cubic, LeftNearest, Linear, Nearest, RightNearest}; pub use crate::Extrapolate; pub use crate::Interpolator; } diff --git a/src/strategy.rs b/src/strategy.rs index 59bb10b..cb0c2f2 100644 --- a/src/strategy.rs +++ b/src/strategy.rs @@ -163,6 +163,14 @@ pub fn find_nearest_index(arr: ArrayView1, target: T) -> usize #[derive(Debug)] pub struct Linear; +// TODO: `pub struct Quadratic;` +// Maybe `pub struct Polynomial(usize);` as well? +// with `pub type Quadratic = Polynomial(2)` and `pub type Cubic = Polynomial(3)` + +/// Cubic spline interpolation: TODO +#[derive(Debug)] +pub struct Cubic; + /// Nearest value interpolation: /// /// # Note From 75473f0db971df223390e308adf144458c381e2a Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Fri, 14 Mar 2025 13:31:43 -0600 Subject: [PATCH 03/18] cubic derives --- src/strategy.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strategy.rs b/src/strategy.rs index 4c320e0..f5357bf 100644 --- a/src/strategy.rs +++ b/src/strategy.rs @@ -172,7 +172,8 @@ pub struct Linear; // with `pub type Quadratic = Polynomial(2)` and `pub type Cubic = Polynomial(3)` /// Cubic spline interpolation: TODO -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Cubic; /// Nearest value interpolation: From 3ba86931a66df16e000b3ea5075f8ebfa85a6e09 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Fri, 14 Mar 2025 22:32:57 -0600 Subject: [PATCH 04/18] start cubic init --- src/one/strategies.rs | 23 ++++++++++++++++ src/strategy.rs | 63 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/one/strategies.rs b/src/one/strategies.rs index baf1d45..6184800 100644 --- a/src/one/strategies.rs +++ b/src/one/strategies.rs @@ -36,6 +36,29 @@ where } } +impl Strategy1D for Cubic +where + D: Data + RawDataClone + Clone, + D::Elem: Num + PartialOrd + Copy + Default + Debug, +{ + fn init(&self, data: &InterpData1D) -> Result<(), ValidateError> { + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData1D, + point: &[::Elem; 1], + ) -> Result<::Elem, InterpolateError> { + todo!() + } + + /// Returns `true` + fn allow_extrapolate(&self) -> bool { + true + } +} + impl Strategy1D for Nearest where D: Data + RawDataClone + Clone, diff --git a/src/strategy.rs b/src/strategy.rs index ec91706..527525d 100644 --- a/src/strategy.rs +++ b/src/strategy.rs @@ -207,14 +207,65 @@ pub fn find_nearest_index(arr: ArrayView1, target: &T) -> usiz #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Linear; -// TODO: `pub struct Quadratic;` -// Maybe `pub struct Polynomial(usize);` as well? -// with `pub type Quadratic = Polynomial(2)` and `pub type Cubic = Polynomial(3)` - /// Cubic spline interpolation: TODO -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub struct Cubic; +pub struct Cubic { + pub boundary_cond: CubicBC, + pub coeffs: ArrayD, +} + +impl Cubic +where + T: Default, +{ + pub fn new(bc: CubicBC) -> Self { + Self { + boundary_cond: bc, + coeffs: as Default>::default(), + } + } + + pub fn natural() -> Self { + Self::new(CubicBC::Natural) + } + + pub fn clamped(a: T, b: T) -> Self { + Self::new(CubicBC::Clamped(a, b)) + } + + pub fn not_a_knot() -> Self { + Self::new(CubicBC::NotAKnot) + } + + pub fn periodic() -> Self { + Self::new(CubicBC::Periodic) + } + + pub fn solve_coeffs(&mut self) { + match &self.boundary_cond { + CubicBC::Natural => { + todo!() + } + _ => todo!(), + } + } +} + +/// Cubic boundary conditions. +#[derive(Copy, Clone, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicBC { + /// Second derivatives at endpoints are 0, thus extrapolation is linear. + // https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf + #[default] + Natural, + /// Specific first derivatives at endpoints. + Clamped(T, T), + NotAKnot, + // https://math.ou.edu/~npetrov/project-5093-s11.pdf + Periodic, +} /// Nearest value interpolation: /// From 95a020770f7d065f9e11c9ed22c19206bbf48b3d Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Sun, 16 Mar 2025 15:57:36 -0600 Subject: [PATCH 05/18] cubic interpolation for natural and clamped --- src/lib.rs | 2 +- src/one/mod.rs | 3 +- src/one/strategies.rs | 130 ++++++++++++++++++++++++++- src/strategy/cubic.rs | 97 ++++++++++++++++++++ src/{strategy.rs => strategy/mod.rs} | 63 +------------ 5 files changed, 230 insertions(+), 65 deletions(-) create mode 100644 src/strategy/cubic.rs rename src/{strategy.rs => strategy/mod.rs} (80%) diff --git a/src/lib.rs b/src/lib.rs index 46930c2..d9504d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -155,7 +155,7 @@ pub use ndarray; pub(crate) use ndarray::prelude::*; pub(crate) use ndarray::{Data, Ix, RawDataClone}; -pub(crate) use num_traits::{clamp, Euclid, Num, One}; +pub(crate) use num_traits::{clamp, Euclid, Float, Num, NumCast, One, Zero}; pub(crate) use dyn_clone::*; diff --git a/src/one/mod.rs b/src/one/mod.rs index 1fff495..72c98e0 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -102,12 +102,13 @@ where strategy: S, extrapolate: Extrapolate, ) -> Result { - let interpolator = Self { + let mut interpolator = Self { data: InterpData1D::new(x, f_x)?, strategy, extrapolate, }; interpolator.check_extrapolate(&interpolator.extrapolate)?; + interpolator.strategy.init(&interpolator.data)?; Ok(interpolator) } } diff --git a/src/one/strategies.rs b/src/one/strategies.rs index 6184800..fee7db5 100644 --- a/src/one/strategies.rs +++ b/src/one/strategies.rs @@ -39,9 +39,79 @@ where impl Strategy1D for Cubic where D: Data + RawDataClone + Clone, - D::Elem: Num + PartialOrd + Copy + Default + Debug, + D::Elem: Float + Default + Debug, { - fn init(&self, data: &InterpData1D) -> Result<(), ValidateError> { + /// Solves coefficients + + /// Reference: https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf + /// ```text + /// ┌─ ─┐┌─ ─┐ ┌─ ─┐ + /// │ V1 H1 ││ Z1 │ │ U1 │ + /// │ H1 V2 H2 ││ Z2 │ │ U2 │ + /// │ H2 V3 H3 ││ Z3 │ │ U3 │ + /// │ . . . ││ . │ ── │ . │ + /// │ . . . ││ . │ ── │ . │ + /// │ . . . ││ . │ │ . │ + /// │ . Vn-2 Hn-2 ││ Zn-2 │ │ Un-2 │ + /// │ Hn-2 Vn-1 ││ Zn-1 │ │ Un-1 │ + /// └─ ─┘└─ ─┘ └─ ─┘ + /// ``` + fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { + // Number of segments + let n = data.grid[0].len() - 1; + + let zero = D::Elem::zero(); + let one = D::Elem::one(); + let two = ::from(2.).unwrap(); + let six = ::from(6.).unwrap(); + + // let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); + // let v = Array1::from_shape_fn(n - 1, |i| two * (h[i] + h[i + 1])); + // let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); + // let u = Array1::from_shape_fn(n - 1, |i| six * (b[i + 1] - b[i])); + + let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); + let v = Array1::from_shape_fn(n + 1, |i| { + if i == 0 || i == n { + match &self.boundary_cond { + CubicBC::Natural => one, + CubicBC::Clamped(_, _) => two * h[0], + _ => todo!(), + } + } else { + two * (h[i - 1] + h[i]) + } + }); + let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); + let u = Array1::from_shape_fn(n + 1, |i| { + if i == 0 || i == n { + match &self.boundary_cond { + CubicBC::Natural => zero, + CubicBC::Clamped(l, r) => { + if i == 0 { + six * (b[i] - *l) + } else { + six * (*r - b[i - 1]) + } + } + _ => todo!(), + } + } else { + six * (b[i] - b[i - 1]) + } + }); + + let (sub, sup) = match &self.boundary_cond { + CubicBC::Natural => ( + &Array1::from_shape_fn(n, |i| if i == n - 1 { zero } else { h[i] }), + &Array1::from_shape_fn(n, |i| if i == 0 { zero } else { h[i] }), + ), + CubicBC::Clamped(_, _) => (&h, &h), + _ => todo!(), + }; + + self.z = Self::thomas(sub.view(), v.view(), sup.view(), u.view()); + Ok(()) } @@ -50,7 +120,24 @@ where data: &InterpData1D, point: &[::Elem; 1], ) -> Result<::Elem, InterpolateError> { - todo!() + let l = if &point[0] < data.grid[0].first().unwrap() { + 0 + } else if &point[0] > data.grid[0].last().unwrap() { + data.grid[0].len() - 2 + } else { + find_nearest_index(data.grid[0].view(), &point[0]) + }; + let u = l + 1; + + let six = ::from(6.).unwrap(); + let h_i = data.grid[0][u] - data.grid[0][l]; + + Ok( + self.z[u] / (six * h_i) * (point[0] - data.grid[0][l]).powi(3) + + self.z[l] / (six * h_i) * (data.grid[0][u] - point[0]).powi(3) + + (data.values[u] / h_i - self.z[u] * h_i / six) * (point[0] - data.grid[0][l]) + + (data.values[l] / h_i - self.z[l] * h_i / six) * (data.grid[0][u] - point[0]), + ) } /// Returns `true` @@ -59,6 +146,43 @@ where } } +#[cfg(test)] +mod tests { + use super::*; + use ndarray::*; + + #[test] + fn test_cubic_natural() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., 6., 19., 99., 291., 444.]; + + let interp = + Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); + + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + } + } + + #[test] + fn test_cubic_clamped() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., 6., 19., 99., 291., 444.]; + + let interp = Interp1D::new( + x.view(), + f_x.view(), + Cubic::clamped(0., 0.), + Extrapolate::Enable, + ) + .unwrap(); + + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + } + } +} + impl Strategy1D for Nearest where D: Data + RawDataClone + Clone, diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs new file mode 100644 index 0000000..ec965fa --- /dev/null +++ b/src/strategy/cubic.rs @@ -0,0 +1,97 @@ +use super::*; + +#[derive(Clone, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub struct Cubic { + pub boundary_cond: CubicBC, + pub z: Array1, +} + +impl Cubic +where + T: Default, +{ + pub fn new(bc: CubicBC) -> Self { + Self { + boundary_cond: bc, + z: as Default>::default(), + } + } + + pub fn natural() -> Self { + Self::new(CubicBC::Natural) + } + + pub fn clamped(a: T, b: T) -> Self { + Self::new(CubicBC::Clamped(a, b)) + } + + pub fn not_a_knot() -> Self { + Self::new(CubicBC::NotAKnot) + } + + pub fn periodic() -> Self { + Self::new(CubicBC::Periodic) + } +} + +impl Cubic +where + T: Num + Copy + Default, +{ + /// Solves Ax = d for a tridiagonal matrix A using the [Thomas algorithm](https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm). + /// - `a`: sub-diagonal (1 element shorter than `b` and `d`) + /// - `b`: diagonal + /// - `c`: super-diagonal (1 element shorter than `b` and `d`) + /// - `d`: right-hand side + pub fn thomas( + a: ArrayView1, + b: ArrayView1, + c: ArrayView1, + d: ArrayView1, + ) -> Array1 { + let n = d.len(); + assert_eq!(a.len(), n - 1); + assert_eq!(b.len(), n); + assert_eq!(c.len(), n - 1); + + let mut c_prime = Array1::default(n - 1); + let mut d_prime = Array1::default(n); + let mut x = Array1::default(n); + + // Forward sweep + c_prime[0] = c[0] / b[0]; + d_prime[0] = d[0] / b[0]; + + for i in 1..n - 1 { + let denom = b[i] - a[i - 1] * c_prime[i - 1]; + c_prime[i] = c[i] / denom; + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom; + } + d_prime[n - 1] = + (d[n - 1] - a[n - 2] * d_prime[n - 2]) / (b[n - 1] - a[n - 2] * c_prime[n - 2]); + + // Back substitution + x[n - 1] = d_prime[n - 1]; + for i in (0..n - 1).rev() { + x[i] = d_prime[i] - c_prime[i] * x[i + 1]; + } + + x + } +} + +/// Cubic spline boundary conditions. +#[derive(Copy, Clone, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicBC { + /// Second derivatives at endpoints are 0, thus extrapolation is linear. + // https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf + #[default] + Natural, + /// Specific first derivatives at endpoints. + Clamped(T, T), + NotAKnot, + // https://math.ou.edu/~npetrov/project-5093-s11.pdf + Periodic, +} diff --git a/src/strategy.rs b/src/strategy/mod.rs similarity index 80% rename from src/strategy.rs rename to src/strategy/mod.rs index ccee040..5335dae 100644 --- a/src/strategy.rs +++ b/src/strategy/mod.rs @@ -2,6 +2,9 @@ use super::*; +mod cubic; +pub use cubic::*; + pub trait Strategy1D: Debug + DynClone where D: Data + RawDataClone + Clone, @@ -207,66 +210,6 @@ pub fn find_nearest_index(arr: ArrayView1, target: &T) -> usiz #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Linear; -/// Cubic spline interpolation: TODO -#[derive(Clone, Debug, Default, PartialEq)] -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub struct Cubic { - pub boundary_cond: CubicBC, - pub coeffs: ArrayD, -} - -impl Cubic -where - T: Default, -{ - pub fn new(bc: CubicBC) -> Self { - Self { - boundary_cond: bc, - coeffs: as Default>::default(), - } - } - - pub fn natural() -> Self { - Self::new(CubicBC::Natural) - } - - pub fn clamped(a: T, b: T) -> Self { - Self::new(CubicBC::Clamped(a, b)) - } - - pub fn not_a_knot() -> Self { - Self::new(CubicBC::NotAKnot) - } - - pub fn periodic() -> Self { - Self::new(CubicBC::Periodic) - } - - pub fn solve_coeffs(&mut self) { - match &self.boundary_cond { - CubicBC::Natural => { - todo!() - } - _ => todo!(), - } - } -} - -/// Cubic boundary conditions. -#[derive(Copy, Clone, Debug, Default, PartialEq)] -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub enum CubicBC { - /// Second derivatives at endpoints are 0, thus extrapolation is linear. - // https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf - #[default] - Natural, - /// Specific first derivatives at endpoints. - Clamped(T, T), - NotAKnot, - // https://math.ou.edu/~npetrov/project-5093-s11.pdf - Periodic, -} - /// Nearest value interpolation: /// /// # Note From 2df45c4c0870d959b27d19c0c0d8bc2a9afd0aac Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Sun, 16 Mar 2025 23:30:52 -0600 Subject: [PATCH 06/18] natural and clamped working with good test coverage --- src/one/mod.rs | 108 ++++++++++++++++++++++++++++++++++++++++++ src/one/strategies.rs | 85 +++++++++------------------------ 2 files changed, 131 insertions(+), 62 deletions(-) diff --git a/src/one/mod.rs b/src/one/mod.rs index 72c98e0..ef67735 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -334,4 +334,112 @@ mod tests { assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05); assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2); } + + #[test] + fn test_cubic_natural() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., 6., 19., 99., 291., 444.]; + + let interp = + Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + } + + let x0 = x.first().unwrap(); + let xn = x.last().unwrap(); + let y0 = f_x.first().unwrap(); + let yn = f_x.last().unwrap(); + + let range = xn - x0; + + let x_low = x0 - 0.2 * range; + let y_low = interp.interpolate(&[x_low]).unwrap(); + let slope_low = (y0 - y_low) / (x0 - x_low); + + let x_high = xn + 0.2 * range; + let y_high = interp.interpolate(&[x_high]).unwrap(); + let slope_high = (y_high - yn) / (x_high - xn); + + let xs_left = Array1::linspace(x0 - 1e-6, x0 + 1e-6, 50); + let xs_right = Array1::linspace(xn - 1e-6, xn + 1e-6, 50); + + // Left extrapolation is linear + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_low); + + // Right extrapolation is linear + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_high); + } + + #[test] + fn test_cubic_clamped() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., -90., 19., 99., 291., 444.]; + + let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); + let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); + + for (a, b) in [(-5., 10.), (0., 0.), (2.4, -5.2)] { + let interp = Interp1D::new( + x.view(), + f_x.view(), + Cubic::clamped(a, b), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + } + + // Left slope = a + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), a); + + // Right slope = b + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), b); + } + } } diff --git a/src/one/strategies.rs b/src/one/strategies.rs index fee7db5..212b863 100644 --- a/src/one/strategies.rs +++ b/src/one/strategies.rs @@ -41,21 +41,6 @@ where D: Data + RawDataClone + Clone, D::Elem: Float + Default + Debug, { - /// Solves coefficients - - /// Reference: https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf - /// ```text - /// ┌─ ─┐┌─ ─┐ ┌─ ─┐ - /// │ V1 H1 ││ Z1 │ │ U1 │ - /// │ H1 V2 H2 ││ Z2 │ │ U2 │ - /// │ H2 V3 H3 ││ Z3 │ │ U3 │ - /// │ . . . ││ . │ ── │ . │ - /// │ . . . ││ . │ ── │ . │ - /// │ . . . ││ . │ │ . │ - /// │ . Vn-2 Hn-2 ││ Zn-2 │ │ Un-2 │ - /// │ Hn-2 Vn-1 ││ Zn-1 │ │ Un-1 │ - /// └─ ─┘└─ ─┘ └─ ─┘ - /// ``` fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { // Number of segments let n = data.grid[0].len() - 1; @@ -65,11 +50,6 @@ where let two = ::from(2.).unwrap(); let six = ::from(6.).unwrap(); - // let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); - // let v = Array1::from_shape_fn(n - 1, |i| two * (h[i] + h[i + 1])); - // let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); - // let u = Array1::from_shape_fn(n - 1, |i| six * (b[i + 1] - b[i])); - let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); let v = Array1::from_shape_fn(n + 1, |i| { if i == 0 || i == n { @@ -120,16 +100,34 @@ where data: &InterpData1D, point: &[::Elem; 1], ) -> Result<::Elem, InterpolateError> { - let l = if &point[0] < data.grid[0].first().unwrap() { - 0 - } else if &point[0] > data.grid[0].last().unwrap() { - data.grid[0].len() - 2 + let six = ::from(6.).unwrap(); + let last = data.grid[0].len() - 1; + let l = if point[0] < data.grid[0][0] { + match &self.boundary_cond { + CubicBC::Natural => { + // linear extrapolation + let h0 = data.grid[0][1] - data.grid[0][0]; + let k0 = (data.values[1] - data.values[0]) / h0 - h0 * self.z[1] / six; + return Ok(k0 * (point[0] - data.grid[0][0]) + data.values[0]); + } + _ => 0, + } + } else if point[0] > data.grid[0][last] { + match &self.boundary_cond { + CubicBC::Natural => { + // linear extrapolation + let hn = data.grid[0][last] - data.grid[0][last - 1]; + let kn = (data.values[last] - data.values[last - 1]) / hn + + hn * self.z[last - 1] / six; + return Ok(kn * (point[0] - data.grid[0][last]) + data.values[last]); + } + _ => last - 1, + } } else { find_nearest_index(data.grid[0].view(), &point[0]) }; let u = l + 1; - let six = ::from(6.).unwrap(); let h_i = data.grid[0][u] - data.grid[0][l]; Ok( @@ -146,43 +144,6 @@ where } } -#[cfg(test)] -mod tests { - use super::*; - use ndarray::*; - - #[test] - fn test_cubic_natural() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., 6., 19., 99., 291., 444.]; - - let interp = - Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); - - for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) - } - } - - #[test] - fn test_cubic_clamped() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., 6., 19., 99., 291., 444.]; - - let interp = Interp1D::new( - x.view(), - f_x.view(), - Cubic::clamped(0., 0.), - Extrapolate::Enable, - ) - .unwrap(); - - for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) - } - } -} - impl Strategy1D for Nearest where D: Data + RawDataClone + Clone, From 24a6e4527d33c5d820e3ec9f0e03103ad3a2ce89 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 08:20:59 -0600 Subject: [PATCH 07/18] introduce CubicExtrapolate --- src/one/mod.rs | 4 +- src/one/strategies.rs | 49 ++++++++------- src/strategy/cubic.rs | 140 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 141 insertions(+), 52 deletions(-) diff --git a/src/one/mod.rs b/src/one/mod.rs index ef67735..fd00b12 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -363,8 +363,8 @@ mod tests { let y_high = interp.interpolate(&[x_high]).unwrap(); let slope_high = (y_high - yn) / (x_high - xn); - let xs_left = Array1::linspace(x0 - 1e-6, x0 + 1e-6, 50); - let xs_right = Array1::linspace(xn - 1e-6, xn + 1e-6, 50); + let xs_left = Array1::linspace(*x0, x0 + 2e-6, 50); + let xs_right = Array1::linspace(xn - 2e-6, *xn, 50); // Left extrapolation is linear let ys: Array1 = xs_left diff --git a/src/one/strategies.rs b/src/one/strategies.rs index 212b863..1136f78 100644 --- a/src/one/strategies.rs +++ b/src/one/strategies.rs @@ -39,7 +39,7 @@ where impl Strategy1D for Cubic where D: Data + RawDataClone + Clone, - D::Elem: Float + Default + Debug, + D::Elem: Float + Euclid + Default + Debug, { fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { // Number of segments @@ -53,7 +53,7 @@ where let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); let v = Array1::from_shape_fn(n + 1, |i| { if i == 0 || i == n { - match &self.boundary_cond { + match &self.boundary_condition { CubicBC::Natural => one, CubicBC::Clamped(_, _) => two * h[0], _ => todo!(), @@ -65,7 +65,7 @@ where let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); let u = Array1::from_shape_fn(n + 1, |i| { if i == 0 || i == n { - match &self.boundary_cond { + match &self.boundary_condition { CubicBC::Natural => zero, CubicBC::Clamped(l, r) => { if i == 0 { @@ -81,7 +81,7 @@ where } }); - let (sub, sup) = match &self.boundary_cond { + let (sub, sup) = match &self.boundary_condition { CubicBC::Natural => ( &Array1::from_shape_fn(n, |i| if i == n - 1 { zero } else { h[i] }), &Array1::from_shape_fn(n, |i| if i == 0 { zero } else { h[i] }), @@ -100,42 +100,41 @@ where data: &InterpData1D, point: &[::Elem; 1], ) -> Result<::Elem, InterpolateError> { - let six = ::from(6.).unwrap(); let last = data.grid[0].len() - 1; let l = if point[0] < data.grid[0][0] { - match &self.boundary_cond { - CubicBC::Natural => { - // linear extrapolation + match &self.extrapolate { + CubicExtrapolate::Linear => { let h0 = data.grid[0][1] - data.grid[0][0]; - let k0 = (data.values[1] - data.values[0]) / h0 - h0 * self.z[1] / six; + let k0 = (data.values[1] - data.values[0]) / h0 + - h0 * self.z[1] / ::from(6.).unwrap(); return Ok(k0 * (point[0] - data.grid[0][0]) + data.values[0]); } - _ => 0, + CubicExtrapolate::Spline => 0, + CubicExtrapolate::Wrap => { + let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])]; + let l = find_nearest_index(data.grid[0].view(), &point[0]); + return self.evaluate_1d(&point, l, data); + } } } else if point[0] > data.grid[0][last] { - match &self.boundary_cond { - CubicBC::Natural => { - // linear extrapolation + match &self.extrapolate { + CubicExtrapolate::Linear => { let hn = data.grid[0][last] - data.grid[0][last - 1]; let kn = (data.values[last] - data.values[last - 1]) / hn - + hn * self.z[last - 1] / six; + + hn * self.z[last - 1] / ::from(6.).unwrap(); return Ok(kn * (point[0] - data.grid[0][last]) + data.values[last]); } - _ => last - 1, + CubicExtrapolate::Spline => last - 1, + CubicExtrapolate::Wrap => { + let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])]; + let l = find_nearest_index(data.grid[0].view(), &point[0]); + return self.evaluate_1d(&point, l, data); + } } } else { find_nearest_index(data.grid[0].view(), &point[0]) }; - let u = l + 1; - - let h_i = data.grid[0][u] - data.grid[0][l]; - - Ok( - self.z[u] / (six * h_i) * (point[0] - data.grid[0][l]).powi(3) - + self.z[l] / (six * h_i) * (data.grid[0][u] - point[0]).powi(3) - + (data.values[u] / h_i - self.z[u] * h_i / six) * (point[0] - data.grid[0][l]) - + (data.values[l] / h_i - self.z[l] * h_i / six) * (data.grid[0][u] - point[0]), - ) + self.evaluate_1d(point, l, data) } /// Returns `true` diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs index ec965fa..75f9d3c 100644 --- a/src/strategy/cubic.rs +++ b/src/strategy/cubic.rs @@ -1,37 +1,142 @@ use super::*; -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Cubic { - pub boundary_cond: CubicBC, + /// Cubic spline boundary conditions. + pub boundary_condition: CubicBC, + /// Behavior of [`Extrapolate::Enable`]. + pub extrapolate: CubicExtrapolate, + /// Solved second derivatives. pub z: Array1, } +/// Cubic spline boundary conditions. +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicBC { + Natural, + Clamped(T, T), + NotAKnot, + // https://math.ou.edu/~npetrov/project-5093-s11.pdf + Periodic, +} + +/// [`Extrapolate::Enable`] behavior for cubic splines +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicExtrapolate { + /// Linear extrapolation, default for natural splines. + Linear, + /// Use nearest spline to extrapolate. + Spline, + /// Same as [`Extrapolate::Wrap`]. Default for periodic splines. + Wrap, +} + impl Cubic where T: Default, { - pub fn new(bc: CubicBC) -> Self { + /// Cubic spline with given boundary condition and extrapolation behavior. + pub fn new(boundary_condition: CubicBC, extrapolate: CubicExtrapolate) -> Self { Self { - boundary_cond: bc, + boundary_condition, + extrapolate, z: as Default>::default(), } } + /// Natural cubic spline + /// (splines straighten at outermost knots). + /// + /// 2nd derivatives at outermost knots are zero: + /// z0 = zn = 0 + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Linear`]. pub fn natural() -> Self { - Self::new(CubicBC::Natural) + Self::new(CubicBC::Natural, CubicExtrapolate::Linear) } - pub fn clamped(a: T, b: T) -> Self { - Self::new(CubicBC::Clamped(a, b)) + /// Clamped cubic spline. + /// + /// 1st derivatives at outermost knots (k0, kn) are given. + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Spline`]. + pub fn clamped(k0: T, kn: T) -> Self { + Self::new(CubicBC::Clamped(k0, kn), CubicExtrapolate::Spline) } + /// Not-a-knot cubic spline. + /// + /// Spline 3rd derivatives at second and second-to-last knots are equal, respectively: + /// S'''0(x1) = S'''1(x1) and + /// S'''n-1(xn-1) = S'''n(xn-1). + /// + /// In other words, this means the first and second spline at data boundaries are the same cubic. + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Spline`]. pub fn not_a_knot() -> Self { - Self::new(CubicBC::NotAKnot) + Self::new(CubicBC::NotAKnot, CubicExtrapolate::Spline) } + /// Periodic cubic spline. + /// + /// Spline 1st & 2nd derivatives at outermost knots are equal: + /// k0 = kn, z0 = zn + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Wrap`]. pub fn periodic() -> Self { - Self::new(CubicBC::Periodic) + Self::new(CubicBC::Periodic, CubicExtrapolate::Wrap) + } +} + +impl Cubic +where + T: Float + Default + Debug, +{ + pub(crate) fn evaluate_1d + RawDataClone + Clone>( + &self, + point: &[T; 1], + l: usize, + data: &InterpData1D, + ) -> Result { + let six = ::from(6.).unwrap(); + let u = l + 1; + let h_i = data.grid[0][u] - data.grid[0][l]; + Ok( + self.z[u] / (six * h_i) * (point[0] - data.grid[0][l]).powi(3) + + self.z[l] / (six * h_i) * (data.grid[0][u] - point[0]).powi(3) + + (data.values[u] / h_i - self.z[u] * h_i / six) * (point[0] - data.grid[0][l]) + + (data.values[l] / h_i - self.z[l] * h_i / six) * (data.grid[0][u] - point[0]), + ) + } + + pub(crate) fn evaluate_2d + RawDataClone + Clone>( + &self, + point: &[T; 2], + l: usize, + data: &InterpData2D, + ) -> Result { + todo!() + } + + pub(crate) fn evaluate_3d + RawDataClone + Clone>( + &self, + point: &[T; 3], + l: usize, + data: &InterpData3D, + ) -> Result { + todo!() + } + + pub(crate) fn evaluate_nd + RawDataClone + Clone>( + &self, + point: &[T], + l: usize, + data: &InterpDataND, + ) -> Result { + todo!() } } @@ -44,7 +149,7 @@ where /// - `b`: diagonal /// - `c`: super-diagonal (1 element shorter than `b` and `d`) /// - `d`: right-hand side - pub fn thomas( + pub(crate) fn thomas( a: ArrayView1, b: ArrayView1, c: ArrayView1, @@ -80,18 +185,3 @@ where x } } - -/// Cubic spline boundary conditions. -#[derive(Copy, Clone, Debug, Default, PartialEq)] -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub enum CubicBC { - /// Second derivatives at endpoints are 0, thus extrapolation is linear. - // https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf - #[default] - Natural, - /// Specific first derivatives at endpoints. - Clamped(T, T), - NotAKnot, - // https://math.ou.edu/~npetrov/project-5093-s11.pdf - Periodic, -} From e29b9f7dc08ca0e64911ad50b9aa6bbfc4e62ee5 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 09:25:28 -0600 Subject: [PATCH 08/18] remove cubic Default bound --- src/one/strategies.rs | 2 +- src/strategy/cubic.rs | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/one/strategies.rs b/src/one/strategies.rs index 1136f78..1db77dc 100644 --- a/src/one/strategies.rs +++ b/src/one/strategies.rs @@ -39,7 +39,7 @@ where impl Strategy1D for Cubic where D: Data + RawDataClone + Clone, - D::Elem: Float + Euclid + Default + Debug, + D::Elem: Float + Euclid + Debug, { fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { // Number of segments diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs index 75f9d3c..1fa4215 100644 --- a/src/strategy/cubic.rs +++ b/src/strategy/cubic.rs @@ -2,7 +2,7 @@ use super::*; #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub struct Cubic { +pub struct Cubic { /// Cubic spline boundary conditions. pub boundary_condition: CubicBC, /// Behavior of [`Extrapolate::Enable`]. @@ -30,20 +30,17 @@ pub enum CubicExtrapolate { Linear, /// Use nearest spline to extrapolate. Spline, - /// Same as [`Extrapolate::Wrap`]. Default for periodic splines. + /// Same as [`Extrapolate::Wrap`], default for periodic splines. Wrap, } -impl Cubic -where - T: Default, -{ +impl Cubic { /// Cubic spline with given boundary condition and extrapolation behavior. pub fn new(boundary_condition: CubicBC, extrapolate: CubicExtrapolate) -> Self { Self { boundary_condition, extrapolate, - z: as Default>::default(), + z: Array1::from_vec(Vec::new()), } } @@ -69,7 +66,7 @@ where /// Not-a-knot cubic spline. /// - /// Spline 3rd derivatives at second and second-to-last knots are equal, respectively: + /// Spline 3rd derivatives at second and second-to-last knots are continuous, respectively: /// S'''0(x1) = S'''1(x1) and /// S'''n-1(xn-1) = S'''n(xn-1). /// @@ -93,7 +90,7 @@ where impl Cubic where - T: Float + Default + Debug, + T: Float + Debug, { pub(crate) fn evaluate_1d + RawDataClone + Clone>( &self, @@ -142,9 +139,10 @@ where impl Cubic where - T: Num + Copy + Default, + T: Num + Copy, { - /// Solves Ax = d for a tridiagonal matrix A using the [Thomas algorithm](https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm). + /// Solves Ax = d for a tridiagonal matrix A using the + /// [Thomas algorithm](https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm). /// - `a`: sub-diagonal (1 element shorter than `b` and `d`) /// - `b`: diagonal /// - `c`: super-diagonal (1 element shorter than `b` and `d`) @@ -160,9 +158,9 @@ where assert_eq!(b.len(), n); assert_eq!(c.len(), n - 1); - let mut c_prime = Array1::default(n - 1); - let mut d_prime = Array1::default(n); - let mut x = Array1::default(n); + let mut c_prime = Array1::zeros(n - 1); + let mut d_prime = Array1::zeros(n); + let mut x = Array1::zeros(n); // Forward sweep c_prime[0] = c[0] / b[0]; From a1a7ed7176962a11fc8e98a3a75e5a540d5806ba Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 09:51:42 -0600 Subject: [PATCH 09/18] draft test for periodic --- src/one/mod.rs | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/one/mod.rs b/src/one/mod.rs index fd00b12..ec270e0 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -345,7 +345,7 @@ mod tests { // Interpolating at knots returns values for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); } let x0 = x.first().unwrap(); @@ -412,7 +412,7 @@ mod tests { // Interpolating at knots returns values for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]) + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); } // Left slope = a @@ -442,4 +442,45 @@ mod tests { assert_approx_eq!(slopes.mean().unwrap(), b); } } + + #[test] + fn test_cubic_periodic() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., -90., 19., 99., 291., 444.]; + + let x0 = x.first().unwrap(); + let xn = x.last().unwrap(); + let range = xn - x0; + let x_low = x0 - 0.2 * range; + let x_high = x0 + 0.2 * range; + let xs_left = Array1::linspace(x_low, *x0, 50); + let xs_right = Array1::linspace(*xn, x_high, 50); + + let interp_extrap_enable = + Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Enable).unwrap(); + let interp_extrap_wrap = + Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Wrap).unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp_extrap_enable.interpolate(&[x[i]]).unwrap(), f_x[i]); + assert_approx_eq!(interp_extrap_wrap.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic() + for x in xs_left { + assert_eq!( + interp_extrap_enable.interpolate(&[x]).unwrap(), + interp_extrap_wrap.interpolate(&[x]).unwrap() + ); + } + for x in xs_right { + assert_eq!( + interp_extrap_enable.interpolate(&[x]).unwrap(), + interp_extrap_wrap.interpolate(&[x]).unwrap() + ); + } + + // TODO: test for slopes? + } } From ef31126d778cb6aedb09b84a27436521271a182b Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 09:59:43 -0600 Subject: [PATCH 10/18] draft test for periodic --- src/one/mod.rs | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/src/one/mod.rs b/src/one/mod.rs index ec270e0..a9d674e 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -448,14 +448,6 @@ mod tests { let x = array![1., 2., 3., 5., 7., 8.]; let f_x = array![3., -90., 19., 99., 291., 444.]; - let x0 = x.first().unwrap(); - let xn = x.last().unwrap(); - let range = xn - x0; - let x_low = x0 - 0.2 * range; - let x_high = x0 + 0.2 * range; - let xs_left = Array1::linspace(x_low, *x0, 50); - let xs_right = Array1::linspace(*xn, x_high, 50); - let interp_extrap_enable = Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Enable).unwrap(); let interp_extrap_wrap = @@ -468,6 +460,13 @@ mod tests { } // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic() + let x0 = x.first().unwrap(); + let xn = x.last().unwrap(); + let range = xn - x0; + let x_low = x0 - 0.2 * range; + let x_high = x0 + 0.2 * range; + let xs_left = Array1::linspace(x_low, *x0, 50); + let xs_right = Array1::linspace(*xn, x_high, 50); for x in xs_left { assert_eq!( interp_extrap_enable.interpolate(&[x]).unwrap(), @@ -481,6 +480,33 @@ mod tests { ); } - // TODO: test for slopes? + // Slope left + let xs_left = Array1::linspace(x_low, x_low + 2e6, 50); + let ys_left: Array1 = xs_left + .iter() + .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) + .collect(); + let slopes_left: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys_left.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + let slope_left = slopes_left.mean().unwrap(); + // Slope right + let xs_right = Array1::linspace(x_high - 2e6, x_high, 50); + let ys_right: Array1 = xs_right + .iter() + .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) + .collect(); + let slopes_right: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys_right.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + let slope_right = slopes_right.mean().unwrap(); + // Slopes at left and right are equal + assert_approx_eq!(slope_left, slope_right); } } From dfcd591135b3bba3b4d2af238183b4aa188135a6 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 10:01:34 -0600 Subject: [PATCH 11/18] draft test for periodic --- src/one/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/one/mod.rs b/src/one/mod.rs index a9d674e..3a0c825 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -508,5 +508,8 @@ mod tests { let slope_right = slopes_right.mean().unwrap(); // Slopes at left and right are equal assert_approx_eq!(slope_left, slope_right); + // Second derivatives at left and right are equal + let z = interp_extrap_enable.strategy.z; + assert_approx_eq!(z.first().unwrap(), z.last().unwrap()); } } From 954a7394030fd62c8a0574a116c2b86e7ecd9d5c Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 10:03:54 -0600 Subject: [PATCH 12/18] draft test for periodic --- src/strategy/cubic.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs index 1fa4215..bb2308c 100644 --- a/src/strategy/cubic.rs +++ b/src/strategy/cubic.rs @@ -82,7 +82,8 @@ impl Cubic { /// Spline 1st & 2nd derivatives at outermost knots are equal: /// k0 = kn, z0 = zn /// - /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Wrap`]. + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Wrap`], + /// thus is equivalent to [`Extrapolate::Wrap`]. pub fn periodic() -> Self { Self::new(CubicBC::Periodic, CubicExtrapolate::Wrap) } From feff80233d874ea615e20cee982ac6e2431afcea Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 10:28:44 -0600 Subject: [PATCH 13/18] reorganize tests into files --- src/n/mod.rs | 380 +----------------------------------------- src/n/tests.rs | 373 +++++++++++++++++++++++++++++++++++++++++ src/one/mod.rs | 335 +------------------------------------ src/one/tests.rs | 329 ++++++++++++++++++++++++++++++++++++ src/strategy/cubic.rs | 1 + src/three/mod.rs | 208 +---------------------- src/three/tests.rs | 202 ++++++++++++++++++++++ src/two/mod.rs | 168 +------------------ src/two/tests.rs | 162 ++++++++++++++++++ 9 files changed, 1076 insertions(+), 1082 deletions(-) create mode 100644 src/n/tests.rs create mode 100644 src/one/tests.rs create mode 100644 src/three/tests.rs create mode 100644 src/two/tests.rs diff --git a/src/n/mod.rs b/src/n/mod.rs index b1faff2..370118a 100644 --- a/src/n/mod.rs +++ b/src/n/mod.rs @@ -5,6 +5,9 @@ use super::*; use ndarray::prelude::*; mod strategies; +#[cfg(test)] +mod tests; + /// Interpolator data where N is determined at runtime #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] @@ -263,380 +266,3 @@ where self.check_extrapolate(&self.extrapolate) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_linear() { - let x = array![0.05, 0.10, 0.15]; - let y = array![0.10, 0.20, 0.30]; - let z = array![0.20, 0.40, 0.60]; - let grid = vec![x.view(), y.view(), z.view()]; - let values = array![ - [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], - [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.]], - ] - .into_dyn(); - let interp = InterpND::new(grid, values.view(), Linear, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for i in 0..x.len() { - for j in 0..y.len() { - for k in 0..z.len() { - assert_eq!( - &interp.interpolate(&[x[i], y[j], z[k]]).unwrap(), - values.slice(s![i, j, k]).first().unwrap() - ); - } - } - } - assert_approx_eq!(interp.interpolate(&[x[0], y[0], 0.3]).unwrap(), 0.5); - assert_approx_eq!(interp.interpolate(&[x[0], 0.15, z[0]]).unwrap(), 1.5); - assert_approx_eq!(interp.interpolate(&[x[0], 0.15, 0.3]).unwrap(), 2.0); - assert_approx_eq!(interp.interpolate(&[0.075, y[0], z[0]]).unwrap(), 4.5); - assert_approx_eq!(interp.interpolate(&[0.075, y[0], 0.3]).unwrap(), 5.); - assert_approx_eq!(interp.interpolate(&[0.075, 0.15, z[0]]).unwrap(), 6.); - } - - #[test] - fn test_linear_offset() { - let interp = InterpND::new( - vec![array![0., 1.], array![0., 1.], array![0., 1.]], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert_approx_eq!(interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(), 3.2) - } - - #[test] - fn test_linear_extrapolation_2d() { - let interp_2d = crate::interpolator::Interp2D::new( - array![0.05, 0.10, 0.15], - array![0.10, 0.20, 0.30], - array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - Linear, - Extrapolate::Enable, - ) - .unwrap(); - let interp_nd = InterpND::new( - vec![array![0.05, 0.10, 0.15], array![0.10, 0.20, 0.30]], - array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]].into_dyn(), - Linear, - Extrapolate::Enable, - ) - .unwrap(); - // below x, below y - assert_eq!( - interp_2d.interpolate(&[0.0, 0.0]).unwrap(), - interp_nd.interpolate(&[0.0, 0.0]).unwrap() - ); - assert_eq!( - interp_2d.interpolate(&[0.03, 0.04]).unwrap(), - interp_nd.interpolate(&[0.03, 0.04]).unwrap(), - ); - // below x, above y - assert_eq!( - interp_2d.interpolate(&[0.0, 0.32]).unwrap(), - interp_nd.interpolate(&[0.0, 0.32]).unwrap(), - ); - assert_eq!( - interp_2d.interpolate(&[0.03, 0.36]).unwrap(), - interp_nd.interpolate(&[0.03, 0.36]).unwrap() - ); - // above x, below y - assert_eq!( - interp_2d.interpolate(&[0.17, 0.0]).unwrap(), - interp_nd.interpolate(&[0.17, 0.0]).unwrap(), - ); - assert_eq!( - interp_2d.interpolate(&[0.19, 0.04]).unwrap(), - interp_nd.interpolate(&[0.19, 0.04]).unwrap(), - ); - // above x, above y - assert_eq!( - interp_2d.interpolate(&[0.17, 0.32]).unwrap(), - interp_nd.interpolate(&[0.17, 0.32]).unwrap() - ); - assert_eq!( - interp_2d.interpolate(&[0.19, 0.36]).unwrap(), - interp_nd.interpolate(&[0.19, 0.36]).unwrap() - ); - } - - #[test] - fn test_linear_extrapolate_3d() { - let interp_3d = crate::interpolator::Interp3D::new( - array![0.05, 0.10, 0.15], - array![0.10, 0.20, 0.30], - array![0.20, 0.40, 0.60], - array![ - [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], - [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], - ], - Linear, - Extrapolate::Enable, - ) - .unwrap(); - let interp_nd = InterpND::new( - vec![ - array![0.05, 0.10, 0.15], - array![0.10, 0.20, 0.30], - array![0.20, 0.40, 0.60], - ], - array![ - [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], - [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.]], - ] - .into_dyn(), - Linear, - Extrapolate::Enable, - ) - .unwrap(); - // below x, below y, below z - assert_eq!( - interp_3d.interpolate(&[0.01, 0.06, 0.17]).unwrap(), - interp_nd.interpolate(&[0.01, 0.06, 0.17]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.02, 0.08, 0.19]).unwrap(), - interp_nd.interpolate(&[0.02, 0.08, 0.19]).unwrap() - ); - // below x, below y, above z - assert_eq!( - interp_3d.interpolate(&[0.01, 0.06, 0.63]).unwrap(), - interp_nd.interpolate(&[0.01, 0.06, 0.63]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.02, 0.08, 0.65]).unwrap(), - interp_nd.interpolate(&[0.02, 0.08, 0.65]).unwrap() - ); - // below x, above y, below z - assert_eq!( - interp_3d.interpolate(&[0.01, 0.33, 0.17]).unwrap(), - interp_nd.interpolate(&[0.01, 0.33, 0.17]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.02, 0.36, 0.19]).unwrap(), - interp_nd.interpolate(&[0.02, 0.36, 0.19]).unwrap() - ); - // below x, above y, above z - assert_eq!( - interp_3d.interpolate(&[0.01, 0.33, 0.63]).unwrap(), - interp_nd.interpolate(&[0.01, 0.33, 0.63]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.02, 0.36, 0.65]).unwrap(), - interp_nd.interpolate(&[0.02, 0.36, 0.65]).unwrap() - ); - // above x, below y, below z - assert_eq!( - interp_3d.interpolate(&[0.17, 0.06, 0.17]).unwrap(), - interp_nd.interpolate(&[0.17, 0.06, 0.17]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.19, 0.08, 0.19]).unwrap(), - interp_nd.interpolate(&[0.19, 0.08, 0.19]).unwrap() - ); - // above x, below y, above z - assert_eq!( - interp_3d.interpolate(&[0.17, 0.06, 0.63]).unwrap(), - interp_nd.interpolate(&[0.17, 0.06, 0.63]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.19, 0.08, 0.65]).unwrap(), - interp_nd.interpolate(&[0.19, 0.08, 0.65]).unwrap() - ); - // above x, above y, below z - assert_eq!( - interp_3d.interpolate(&[0.17, 0.33, 0.17]).unwrap(), - interp_nd.interpolate(&[0.17, 0.33, 0.17]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.19, 0.36, 0.19]).unwrap(), - interp_nd.interpolate(&[0.19, 0.36, 0.19]).unwrap() - ); - // above x, above y, above z - assert_eq!( - interp_3d.interpolate(&[0.17, 0.33, 0.63]).unwrap(), - interp_nd.interpolate(&[0.17, 0.33, 0.63]).unwrap() - ); - assert_eq!( - interp_3d.interpolate(&[0.19, 0.36, 0.65]).unwrap(), - interp_nd.interpolate(&[0.19, 0.36, 0.65]).unwrap() - ); - } - - #[test] - fn test_nearest() { - let x = array![0., 1.]; - let y = array![0., 1.]; - let z = array![0., 1.]; - let grid = vec![x.view(), y.view(), z.view()]; - let values = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(); - let interp = InterpND::new(grid, values.view(), Nearest, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for i in 0..x.len() { - for j in 0..y.len() { - for k in 0..z.len() { - assert_eq!( - &interp.interpolate(&[x[i], y[j], z[k]]).unwrap(), - values.slice(s![i, j, k]).first().unwrap() - ); - } - } - } - assert_eq!(interp.interpolate(&[0.25, 0.25, 0.25]).unwrap(), 0.); - assert_eq!(interp.interpolate(&[0.25, 0.75, 0.25]).unwrap(), 2.); - assert_eq!(interp.interpolate(&[0.75, 0.25, 0.75]).unwrap(), 5.); - assert_eq!(interp.interpolate(&[0.75, 0.75, 0.75]).unwrap(), 7.); - } - - #[test] - fn test_extrapolate_inputs() { - // Extrapolate::Extrapolate - assert!(matches!( - InterpND::new( - vec![array![0., 1.], array![0., 1.], array![0., 1.]], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), - Nearest, - Extrapolate::Enable, - ) - .unwrap_err(), - ValidateError::ExtrapolateSelection(_) - )); - // Extrapolate::Error - let interp = InterpND::new( - vec![array![0., 1.], array![0., 1.], array![0., 1.]], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert!(matches!( - interp.interpolate(&[-1., -1., -1.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - assert!(matches!( - interp.interpolate(&[2., 2., 2.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - } - - #[test] - fn test_extrapolate_fill() { - let interp = InterpND::new( - vec![array![0.1, 1.1], array![0.2, 1.2], array![0.3, 1.3]], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), - Linear, - Extrapolate::Fill(f64::NAN), - ) - .unwrap(); - assert!(interp.interpolate(&[0., 0., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 0., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 2., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 2., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 0., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 0., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 2., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 2., 2.]).unwrap().is_nan()); - } - - #[test] - fn test_extrapolate_clamp() { - let x = array![0.1, 1.1]; - let y = array![0.2, 1.2]; - let z = array![0.3, 1.3]; - let values = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(); - let interp = InterpND::new( - vec![x.view(), y.view(), z.view()], - values.view(), - Linear, - Extrapolate::Clamp, - ) - .unwrap(); - assert_eq!( - interp.interpolate(&[-1., -1., -1.]).unwrap(), - values[[0, 0, 0]] - ); - assert_eq!( - interp.interpolate(&[-1., 2., -1.]).unwrap(), - values[[0, 1, 0]] - ); - assert_eq!( - interp.interpolate(&[2., -1., 2.]).unwrap(), - values[[1, 0, 1]] - ); - assert_eq!( - interp.interpolate(&[2., 2., 2.]).unwrap(), - values[[1, 1, 1]] - ); - } - - #[test] - fn test_extrapolate_wrap() { - let interp = InterpND::new( - vec![array![0., 1.], array![0., 1.], array![0., 1.]], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), - Linear, - Extrapolate::Wrap, - ) - .unwrap(); - assert_eq!( - interp.interpolate(&[-0.25, -0.2, -0.4]).unwrap(), - interp.interpolate(&[0.75, 0.8, 0.6]).unwrap(), - ); - assert_eq!( - interp.interpolate(&[-0.25, 2.1, -0.4]).unwrap(), - interp.interpolate(&[0.75, 0.1, 0.6]).unwrap(), - ); - assert_eq!( - interp.interpolate(&[-0.25, 2.1, 2.3]).unwrap(), - interp.interpolate(&[0.75, 0.1, 0.3]).unwrap(), - ); - assert_eq!( - interp.interpolate(&[2.5, 2.1, 2.3]).unwrap(), - interp.interpolate(&[0.5, 0.1, 0.3]).unwrap(), - ); - } - - #[test] - fn test_mismatched_grid() { - assert!(matches!( - InterpND::new( - // 3-D grid - vec![array![0., 1.], array![0., 1.], array![0., 1.]], - // 2-D values - array![[0., 1.], [2., 3.]].into_dyn(), - Linear, - Extrapolate::Error, - ) - .unwrap_err(), - ValidateError::Other(_) - )); - assert!(InterpND::new( - vec![array![]], - array![0.].into_dyn(), - Linear, - Extrapolate::Error, - ) - .is_ok(),); - assert!(matches!( - InterpND::new( - // non-empty grid - vec![array![1.]], - // 0-D values - array![0.].into_dyn(), - Linear, - Extrapolate::Error, - ) - .unwrap_err(), - ValidateError::Other(_) - )); - } -} diff --git a/src/n/tests.rs b/src/n/tests.rs new file mode 100644 index 0000000..4b4c837 --- /dev/null +++ b/src/n/tests.rs @@ -0,0 +1,373 @@ +use super::*; + +#[test] +fn test_linear() { + let x = array![0.05, 0.10, 0.15]; + let y = array![0.10, 0.20, 0.30]; + let z = array![0.20, 0.40, 0.60]; + let grid = vec![x.view(), y.view(), z.view()]; + let values = array![ + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], + [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.]], + ] + .into_dyn(); + let interp = InterpND::new(grid, values.view(), Linear, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for i in 0..x.len() { + for j in 0..y.len() { + for k in 0..z.len() { + assert_eq!( + &interp.interpolate(&[x[i], y[j], z[k]]).unwrap(), + values.slice(s![i, j, k]).first().unwrap() + ); + } + } + } + assert_approx_eq!(interp.interpolate(&[x[0], y[0], 0.3]).unwrap(), 0.5); + assert_approx_eq!(interp.interpolate(&[x[0], 0.15, z[0]]).unwrap(), 1.5); + assert_approx_eq!(interp.interpolate(&[x[0], 0.15, 0.3]).unwrap(), 2.0); + assert_approx_eq!(interp.interpolate(&[0.075, y[0], z[0]]).unwrap(), 4.5); + assert_approx_eq!(interp.interpolate(&[0.075, y[0], 0.3]).unwrap(), 5.); + assert_approx_eq!(interp.interpolate(&[0.075, 0.15, z[0]]).unwrap(), 6.); +} + +#[test] +fn test_linear_offset() { + let interp = InterpND::new( + vec![array![0., 1.], array![0., 1.], array![0., 1.]], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert_approx_eq!(interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(), 3.2) +} + +#[test] +fn test_linear_extrapolation_2d() { + let interp_2d = crate::interpolator::Interp2D::new( + array![0.05, 0.10, 0.15], + array![0.10, 0.20, 0.30], + array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + Linear, + Extrapolate::Enable, + ) + .unwrap(); + let interp_nd = InterpND::new( + vec![array![0.05, 0.10, 0.15], array![0.10, 0.20, 0.30]], + array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]].into_dyn(), + Linear, + Extrapolate::Enable, + ) + .unwrap(); + // below x, below y + assert_eq!( + interp_2d.interpolate(&[0.0, 0.0]).unwrap(), + interp_nd.interpolate(&[0.0, 0.0]).unwrap() + ); + assert_eq!( + interp_2d.interpolate(&[0.03, 0.04]).unwrap(), + interp_nd.interpolate(&[0.03, 0.04]).unwrap(), + ); + // below x, above y + assert_eq!( + interp_2d.interpolate(&[0.0, 0.32]).unwrap(), + interp_nd.interpolate(&[0.0, 0.32]).unwrap(), + ); + assert_eq!( + interp_2d.interpolate(&[0.03, 0.36]).unwrap(), + interp_nd.interpolate(&[0.03, 0.36]).unwrap() + ); + // above x, below y + assert_eq!( + interp_2d.interpolate(&[0.17, 0.0]).unwrap(), + interp_nd.interpolate(&[0.17, 0.0]).unwrap(), + ); + assert_eq!( + interp_2d.interpolate(&[0.19, 0.04]).unwrap(), + interp_nd.interpolate(&[0.19, 0.04]).unwrap(), + ); + // above x, above y + assert_eq!( + interp_2d.interpolate(&[0.17, 0.32]).unwrap(), + interp_nd.interpolate(&[0.17, 0.32]).unwrap() + ); + assert_eq!( + interp_2d.interpolate(&[0.19, 0.36]).unwrap(), + interp_nd.interpolate(&[0.19, 0.36]).unwrap() + ); +} + +#[test] +fn test_linear_extrapolate_3d() { + let interp_3d = crate::interpolator::Interp3D::new( + array![0.05, 0.10, 0.15], + array![0.10, 0.20, 0.30], + array![0.20, 0.40, 0.60], + array![ + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], + [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], + ], + Linear, + Extrapolate::Enable, + ) + .unwrap(); + let interp_nd = InterpND::new( + vec![ + array![0.05, 0.10, 0.15], + array![0.10, 0.20, 0.30], + array![0.20, 0.40, 0.60], + ], + array![ + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], + [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.]], + ] + .into_dyn(), + Linear, + Extrapolate::Enable, + ) + .unwrap(); + // below x, below y, below z + assert_eq!( + interp_3d.interpolate(&[0.01, 0.06, 0.17]).unwrap(), + interp_nd.interpolate(&[0.01, 0.06, 0.17]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.02, 0.08, 0.19]).unwrap(), + interp_nd.interpolate(&[0.02, 0.08, 0.19]).unwrap() + ); + // below x, below y, above z + assert_eq!( + interp_3d.interpolate(&[0.01, 0.06, 0.63]).unwrap(), + interp_nd.interpolate(&[0.01, 0.06, 0.63]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.02, 0.08, 0.65]).unwrap(), + interp_nd.interpolate(&[0.02, 0.08, 0.65]).unwrap() + ); + // below x, above y, below z + assert_eq!( + interp_3d.interpolate(&[0.01, 0.33, 0.17]).unwrap(), + interp_nd.interpolate(&[0.01, 0.33, 0.17]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.02, 0.36, 0.19]).unwrap(), + interp_nd.interpolate(&[0.02, 0.36, 0.19]).unwrap() + ); + // below x, above y, above z + assert_eq!( + interp_3d.interpolate(&[0.01, 0.33, 0.63]).unwrap(), + interp_nd.interpolate(&[0.01, 0.33, 0.63]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.02, 0.36, 0.65]).unwrap(), + interp_nd.interpolate(&[0.02, 0.36, 0.65]).unwrap() + ); + // above x, below y, below z + assert_eq!( + interp_3d.interpolate(&[0.17, 0.06, 0.17]).unwrap(), + interp_nd.interpolate(&[0.17, 0.06, 0.17]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.19, 0.08, 0.19]).unwrap(), + interp_nd.interpolate(&[0.19, 0.08, 0.19]).unwrap() + ); + // above x, below y, above z + assert_eq!( + interp_3d.interpolate(&[0.17, 0.06, 0.63]).unwrap(), + interp_nd.interpolate(&[0.17, 0.06, 0.63]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.19, 0.08, 0.65]).unwrap(), + interp_nd.interpolate(&[0.19, 0.08, 0.65]).unwrap() + ); + // above x, above y, below z + assert_eq!( + interp_3d.interpolate(&[0.17, 0.33, 0.17]).unwrap(), + interp_nd.interpolate(&[0.17, 0.33, 0.17]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.19, 0.36, 0.19]).unwrap(), + interp_nd.interpolate(&[0.19, 0.36, 0.19]).unwrap() + ); + // above x, above y, above z + assert_eq!( + interp_3d.interpolate(&[0.17, 0.33, 0.63]).unwrap(), + interp_nd.interpolate(&[0.17, 0.33, 0.63]).unwrap() + ); + assert_eq!( + interp_3d.interpolate(&[0.19, 0.36, 0.65]).unwrap(), + interp_nd.interpolate(&[0.19, 0.36, 0.65]).unwrap() + ); +} + +#[test] +fn test_nearest() { + let x = array![0., 1.]; + let y = array![0., 1.]; + let z = array![0., 1.]; + let grid = vec![x.view(), y.view(), z.view()]; + let values = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(); + let interp = InterpND::new(grid, values.view(), Nearest, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for i in 0..x.len() { + for j in 0..y.len() { + for k in 0..z.len() { + assert_eq!( + &interp.interpolate(&[x[i], y[j], z[k]]).unwrap(), + values.slice(s![i, j, k]).first().unwrap() + ); + } + } + } + assert_eq!(interp.interpolate(&[0.25, 0.25, 0.25]).unwrap(), 0.); + assert_eq!(interp.interpolate(&[0.25, 0.75, 0.25]).unwrap(), 2.); + assert_eq!(interp.interpolate(&[0.75, 0.25, 0.75]).unwrap(), 5.); + assert_eq!(interp.interpolate(&[0.75, 0.75, 0.75]).unwrap(), 7.); +} + +#[test] +fn test_extrapolate_inputs() { + // Extrapolate::Extrapolate + assert!(matches!( + InterpND::new( + vec![array![0., 1.], array![0., 1.], array![0., 1.]], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), + Nearest, + Extrapolate::Enable, + ) + .unwrap_err(), + ValidateError::ExtrapolateSelection(_) + )); + // Extrapolate::Error + let interp = InterpND::new( + vec![array![0., 1.], array![0., 1.], array![0., 1.]], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert!(matches!( + interp.interpolate(&[-1., -1., -1.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); + assert!(matches!( + interp.interpolate(&[2., 2., 2.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); +} + +#[test] +fn test_extrapolate_fill() { + let interp = InterpND::new( + vec![array![0.1, 1.1], array![0.2, 1.2], array![0.3, 1.3]], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), + Linear, + Extrapolate::Fill(f64::NAN), + ) + .unwrap(); + assert!(interp.interpolate(&[0., 0., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 0., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 2., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 2., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 0., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 0., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 2., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 2., 2.]).unwrap().is_nan()); +} + +#[test] +fn test_extrapolate_clamp() { + let x = array![0.1, 1.1]; + let y = array![0.2, 1.2]; + let z = array![0.3, 1.3]; + let values = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(); + let interp = InterpND::new( + vec![x.view(), y.view(), z.view()], + values.view(), + Linear, + Extrapolate::Clamp, + ) + .unwrap(); + assert_eq!( + interp.interpolate(&[-1., -1., -1.]).unwrap(), + values[[0, 0, 0]] + ); + assert_eq!( + interp.interpolate(&[-1., 2., -1.]).unwrap(), + values[[0, 1, 0]] + ); + assert_eq!( + interp.interpolate(&[2., -1., 2.]).unwrap(), + values[[1, 0, 1]] + ); + assert_eq!( + interp.interpolate(&[2., 2., 2.]).unwrap(), + values[[1, 1, 1]] + ); +} + +#[test] +fn test_extrapolate_wrap() { + let interp = InterpND::new( + vec![array![0., 1.], array![0., 1.], array![0., 1.]], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(), + Linear, + Extrapolate::Wrap, + ) + .unwrap(); + assert_eq!( + interp.interpolate(&[-0.25, -0.2, -0.4]).unwrap(), + interp.interpolate(&[0.75, 0.8, 0.6]).unwrap(), + ); + assert_eq!( + interp.interpolate(&[-0.25, 2.1, -0.4]).unwrap(), + interp.interpolate(&[0.75, 0.1, 0.6]).unwrap(), + ); + assert_eq!( + interp.interpolate(&[-0.25, 2.1, 2.3]).unwrap(), + interp.interpolate(&[0.75, 0.1, 0.3]).unwrap(), + ); + assert_eq!( + interp.interpolate(&[2.5, 2.1, 2.3]).unwrap(), + interp.interpolate(&[0.5, 0.1, 0.3]).unwrap(), + ); +} + +#[test] +fn test_mismatched_grid() { + assert!(matches!( + InterpND::new( + // 3-D grid + vec![array![0., 1.], array![0., 1.], array![0., 1.]], + // 2-D values + array![[0., 1.], [2., 3.]].into_dyn(), + Linear, + Extrapolate::Error, + ) + .unwrap_err(), + ValidateError::Other(_) + )); + assert!(InterpND::new( + vec![array![]], + array![0.].into_dyn(), + Linear, + Extrapolate::Error, + ) + .is_ok(),); + assert!(matches!( + InterpND::new( + // non-empty grid + vec![array![1.]], + // 0-D values + array![0.].into_dyn(), + Linear, + Extrapolate::Error, + ) + .unwrap_err(), + ValidateError::Other(_) + )); +} diff --git a/src/one/mod.rs b/src/one/mod.rs index 3a0c825..b035436 100644 --- a/src/one/mod.rs +++ b/src/one/mod.rs @@ -3,6 +3,8 @@ use super::*; mod strategies; +#[cfg(test)] +mod tests; const N: usize = 1; @@ -180,336 +182,3 @@ where self.check_extrapolate(&self.extrapolate) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_invalid_args() { - let interp = Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert!(matches!( - interp.interpolate(&[]).unwrap_err(), - InterpolateError::PointLength(_) - )); - assert_eq!(interp.interpolate(&[1.0]).unwrap(), 0.4); - } - - #[test] - fn test_linear() { - let x = array![0., 1., 2., 3., 4.]; - let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; - let interp = Interp1D::new(x.view(), f_x.view(), Linear, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); - } - assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[3.75]).unwrap(), 0.95); - assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); - } - - #[test] - fn test_left_nearest() { - let x = array![0., 1., 2., 3., 4.]; - let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; - let interp = Interp1D::new(x.view(), f_x.view(), LeftNearest, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); - } - assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[3.75]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); - } - - #[test] - fn test_right_nearest() { - let x = array![0., 1., 2., 3., 4.]; - let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; - let interp = Interp1D::new(x.view(), f_x.view(), RightNearest, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); - } - assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[3.25]).unwrap(), 1.0); - assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); - } - - #[test] - fn test_nearest() { - let x = array![0., 1., 2., 3., 4.]; - let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; - let interp = Interp1D::new(x.view(), f_x.view(), Nearest, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); - } - assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[3.25]).unwrap(), 0.8); - assert_eq!(interp.interpolate(&[3.50]).unwrap(), 1.0); - assert_eq!(interp.interpolate(&[3.75]).unwrap(), 1.0); - assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); - } - - #[test] - fn test_extrapolate_inputs() { - // Incorrect extrapolation selection - assert!(matches!( - Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Nearest, - Extrapolate::Enable, - ) - .unwrap_err(), - ValidateError::ExtrapolateSelection(_) - )); - - // Extrapolate::Error - let interp = Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Linear, - Extrapolate::Error, - ) - .unwrap(); - // Fail to extrapolate below lowest grid value - assert!(matches!( - interp.interpolate(&[-1.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - // Fail to extrapolate above highest grid value - assert!(matches!( - interp.interpolate(&[5.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - } - - #[test] - fn test_extrapolate_fill() { - let interp = Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Linear, - Extrapolate::Fill(f64::NAN), - ) - .unwrap(); - assert_eq!(interp.interpolate(&[1.5]).unwrap(), 0.5); - assert_eq!(interp.interpolate(&[2.]).unwrap(), 0.6); - assert!(interp.interpolate(&[-1.]).unwrap().is_nan()); - assert!(interp.interpolate(&[5.]).unwrap().is_nan()); - } - - #[test] - fn test_extrapolate_clamp() { - let interp = Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Linear, - Extrapolate::Clamp, - ) - .unwrap(); - assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.2); - assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.0); - } - - #[test] - fn test_extrapolate() { - let interp = Interp1D::new( - array![0., 1., 2., 3., 4.], - array![0.2, 0.4, 0.6, 0.8, 1.0], - Linear, - Extrapolate::Enable, - ) - .unwrap(); - assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.0); - assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05); - assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2); - } - - #[test] - fn test_cubic_natural() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., 6., 19., 99., 291., 444.]; - - let interp = - Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); - - // Interpolating at knots returns values - for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); - } - - let x0 = x.first().unwrap(); - let xn = x.last().unwrap(); - let y0 = f_x.first().unwrap(); - let yn = f_x.last().unwrap(); - - let range = xn - x0; - - let x_low = x0 - 0.2 * range; - let y_low = interp.interpolate(&[x_low]).unwrap(); - let slope_low = (y0 - y_low) / (x0 - x_low); - - let x_high = xn + 0.2 * range; - let y_high = interp.interpolate(&[x_high]).unwrap(); - let slope_high = (y_high - yn) / (x_high - xn); - - let xs_left = Array1::linspace(*x0, x0 + 2e-6, 50); - let xs_right = Array1::linspace(xn - 2e-6, *xn, 50); - - // Left extrapolation is linear - let ys: Array1 = xs_left - .iter() - .map(|&x| interp.interpolate(&[x]).unwrap()) - .collect(); - let slopes: Array1 = xs_left - .windows(2) - .into_iter() - .zip(ys.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - assert_approx_eq!(slopes.mean().unwrap(), slope_low); - - // Right extrapolation is linear - let ys: Array1 = xs_right - .iter() - .map(|&x| interp.interpolate(&[x]).unwrap()) - .collect(); - let slopes: Array1 = xs_right - .windows(2) - .into_iter() - .zip(ys.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - assert_approx_eq!(slopes.mean().unwrap(), slope_high); - } - - #[test] - fn test_cubic_clamped() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., -90., 19., 99., 291., 444.]; - - let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); - let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); - - for (a, b) in [(-5., 10.), (0., 0.), (2.4, -5.2)] { - let interp = Interp1D::new( - x.view(), - f_x.view(), - Cubic::clamped(a, b), - Extrapolate::Enable, - ) - .unwrap(); - - // Interpolating at knots returns values - for i in 0..x.len() { - assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); - } - - // Left slope = a - let ys: Array1 = xs_left - .iter() - .map(|&x| interp.interpolate(&[x]).unwrap()) - .collect(); - let slopes: Array1 = xs_left - .windows(2) - .into_iter() - .zip(ys.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - assert_approx_eq!(slopes.mean().unwrap(), a); - - // Right slope = b - let ys: Array1 = xs_right - .iter() - .map(|&x| interp.interpolate(&[x]).unwrap()) - .collect(); - let slopes: Array1 = xs_right - .windows(2) - .into_iter() - .zip(ys.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - assert_approx_eq!(slopes.mean().unwrap(), b); - } - } - - #[test] - fn test_cubic_periodic() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., -90., 19., 99., 291., 444.]; - - let interp_extrap_enable = - Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Enable).unwrap(); - let interp_extrap_wrap = - Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Wrap).unwrap(); - - // Interpolating at knots returns values - for i in 0..x.len() { - assert_approx_eq!(interp_extrap_enable.interpolate(&[x[i]]).unwrap(), f_x[i]); - assert_approx_eq!(interp_extrap_wrap.interpolate(&[x[i]]).unwrap(), f_x[i]); - } - - // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic() - let x0 = x.first().unwrap(); - let xn = x.last().unwrap(); - let range = xn - x0; - let x_low = x0 - 0.2 * range; - let x_high = x0 + 0.2 * range; - let xs_left = Array1::linspace(x_low, *x0, 50); - let xs_right = Array1::linspace(*xn, x_high, 50); - for x in xs_left { - assert_eq!( - interp_extrap_enable.interpolate(&[x]).unwrap(), - interp_extrap_wrap.interpolate(&[x]).unwrap() - ); - } - for x in xs_right { - assert_eq!( - interp_extrap_enable.interpolate(&[x]).unwrap(), - interp_extrap_wrap.interpolate(&[x]).unwrap() - ); - } - - // Slope left - let xs_left = Array1::linspace(x_low, x_low + 2e6, 50); - let ys_left: Array1 = xs_left - .iter() - .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) - .collect(); - let slopes_left: Array1 = xs_left - .windows(2) - .into_iter() - .zip(ys_left.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - let slope_left = slopes_left.mean().unwrap(); - // Slope right - let xs_right = Array1::linspace(x_high - 2e6, x_high, 50); - let ys_right: Array1 = xs_right - .iter() - .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) - .collect(); - let slopes_right: Array1 = xs_right - .windows(2) - .into_iter() - .zip(ys_right.windows(2)) - .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) - .collect(); - let slope_right = slopes_right.mean().unwrap(); - // Slopes at left and right are equal - assert_approx_eq!(slope_left, slope_right); - // Second derivatives at left and right are equal - let z = interp_extrap_enable.strategy.z; - assert_approx_eq!(z.first().unwrap(), z.last().unwrap()); - } -} diff --git a/src/one/tests.rs b/src/one/tests.rs new file mode 100644 index 0000000..4f989ea --- /dev/null +++ b/src/one/tests.rs @@ -0,0 +1,329 @@ +use super::*; + +#[test] +fn test_invalid_args() { + let interp = Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert!(matches!( + interp.interpolate(&[]).unwrap_err(), + InterpolateError::PointLength(_) + )); + assert_eq!(interp.interpolate(&[1.0]).unwrap(), 0.4); +} + +#[test] +fn test_linear() { + let x = array![0., 1., 2., 3., 4.]; + let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; + let interp = Interp1D::new(x.view(), f_x.view(), Linear, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); + } + assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[3.75]).unwrap(), 0.95); + assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); +} + +#[test] +fn test_left_nearest() { + let x = array![0., 1., 2., 3., 4.]; + let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; + let interp = Interp1D::new(x.view(), f_x.view(), LeftNearest, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); + } + assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[3.75]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); +} + +#[test] +fn test_right_nearest() { + let x = array![0., 1., 2., 3., 4.]; + let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; + let interp = Interp1D::new(x.view(), f_x.view(), RightNearest, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); + } + assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[3.25]).unwrap(), 1.0); + assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); +} + +#[test] +fn test_nearest() { + let x = array![0., 1., 2., 3., 4.]; + let f_x = array![0.2, 0.4, 0.6, 0.8, 1.0]; + let interp = Interp1D::new(x.view(), f_x.view(), Nearest, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i]).unwrap(), f_x[i]); + } + assert_eq!(interp.interpolate(&[3.00]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[3.25]).unwrap(), 0.8); + assert_eq!(interp.interpolate(&[3.50]).unwrap(), 1.0); + assert_eq!(interp.interpolate(&[3.75]).unwrap(), 1.0); + assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0); +} + +#[test] +fn test_extrapolate_inputs() { + // Incorrect extrapolation selection + assert!(matches!( + Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Nearest, + Extrapolate::Enable, + ) + .unwrap_err(), + ValidateError::ExtrapolateSelection(_) + )); + + // Extrapolate::Error + let interp = Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Linear, + Extrapolate::Error, + ) + .unwrap(); + // Fail to extrapolate below lowest grid value + assert!(matches!( + interp.interpolate(&[-1.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); + // Fail to extrapolate above highest grid value + assert!(matches!( + interp.interpolate(&[5.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); +} + +#[test] +fn test_extrapolate_fill() { + let interp = Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Linear, + Extrapolate::Fill(f64::NAN), + ) + .unwrap(); + assert_eq!(interp.interpolate(&[1.5]).unwrap(), 0.5); + assert_eq!(interp.interpolate(&[2.]).unwrap(), 0.6); + assert!(interp.interpolate(&[-1.]).unwrap().is_nan()); + assert!(interp.interpolate(&[5.]).unwrap().is_nan()); +} + +#[test] +fn test_extrapolate_clamp() { + let interp = Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Linear, + Extrapolate::Clamp, + ) + .unwrap(); + assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.2); + assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.0); +} + +#[test] +fn test_extrapolate() { + let interp = Interp1D::new( + array![0., 1., 2., 3., 4.], + array![0.2, 0.4, 0.6, 0.8, 1.0], + Linear, + Extrapolate::Enable, + ) + .unwrap(); + assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.0); + assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05); + assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2); +} + +#[test] +fn test_cubic_natural() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., 6., 19., 99., 291., 444.]; + + let interp = + Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + let x0 = x.first().unwrap(); + let xn = x.last().unwrap(); + let y0 = f_x.first().unwrap(); + let yn = f_x.last().unwrap(); + + let range = xn - x0; + + let x_low = x0 - 0.2 * range; + let y_low = interp.interpolate(&[x_low]).unwrap(); + let slope_low = (y0 - y_low) / (x0 - x_low); + + let x_high = xn + 0.2 * range; + let y_high = interp.interpolate(&[x_high]).unwrap(); + let slope_high = (y_high - yn) / (x_high - xn); + + let xs_left = Array1::linspace(*x0, x0 + 2e-6, 50); + let xs_right = Array1::linspace(xn - 2e-6, *xn, 50); + + // Left extrapolation is linear + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_low); + + // Right extrapolation is linear + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_high); +} + +#[test] +fn test_cubic_clamped() { + let x = array![1., 2., 3., 5., 7., 8.]; + let f_x = array![3., -90., 19., 99., 291., 444.]; + + let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); + let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); + + for (a, b) in [(-5., 10.), (0., 0.), (2.4, -5.2)] { + let interp = Interp1D::new( + x.view(), + f_x.view(), + Cubic::clamped(a, b), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + // Left slope = a + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), a); + + // Right slope = b + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), b); + } +} + +// #[test] +// fn test_cubic_periodic() { +// let x = array![1., 2., 3., 5., 7., 8.]; +// let f_x = array![3., -90., 19., 99., 291., 444.]; + +// let interp_extrap_enable = +// Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Enable).unwrap(); +// let interp_extrap_wrap = +// Interp1D::new(x.view(), f_x.view(), Cubic::periodic(), Extrapolate::Wrap).unwrap(); + +// // Interpolating at knots returns values +// for i in 0..x.len() { +// assert_approx_eq!(interp_extrap_enable.interpolate(&[x[i]]).unwrap(), f_x[i]); +// assert_approx_eq!(interp_extrap_wrap.interpolate(&[x[i]]).unwrap(), f_x[i]); +// } + +// // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic() +// let x0 = x.first().unwrap(); +// let xn = x.last().unwrap(); +// let range = xn - x0; +// let x_low = x0 - 0.2 * range; +// let x_high = x0 + 0.2 * range; +// let xs_left = Array1::linspace(x_low, *x0, 50); +// let xs_right = Array1::linspace(*xn, x_high, 50); +// for x in xs_left { +// assert_eq!( +// interp_extrap_enable.interpolate(&[x]).unwrap(), +// interp_extrap_wrap.interpolate(&[x]).unwrap() +// ); +// } +// for x in xs_right { +// assert_eq!( +// interp_extrap_enable.interpolate(&[x]).unwrap(), +// interp_extrap_wrap.interpolate(&[x]).unwrap() +// ); +// } + +// // Slope left +// let xs_left = Array1::linspace(x_low, x_low + 2e6, 50); +// let ys_left: Array1 = xs_left +// .iter() +// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) +// .collect(); +// let slopes_left: Array1 = xs_left +// .windows(2) +// .into_iter() +// .zip(ys_left.windows(2)) +// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) +// .collect(); +// let slope_left = slopes_left.mean().unwrap(); +// // Slope right +// let xs_right = Array1::linspace(x_high - 2e6, x_high, 50); +// let ys_right: Array1 = xs_right +// .iter() +// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) +// .collect(); +// let slopes_right: Array1 = xs_right +// .windows(2) +// .into_iter() +// .zip(ys_right.windows(2)) +// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) +// .collect(); +// let slope_right = slopes_right.mean().unwrap(); +// // Slopes at left and right are equal +// assert_approx_eq!(slope_left, slope_right); +// // Second derivatives at left and right are equal +// let z = interp_extrap_enable.strategy.z; +// assert_approx_eq!(z.first().unwrap(), z.last().unwrap()); +// } diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs index bb2308c..d117187 100644 --- a/src/strategy/cubic.rs +++ b/src/strategy/cubic.rs @@ -93,6 +93,7 @@ impl Cubic where T: Float + Debug, { + // Reference: https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf pub(crate) fn evaluate_1d + RawDataClone + Clone>( &self, point: &[T; 1], diff --git a/src/three/mod.rs b/src/three/mod.rs index 787410b..47d8e16 100644 --- a/src/three/mod.rs +++ b/src/three/mod.rs @@ -3,6 +3,8 @@ use super::*; mod strategies; +#[cfg(test)] +mod tests; const N: usize = 3; @@ -209,209 +211,3 @@ where self.check_extrapolate(&self.extrapolate) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_linear() { - let x = array![0.05, 0.10, 0.15]; - let y = array![0.10, 0.20, 0.30]; - let z = array![0.20, 0.40, 0.60]; - let f_xyz = array![ - [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], - [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], - ]; - let interp = Interp3D::new( - x.view(), - y.view(), - z.view(), - f_xyz.view(), - Linear, - Extrapolate::Error, - ) - .unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - for (j, y_j) in y.iter().enumerate() { - for (k, z_k) in z.iter().enumerate() { - assert_eq!( - interp.interpolate(&[*x_i, *y_j, *z_k]).unwrap(), - f_xyz[[i, j, k]] - ); - } - } - } - assert_approx_eq!(interp.interpolate(&[x[0], y[0], 0.3]).unwrap(), 0.5); - assert_approx_eq!(interp.interpolate(&[x[0], 0.15, z[0]]).unwrap(), 1.5); - assert_approx_eq!(interp.interpolate(&[x[0], 0.15, 0.3]).unwrap(), 2.); - assert_approx_eq!(interp.interpolate(&[0.075, y[0], z[0]]).unwrap(), 4.5); - assert_approx_eq!(interp.interpolate(&[0.075, y[0], 0.3]).unwrap(), 5.); - assert_approx_eq!(interp.interpolate(&[0.075, 0.15, z[0]]).unwrap(), 6.); - } - - #[test] - fn test_linear_extrapolation() { - let interp = Interp3D::new( - array![0.05, 0.10, 0.15], - array![0.10, 0.20, 0.30], - array![0.20, 0.40, 0.60], - array![ - [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], - [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], - ], - Linear, - Extrapolate::Enable, - ) - .unwrap(); - // below x, below y, below z - assert_approx_eq!(interp.interpolate(&[0.01, 0.06, 0.17]).unwrap(), -8.55); - assert_approx_eq!(interp.interpolate(&[0.02, 0.08, 0.19]).unwrap(), -6.05); - // below x, below y, above z - assert_approx_eq!(interp.interpolate(&[0.01, 0.06, 0.63]).unwrap(), -6.25); - assert_approx_eq!(interp.interpolate(&[0.02, 0.08, 0.65]).unwrap(), -3.75); - // below x, above y, below z - assert_approx_eq!(interp.interpolate(&[0.01, 0.33, 0.17]).unwrap(), -0.45); - assert_approx_eq!(interp.interpolate(&[0.02, 0.36, 0.19]).unwrap(), 2.35); - // below x, above y, above z - assert_approx_eq!(interp.interpolate(&[0.01, 0.33, 0.63]).unwrap(), 1.85); - assert_approx_eq!(interp.interpolate(&[0.02, 0.36, 0.65]).unwrap(), 4.65); - // above x, below y, below z - assert_approx_eq!(interp.interpolate(&[0.17, 0.06, 0.17]).unwrap(), 20.25); - assert_approx_eq!(interp.interpolate(&[0.19, 0.08, 0.19]).unwrap(), 24.55); - // above x, below y, above z - assert_approx_eq!(interp.interpolate(&[0.17, 0.06, 0.63]).unwrap(), 22.55); - assert_approx_eq!(interp.interpolate(&[0.19, 0.08, 0.65]).unwrap(), 26.85); - // above x, above y, below z - assert_approx_eq!(interp.interpolate(&[0.17, 0.33, 0.17]).unwrap(), 28.35); - assert_approx_eq!(interp.interpolate(&[0.19, 0.36, 0.19]).unwrap(), 32.95); - // above x, above y, above z - assert_approx_eq!(interp.interpolate(&[0.17, 0.33, 0.63]).unwrap(), 30.65); - assert_approx_eq!(interp.interpolate(&[0.19, 0.36, 0.65]).unwrap(), 35.25); - } - - #[test] - fn test_linear_offset() { - let interp = Interp3D::new( - array![0., 1.], - array![0., 1.], - array![0., 1.], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert_approx_eq!(interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(), 3.2); - } - - #[test] - fn test_nearest() { - let x = array![0., 1.]; - let y = array![0., 1.]; - let z = array![0., 1.]; - let f_xyz = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],]; - let interp = Interp3D::new( - x.view(), - y.view(), - z.view(), - f_xyz.view(), - Nearest, - Extrapolate::Error, - ) - .unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - for (j, y_j) in y.iter().enumerate() { - for (k, z_k) in z.iter().enumerate() { - assert_eq!( - interp.interpolate(&[*x_i, *y_j, *z_k]).unwrap(), - f_xyz[[i, j, k]] - ); - } - } - } - assert_eq!(interp.interpolate(&[0., 0., 0.]).unwrap(), 0.); - assert_eq!(interp.interpolate(&[0.25, 0.25, 0.25]).unwrap(), 0.); - assert_eq!(interp.interpolate(&[0.25, 0.75, 0.25]).unwrap(), 2.); - assert_eq!(interp.interpolate(&[0., 1., 0.]).unwrap(), 2.); - assert_eq!(interp.interpolate(&[0.75, 0.25, 0.75]).unwrap(), 5.); - assert_eq!(interp.interpolate(&[0.75, 0.75, 0.75]).unwrap(), 7.); - assert_eq!(interp.interpolate(&[1., 1., 1.]).unwrap(), 7.); - } - - #[test] - fn test_extrapolate_inputs() { - // Extrapolate::Extrapolate - assert!(matches!( - Interp3D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![0.3, 1.3], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - Nearest, - Extrapolate::Enable, - ) - .unwrap_err(), - ValidateError::ExtrapolateSelection(_) - )); - // Extrapolate::Error - let interp = Interp3D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![0.3, 1.3], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert!(matches!( - interp.interpolate(&[-1., -1., -1.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - assert!(matches!( - interp.interpolate(&[2., 2., 2.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - } - - #[test] - fn test_extrapolate_fill() { - let interp = Interp3D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![0.3, 1.3], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - Linear, - Extrapolate::Fill(f64::NAN), - ) - .unwrap(); - assert_approx_eq!(interp.interpolate(&[0.4, 0.4, 0.4]).unwrap(), 1.7); - assert_approx_eq!(interp.interpolate(&[0.8, 0.8, 0.8]).unwrap(), 4.5); - assert!(interp.interpolate(&[0., 0., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 0., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 2., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 2., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 0., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 0., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 2., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 2., 2.]).unwrap().is_nan()); - } - - #[test] - fn test_extrapolate_clamp() { - let interp = Interp3D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![0.3, 1.3], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - Linear, - Extrapolate::Clamp, - ) - .unwrap(); - assert_eq!(interp.interpolate(&[-1., -1., -1.]).unwrap(), 0.); - assert_eq!(interp.interpolate(&[2., 2., 2.]).unwrap(), 7.); - } -} diff --git a/src/three/tests.rs b/src/three/tests.rs new file mode 100644 index 0000000..2a3f0ef --- /dev/null +++ b/src/three/tests.rs @@ -0,0 +1,202 @@ +use super::*; + +#[test] +fn test_linear() { + let x = array![0.05, 0.10, 0.15]; + let y = array![0.10, 0.20, 0.30]; + let z = array![0.20, 0.40, 0.60]; + let f_xyz = array![ + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], + [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], + ]; + let interp = Interp3D::new( + x.view(), + y.view(), + z.view(), + f_xyz.view(), + Linear, + Extrapolate::Error, + ) + .unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + for (j, y_j) in y.iter().enumerate() { + for (k, z_k) in z.iter().enumerate() { + assert_eq!( + interp.interpolate(&[*x_i, *y_j, *z_k]).unwrap(), + f_xyz[[i, j, k]] + ); + } + } + } + assert_approx_eq!(interp.interpolate(&[x[0], y[0], 0.3]).unwrap(), 0.5); + assert_approx_eq!(interp.interpolate(&[x[0], 0.15, z[0]]).unwrap(), 1.5); + assert_approx_eq!(interp.interpolate(&[x[0], 0.15, 0.3]).unwrap(), 2.); + assert_approx_eq!(interp.interpolate(&[0.075, y[0], z[0]]).unwrap(), 4.5); + assert_approx_eq!(interp.interpolate(&[0.075, y[0], 0.3]).unwrap(), 5.); + assert_approx_eq!(interp.interpolate(&[0.075, 0.15, z[0]]).unwrap(), 6.); +} + +#[test] +fn test_linear_extrapolation() { + let interp = Interp3D::new( + array![0.05, 0.10, 0.15], + array![0.10, 0.20, 0.30], + array![0.20, 0.40, 0.60], + array![ + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + [[9., 10., 11.], [12., 13., 14.], [15., 16., 17.]], + [[18., 19., 20.], [21., 22., 23.], [24., 25., 26.],], + ], + Linear, + Extrapolate::Enable, + ) + .unwrap(); + // below x, below y, below z + assert_approx_eq!(interp.interpolate(&[0.01, 0.06, 0.17]).unwrap(), -8.55); + assert_approx_eq!(interp.interpolate(&[0.02, 0.08, 0.19]).unwrap(), -6.05); + // below x, below y, above z + assert_approx_eq!(interp.interpolate(&[0.01, 0.06, 0.63]).unwrap(), -6.25); + assert_approx_eq!(interp.interpolate(&[0.02, 0.08, 0.65]).unwrap(), -3.75); + // below x, above y, below z + assert_approx_eq!(interp.interpolate(&[0.01, 0.33, 0.17]).unwrap(), -0.45); + assert_approx_eq!(interp.interpolate(&[0.02, 0.36, 0.19]).unwrap(), 2.35); + // below x, above y, above z + assert_approx_eq!(interp.interpolate(&[0.01, 0.33, 0.63]).unwrap(), 1.85); + assert_approx_eq!(interp.interpolate(&[0.02, 0.36, 0.65]).unwrap(), 4.65); + // above x, below y, below z + assert_approx_eq!(interp.interpolate(&[0.17, 0.06, 0.17]).unwrap(), 20.25); + assert_approx_eq!(interp.interpolate(&[0.19, 0.08, 0.19]).unwrap(), 24.55); + // above x, below y, above z + assert_approx_eq!(interp.interpolate(&[0.17, 0.06, 0.63]).unwrap(), 22.55); + assert_approx_eq!(interp.interpolate(&[0.19, 0.08, 0.65]).unwrap(), 26.85); + // above x, above y, below z + assert_approx_eq!(interp.interpolate(&[0.17, 0.33, 0.17]).unwrap(), 28.35); + assert_approx_eq!(interp.interpolate(&[0.19, 0.36, 0.19]).unwrap(), 32.95); + // above x, above y, above z + assert_approx_eq!(interp.interpolate(&[0.17, 0.33, 0.63]).unwrap(), 30.65); + assert_approx_eq!(interp.interpolate(&[0.19, 0.36, 0.65]).unwrap(), 35.25); +} + +#[test] +fn test_linear_offset() { + let interp = Interp3D::new( + array![0., 1.], + array![0., 1.], + array![0., 1.], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert_approx_eq!(interp.interpolate(&[0.25, 0.65, 0.9]).unwrap(), 3.2); +} + +#[test] +fn test_nearest() { + let x = array![0., 1.]; + let y = array![0., 1.]; + let z = array![0., 1.]; + let f_xyz = array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],]; + let interp = Interp3D::new( + x.view(), + y.view(), + z.view(), + f_xyz.view(), + Nearest, + Extrapolate::Error, + ) + .unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + for (j, y_j) in y.iter().enumerate() { + for (k, z_k) in z.iter().enumerate() { + assert_eq!( + interp.interpolate(&[*x_i, *y_j, *z_k]).unwrap(), + f_xyz[[i, j, k]] + ); + } + } + } + assert_eq!(interp.interpolate(&[0., 0., 0.]).unwrap(), 0.); + assert_eq!(interp.interpolate(&[0.25, 0.25, 0.25]).unwrap(), 0.); + assert_eq!(interp.interpolate(&[0.25, 0.75, 0.25]).unwrap(), 2.); + assert_eq!(interp.interpolate(&[0., 1., 0.]).unwrap(), 2.); + assert_eq!(interp.interpolate(&[0.75, 0.25, 0.75]).unwrap(), 5.); + assert_eq!(interp.interpolate(&[0.75, 0.75, 0.75]).unwrap(), 7.); + assert_eq!(interp.interpolate(&[1., 1., 1.]).unwrap(), 7.); +} + +#[test] +fn test_extrapolate_inputs() { + // Extrapolate::Extrapolate + assert!(matches!( + Interp3D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![0.3, 1.3], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], + Nearest, + Extrapolate::Enable, + ) + .unwrap_err(), + ValidateError::ExtrapolateSelection(_) + )); + // Extrapolate::Error + let interp = Interp3D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![0.3, 1.3], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert!(matches!( + interp.interpolate(&[-1., -1., -1.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); + assert!(matches!( + interp.interpolate(&[2., 2., 2.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); +} + +#[test] +fn test_extrapolate_fill() { + let interp = Interp3D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![0.3, 1.3], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], + Linear, + Extrapolate::Fill(f64::NAN), + ) + .unwrap(); + assert_approx_eq!(interp.interpolate(&[0.4, 0.4, 0.4]).unwrap(), 1.7); + assert_approx_eq!(interp.interpolate(&[0.8, 0.8, 0.8]).unwrap(), 4.5); + assert!(interp.interpolate(&[0., 0., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 0., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 2., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 2., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 0., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 0., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 2., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 2., 2.]).unwrap().is_nan()); +} + +#[test] +fn test_extrapolate_clamp() { + let interp = Interp3D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![0.3, 1.3], + array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], + Linear, + Extrapolate::Clamp, + ) + .unwrap(); + assert_eq!(interp.interpolate(&[-1., -1., -1.]).unwrap(), 0.); + assert_eq!(interp.interpolate(&[2., 2., 2.]).unwrap(), 7.); +} diff --git a/src/two/mod.rs b/src/two/mod.rs index 86d81df..7d240a4 100644 --- a/src/two/mod.rs +++ b/src/two/mod.rs @@ -3,6 +3,8 @@ use super::*; mod strategies; +#[cfg(test)] +mod tests; const N: usize = 2; @@ -195,169 +197,3 @@ where self.check_extrapolate(&self.extrapolate) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_linear() { - let x = array![0.05, 0.10, 0.15]; - let y = array![0.10, 0.20, 0.30]; - let f_xy = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; - let interp = - Interp2D::new(x.view(), y.view(), f_xy.view(), Linear, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - for (j, y_j) in y.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i, *y_j]).unwrap(), f_xy[[i, j]]); - } - } - assert_eq!(interp.interpolate(&[x[2], y[1]]).unwrap(), f_xy[[2, 1]]); - assert_eq!(interp.interpolate(&[0.075, 0.25]).unwrap(), 3.); - } - - #[test] - fn test_linear_offset() { - let interp = Interp2D::new( - array![0., 1.], - array![0., 1.], - array![[0., 1.], [2., 3.]], - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert_approx_eq!(interp.interpolate(&[0.25, 0.65]).unwrap(), 1.15); - } - - #[test] - fn test_linear_extrapolation() { - let interp = Interp2D::new( - array![0.05, 0.10, 0.15], - array![0.10, 0.20, 0.30], - array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], - Linear, - Extrapolate::Enable, - ) - .unwrap(); - // RHS are coplanar neighboring data planes according to: - // https://www.ambrbit.com/TrigoCalc/Plan3D/PointsCoplanar.htm - // below x, below y - assert_approx_eq!(interp.interpolate(&[0.0, 0.0]).unwrap(), -4.); - assert_approx_eq!(interp.interpolate(&[0.03, 0.04]).unwrap(), -1.8); - // below x, above y - assert_approx_eq!(interp.interpolate(&[0.0, 0.32]).unwrap(), -0.8); - assert_approx_eq!(interp.interpolate(&[0.03, 0.36]).unwrap(), 1.4); - // above x, below y - assert_approx_eq!(interp.interpolate(&[0.17, 0.0]).unwrap(), 6.2); - assert_approx_eq!(interp.interpolate(&[0.19, 0.04]).unwrap(), 7.8); - // above x, above y - assert_approx_eq!(interp.interpolate(&[0.17, 0.32]).unwrap(), 9.4); - assert_approx_eq!(interp.interpolate(&[0.19, 0.36]).unwrap(), 11.); - } - - #[test] - fn test_nearest() { - let x = array![0.05, 0.10, 0.15]; - let y = array![0.10, 0.20, 0.30]; - let f_xy = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; - let interp = - Interp2D::new(x.view(), y.view(), f_xy.view(), Nearest, Extrapolate::Error).unwrap(); - // Check that interpolating at grid points just retrieves the value - for (i, x_i) in x.iter().enumerate() { - for (j, y_j) in y.iter().enumerate() { - assert_eq!(interp.interpolate(&[*x_i, *y_j]).unwrap(), f_xy[[i, j]]); - } - } - assert_eq!(interp.interpolate(&[0.05, 0.12]).unwrap(), f_xy[[0, 0]]); - assert_eq!( - // float imprecision - interp.interpolate(&[0.07, 0.15 + 0.0001]).unwrap(), - f_xy[[0, 1]] - ); - assert_eq!(interp.interpolate(&[0.08, 0.21]).unwrap(), f_xy[[1, 1]]); - assert_eq!(interp.interpolate(&[0.11, 0.26]).unwrap(), f_xy[[1, 2]]); - assert_eq!(interp.interpolate(&[0.13, 0.12]).unwrap(), f_xy[[2, 0]]); - assert_eq!(interp.interpolate(&[0.14, 0.29]).unwrap(), f_xy[[2, 2]]); - } - - #[test] - fn test_extrapolate_inputs() { - // Extrapolate::Extrapolate - assert!(matches!( - Interp2D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![[0., 1.], [2., 3.]], - Nearest, - Extrapolate::Enable, - ) - .unwrap_err(), - ValidateError::ExtrapolateSelection(_) - )); - // Extrapolate::Error - let interp = Interp2D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![[0., 1.], [2., 3.]], - Linear, - Extrapolate::Error, - ) - .unwrap(); - assert!(matches!( - interp.interpolate(&[-1., -1.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - assert!(matches!( - interp.interpolate(&[2., 2.]).unwrap_err(), - InterpolateError::ExtrapolateError(_) - )); - } - - #[test] - fn test_extrapolate_fill() { - let interp = Interp2D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![[0., 1.], [2., 3.]], - Linear, - Extrapolate::Fill(f64::NAN), - ) - .unwrap(); - assert_eq!(interp.interpolate(&[0.5, 0.5]).unwrap(), 1.1); - assert_eq!(interp.interpolate(&[0.1, 1.2]).unwrap(), 1.); - assert!(interp.interpolate(&[0., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[0., 2.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 0.]).unwrap().is_nan()); - assert!(interp.interpolate(&[2., 2.]).unwrap().is_nan()); - } - - #[test] - fn test_dyn_strategy() { - let mut interp = Interp2D::new( - array![0., 1.], - array![0., 1.], - array![[0., 1.], [2., 3.]], - Box::new(Linear) as Box>, - Extrapolate::Error, - ) - .unwrap(); - assert_eq!(interp.interpolate(&[0.2, 0.]).unwrap(), 0.4); - interp.set_strategy(Box::new(Nearest)).unwrap(); - assert_eq!(interp.interpolate(&[0.2, 0.]).unwrap(), 0.); - } - - #[test] - fn test_extrapolate_clamp() { - let interp = Interp2D::new( - array![0.1, 1.1], - array![0.2, 1.2], - array![[0., 1.], [2., 3.]], - Linear, - Extrapolate::Clamp, - ) - .unwrap(); - assert_eq!(interp.interpolate(&[-1., -1.]).unwrap(), 0.); - assert_eq!(interp.interpolate(&[2., 2.]).unwrap(), 3.); - } -} diff --git a/src/two/tests.rs b/src/two/tests.rs new file mode 100644 index 0000000..ef9a546 --- /dev/null +++ b/src/two/tests.rs @@ -0,0 +1,162 @@ +use super::*; + +#[test] +fn test_linear() { + let x = array![0.05, 0.10, 0.15]; + let y = array![0.10, 0.20, 0.30]; + let f_xy = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; + let interp = + Interp2D::new(x.view(), y.view(), f_xy.view(), Linear, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + for (j, y_j) in y.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i, *y_j]).unwrap(), f_xy[[i, j]]); + } + } + assert_eq!(interp.interpolate(&[x[2], y[1]]).unwrap(), f_xy[[2, 1]]); + assert_eq!(interp.interpolate(&[0.075, 0.25]).unwrap(), 3.); +} + +#[test] +fn test_linear_offset() { + let interp = Interp2D::new( + array![0., 1.], + array![0., 1.], + array![[0., 1.], [2., 3.]], + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert_approx_eq!(interp.interpolate(&[0.25, 0.65]).unwrap(), 1.15); +} + +#[test] +fn test_linear_extrapolation() { + let interp = Interp2D::new( + array![0.05, 0.10, 0.15], + array![0.10, 0.20, 0.30], + array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + Linear, + Extrapolate::Enable, + ) + .unwrap(); + // RHS are coplanar neighboring data planes according to: + // https://www.ambrbit.com/TrigoCalc/Plan3D/PointsCoplanar.htm + // below x, below y + assert_approx_eq!(interp.interpolate(&[0.0, 0.0]).unwrap(), -4.); + assert_approx_eq!(interp.interpolate(&[0.03, 0.04]).unwrap(), -1.8); + // below x, above y + assert_approx_eq!(interp.interpolate(&[0.0, 0.32]).unwrap(), -0.8); + assert_approx_eq!(interp.interpolate(&[0.03, 0.36]).unwrap(), 1.4); + // above x, below y + assert_approx_eq!(interp.interpolate(&[0.17, 0.0]).unwrap(), 6.2); + assert_approx_eq!(interp.interpolate(&[0.19, 0.04]).unwrap(), 7.8); + // above x, above y + assert_approx_eq!(interp.interpolate(&[0.17, 0.32]).unwrap(), 9.4); + assert_approx_eq!(interp.interpolate(&[0.19, 0.36]).unwrap(), 11.); +} + +#[test] +fn test_nearest() { + let x = array![0.05, 0.10, 0.15]; + let y = array![0.10, 0.20, 0.30]; + let f_xy = array![[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]; + let interp = + Interp2D::new(x.view(), y.view(), f_xy.view(), Nearest, Extrapolate::Error).unwrap(); + // Check that interpolating at grid points just retrieves the value + for (i, x_i) in x.iter().enumerate() { + for (j, y_j) in y.iter().enumerate() { + assert_eq!(interp.interpolate(&[*x_i, *y_j]).unwrap(), f_xy[[i, j]]); + } + } + assert_eq!(interp.interpolate(&[0.05, 0.12]).unwrap(), f_xy[[0, 0]]); + assert_eq!( + // float imprecision + interp.interpolate(&[0.07, 0.15 + 0.0001]).unwrap(), + f_xy[[0, 1]] + ); + assert_eq!(interp.interpolate(&[0.08, 0.21]).unwrap(), f_xy[[1, 1]]); + assert_eq!(interp.interpolate(&[0.11, 0.26]).unwrap(), f_xy[[1, 2]]); + assert_eq!(interp.interpolate(&[0.13, 0.12]).unwrap(), f_xy[[2, 0]]); + assert_eq!(interp.interpolate(&[0.14, 0.29]).unwrap(), f_xy[[2, 2]]); +} + +#[test] +fn test_extrapolate_inputs() { + // Extrapolate::Extrapolate + assert!(matches!( + Interp2D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![[0., 1.], [2., 3.]], + Nearest, + Extrapolate::Enable, + ) + .unwrap_err(), + ValidateError::ExtrapolateSelection(_) + )); + // Extrapolate::Error + let interp = Interp2D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![[0., 1.], [2., 3.]], + Linear, + Extrapolate::Error, + ) + .unwrap(); + assert!(matches!( + interp.interpolate(&[-1., -1.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); + assert!(matches!( + interp.interpolate(&[2., 2.]).unwrap_err(), + InterpolateError::ExtrapolateError(_) + )); +} + +#[test] +fn test_extrapolate_fill() { + let interp = Interp2D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![[0., 1.], [2., 3.]], + Linear, + Extrapolate::Fill(f64::NAN), + ) + .unwrap(); + assert_eq!(interp.interpolate(&[0.5, 0.5]).unwrap(), 1.1); + assert_eq!(interp.interpolate(&[0.1, 1.2]).unwrap(), 1.); + assert!(interp.interpolate(&[0., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[0., 2.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 0.]).unwrap().is_nan()); + assert!(interp.interpolate(&[2., 2.]).unwrap().is_nan()); +} + +#[test] +fn test_dyn_strategy() { + let mut interp = Interp2D::new( + array![0., 1.], + array![0., 1.], + array![[0., 1.], [2., 3.]], + Box::new(Linear) as Box>, + Extrapolate::Error, + ) + .unwrap(); + assert_eq!(interp.interpolate(&[0.2, 0.]).unwrap(), 0.4); + interp.set_strategy(Box::new(Nearest)).unwrap(); + assert_eq!(interp.interpolate(&[0.2, 0.]).unwrap(), 0.); +} + +#[test] +fn test_extrapolate_clamp() { + let interp = Interp2D::new( + array![0.1, 1.1], + array![0.2, 1.2], + array![[0., 1.], [2., 3.]], + Linear, + Extrapolate::Clamp, + ) + .unwrap(); + assert_eq!(interp.interpolate(&[-1., -1.]).unwrap(), 0.); + assert_eq!(interp.interpolate(&[2., 2.]).unwrap(), 3.); +} From c28fb06ce8e5b823209a70f9b20a7275233a47c7 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 10:43:14 -0600 Subject: [PATCH 14/18] mod reorg, API should be unchanged --- src/{ => interpolator}/data.rs | 0 src/interpolator/mod.rs | 15 ++++++++++++++ src/{ => interpolator}/n/mod.rs | 0 src/{ => interpolator}/n/strategies.rs | 0 src/{ => interpolator}/n/tests.rs | 0 src/{ => interpolator}/one/mod.rs | 0 src/{ => interpolator}/one/strategies.rs | 0 src/{ => interpolator}/one/tests.rs | 0 src/{ => interpolator}/three/mod.rs | 0 src/{ => interpolator}/three/strategies.rs | 0 src/{ => interpolator}/three/tests.rs | 0 src/{ => interpolator}/two/mod.rs | 0 src/{ => interpolator}/two/strategies.rs | 0 src/{ => interpolator}/two/tests.rs | 0 src/{ => interpolator}/zero/mod.rs | 0 src/lib.rs | 23 ++++++++-------------- 16 files changed, 23 insertions(+), 15 deletions(-) rename src/{ => interpolator}/data.rs (100%) create mode 100644 src/interpolator/mod.rs rename src/{ => interpolator}/n/mod.rs (100%) rename src/{ => interpolator}/n/strategies.rs (100%) rename src/{ => interpolator}/n/tests.rs (100%) rename src/{ => interpolator}/one/mod.rs (100%) rename src/{ => interpolator}/one/strategies.rs (100%) rename src/{ => interpolator}/one/tests.rs (100%) rename src/{ => interpolator}/three/mod.rs (100%) rename src/{ => interpolator}/three/strategies.rs (100%) rename src/{ => interpolator}/three/tests.rs (100%) rename src/{ => interpolator}/two/mod.rs (100%) rename src/{ => interpolator}/two/strategies.rs (100%) rename src/{ => interpolator}/two/tests.rs (100%) rename src/{ => interpolator}/zero/mod.rs (100%) diff --git a/src/data.rs b/src/interpolator/data.rs similarity index 100% rename from src/data.rs rename to src/interpolator/data.rs diff --git a/src/interpolator/mod.rs b/src/interpolator/mod.rs new file mode 100644 index 0000000..836befe --- /dev/null +++ b/src/interpolator/mod.rs @@ -0,0 +1,15 @@ +use super::*; + +pub mod data; + +pub mod n; +pub mod one; +pub mod three; +pub mod two; +pub mod zero; + +pub use n::{InterpND, InterpNDOwned, InterpNDViewed}; +pub use one::{Interp1D, Interp1DOwned, Interp1DViewed}; +pub use three::{Interp3D, Interp3DOwned, Interp3DViewed}; +pub use two::{Interp2D, Interp2DOwned, Interp2DViewed}; +pub use zero::Interp0D; diff --git a/src/n/mod.rs b/src/interpolator/n/mod.rs similarity index 100% rename from src/n/mod.rs rename to src/interpolator/n/mod.rs diff --git a/src/n/strategies.rs b/src/interpolator/n/strategies.rs similarity index 100% rename from src/n/strategies.rs rename to src/interpolator/n/strategies.rs diff --git a/src/n/tests.rs b/src/interpolator/n/tests.rs similarity index 100% rename from src/n/tests.rs rename to src/interpolator/n/tests.rs diff --git a/src/one/mod.rs b/src/interpolator/one/mod.rs similarity index 100% rename from src/one/mod.rs rename to src/interpolator/one/mod.rs diff --git a/src/one/strategies.rs b/src/interpolator/one/strategies.rs similarity index 100% rename from src/one/strategies.rs rename to src/interpolator/one/strategies.rs diff --git a/src/one/tests.rs b/src/interpolator/one/tests.rs similarity index 100% rename from src/one/tests.rs rename to src/interpolator/one/tests.rs diff --git a/src/three/mod.rs b/src/interpolator/three/mod.rs similarity index 100% rename from src/three/mod.rs rename to src/interpolator/three/mod.rs diff --git a/src/three/strategies.rs b/src/interpolator/three/strategies.rs similarity index 100% rename from src/three/strategies.rs rename to src/interpolator/three/strategies.rs diff --git a/src/three/tests.rs b/src/interpolator/three/tests.rs similarity index 100% rename from src/three/tests.rs rename to src/interpolator/three/tests.rs diff --git a/src/two/mod.rs b/src/interpolator/two/mod.rs similarity index 100% rename from src/two/mod.rs rename to src/interpolator/two/mod.rs diff --git a/src/two/strategies.rs b/src/interpolator/two/strategies.rs similarity index 100% rename from src/two/strategies.rs rename to src/interpolator/two/strategies.rs diff --git a/src/two/tests.rs b/src/interpolator/two/tests.rs similarity index 100% rename from src/two/tests.rs rename to src/interpolator/two/tests.rs diff --git a/src/zero/mod.rs b/src/interpolator/zero/mod.rs similarity index 100% rename from src/zero/mod.rs rename to src/interpolator/zero/mod.rs diff --git a/src/lib.rs b/src/lib.rs index d9504d8..8b1ac8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,25 +127,18 @@ pub mod prelude { pub use crate::Interpolator; } -pub mod data; pub mod error; pub mod strategy; -pub mod n; -pub mod one; -pub mod three; -pub mod two; -pub mod zero; +pub mod interpolator; +pub use interpolator::data; +pub(crate) use interpolator::data::*; +pub use interpolator::n; +pub use interpolator::one; +pub use interpolator::three; +pub use interpolator::two; +pub use interpolator::zero; -pub mod interpolator { - pub use crate::n::{InterpND, InterpNDOwned, InterpNDViewed}; - pub use crate::one::{Interp1D, Interp1DOwned, Interp1DViewed}; - pub use crate::three::{Interp3D, Interp3DOwned, Interp3DViewed}; - pub use crate::two::{Interp2D, Interp2DOwned, Interp2DViewed}; - pub use crate::zero::Interp0D; -} - -pub(crate) use data::*; pub(crate) use error::*; pub(crate) use strategy::*; From 0fbf00f2870bc2601cf66a4c003072868f81c20c Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 10:50:26 -0600 Subject: [PATCH 15/18] API breaking mod reorg --- examples/custom_strategy.rs | 2 +- examples/dynamic_strategy.rs | 2 +- src/interpolator/one/mod.rs | 1 + src/lib.rs | 3 +- src/strategy/mod.rs | 176 +---------------------------------- src/strategy/traits.rs | 173 ++++++++++++++++++++++++++++++++++ 6 files changed, 180 insertions(+), 177 deletions(-) create mode 100644 src/strategy/traits.rs diff --git a/examples/custom_strategy.rs b/examples/custom_strategy.rs index c3f1f63..4a7534c 100644 --- a/examples/custom_strategy.rs +++ b/examples/custom_strategy.rs @@ -1,6 +1,6 @@ use ninterp::data::InterpData2D; use ninterp::prelude::*; -use ninterp::strategy::*; +use ninterp::strategy::traits::*; // Note: ninterp also re-exposes the internally used `ndarray` crate // `use ninterp::ndarray;` diff --git a/examples/dynamic_strategy.rs b/examples/dynamic_strategy.rs index 59255e5..614411d 100644 --- a/examples/dynamic_strategy.rs +++ b/examples/dynamic_strategy.rs @@ -1,7 +1,7 @@ use ndarray::prelude::*; use ninterp::prelude::*; -use ninterp::strategy::Strategy1D; +use ninterp::strategy::traits::Strategy1D; fn main() { // Create mutable interpolator diff --git a/src/interpolator/one/mod.rs b/src/interpolator/one/mod.rs index b035436..3fb2ba3 100644 --- a/src/interpolator/one/mod.rs +++ b/src/interpolator/one/mod.rs @@ -1,6 +1,7 @@ //! 1-dimensional interpolation use super::*; +use crate::strategy::cubic::*; mod strategies; #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 8b1ac8b..2eb4cdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,7 +122,7 @@ /// - The extrapolation setting enum: [`Extrapolate`] pub mod prelude { pub use crate::interpolator::*; - pub use crate::strategy::{Cubic, LeftNearest, Linear, Nearest, RightNearest}; + pub use crate::strategy::{cubic, LeftNearest, Linear, Nearest, RightNearest}; pub use crate::Extrapolate; pub use crate::Interpolator; } @@ -140,6 +140,7 @@ pub use interpolator::two; pub use interpolator::zero; pub(crate) use error::*; +pub(crate) use strategy::traits::*; pub(crate) use strategy::*; pub(crate) use std::fmt::Debug; diff --git a/src/strategy/mod.rs b/src/strategy/mod.rs index 5335dae..accac76 100644 --- a/src/strategy/mod.rs +++ b/src/strategy/mod.rs @@ -2,180 +2,8 @@ use super::*; -mod cubic; -pub use cubic::*; - -pub trait Strategy1D: Debug + DynClone -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, _data: &InterpData1D) -> Result<(), ValidateError> { - Ok(()) - } - - fn interpolate( - &self, - data: &InterpData1D, - point: &[D::Elem; 1], - ) -> Result; - - /// Does this type's [`Strategy1D::interpolate`] provision for extrapolation? - fn allow_extrapolate(&self) -> bool; -} - -clone_trait_object!( Strategy1D); - -impl Strategy1D for Box> -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { - (**self).init(data) - } - - fn interpolate( - &self, - data: &InterpData1D, - point: &[D::Elem; 1], - ) -> Result { - (**self).interpolate(data, point) - } - - fn allow_extrapolate(&self) -> bool { - (**self).allow_extrapolate() - } -} - -pub trait Strategy2D: Debug + DynClone -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, _data: &InterpData2D) -> Result<(), ValidateError> { - Ok(()) - } - - fn interpolate( - &self, - data: &InterpData2D, - point: &[D::Elem; 2], - ) -> Result; - - /// Does this type's [`Strategy2D::interpolate`] provision for extrapolation? - fn allow_extrapolate(&self) -> bool; -} - -clone_trait_object!( Strategy2D); - -impl Strategy2D for Box> -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, data: &InterpData2D) -> Result<(), ValidateError> { - (**self).init(data) - } - - fn interpolate( - &self, - data: &InterpData2D, - point: &[D::Elem; 2], - ) -> Result { - (**self).interpolate(data, point) - } - - fn allow_extrapolate(&self) -> bool { - (**self).allow_extrapolate() - } -} - -pub trait Strategy3D: Debug + DynClone -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, _data: &InterpData3D) -> Result<(), ValidateError> { - Ok(()) - } - - fn interpolate( - &self, - data: &InterpData3D, - point: &[D::Elem; 3], - ) -> Result; - - /// Does this type's [`Strategy3D::interpolate`] provision for extrapolation? - fn allow_extrapolate(&self) -> bool; -} - -clone_trait_object!( Strategy3D); - -impl Strategy3D for Box> -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, data: &InterpData3D) -> Result<(), ValidateError> { - (**self).init(data) - } - - fn interpolate( - &self, - data: &InterpData3D, - point: &[D::Elem; 3], - ) -> Result { - (**self).interpolate(data, point) - } - - fn allow_extrapolate(&self) -> bool { - (**self).allow_extrapolate() - } -} - -pub trait StrategyND: Debug + DynClone -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, _data: &InterpDataND) -> Result<(), ValidateError> { - Ok(()) - } - - fn interpolate( - &self, - data: &InterpDataND, - point: &[D::Elem], - ) -> Result; - - /// Does this type's [`StrategyND::interpolate`] provision for extrapolation? - fn allow_extrapolate(&self) -> bool; -} - -clone_trait_object!( StrategyND); - -impl StrategyND for Box> -where - D: Data + RawDataClone + Clone, - D::Elem: PartialEq + Debug, -{ - fn init(&mut self, data: &InterpDataND) -> Result<(), ValidateError> { - (**self).init(data) - } - - fn interpolate( - &self, - data: &InterpDataND, - point: &[D::Elem], - ) -> Result { - (**self).interpolate(data, point) - } - - fn allow_extrapolate(&self) -> bool { - (**self).allow_extrapolate() - } -} +pub mod cubic; +pub mod traits; // This method contains code from RouteE Compass, another open-source NREL-developed tool // diff --git a/src/strategy/traits.rs b/src/strategy/traits.rs new file mode 100644 index 0000000..7360d85 --- /dev/null +++ b/src/strategy/traits.rs @@ -0,0 +1,173 @@ +use super::*; + +pub trait Strategy1D: Debug + DynClone +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, _data: &InterpData1D) -> Result<(), ValidateError> { + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData1D, + point: &[D::Elem; 1], + ) -> Result; + + /// Does this type's [`Strategy1D::interpolate`] provision for extrapolation? + fn allow_extrapolate(&self) -> bool; +} + +clone_trait_object!( Strategy1D); + +impl Strategy1D for Box> +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { + (**self).init(data) + } + + fn interpolate( + &self, + data: &InterpData1D, + point: &[D::Elem; 1], + ) -> Result { + (**self).interpolate(data, point) + } + + fn allow_extrapolate(&self) -> bool { + (**self).allow_extrapolate() + } +} + +pub trait Strategy2D: Debug + DynClone +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, _data: &InterpData2D) -> Result<(), ValidateError> { + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData2D, + point: &[D::Elem; 2], + ) -> Result; + + /// Does this type's [`Strategy2D::interpolate`] provision for extrapolation? + fn allow_extrapolate(&self) -> bool; +} + +clone_trait_object!( Strategy2D); + +impl Strategy2D for Box> +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, data: &InterpData2D) -> Result<(), ValidateError> { + (**self).init(data) + } + + fn interpolate( + &self, + data: &InterpData2D, + point: &[D::Elem; 2], + ) -> Result { + (**self).interpolate(data, point) + } + + fn allow_extrapolate(&self) -> bool { + (**self).allow_extrapolate() + } +} + +pub trait Strategy3D: Debug + DynClone +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, _data: &InterpData3D) -> Result<(), ValidateError> { + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData3D, + point: &[D::Elem; 3], + ) -> Result; + + /// Does this type's [`Strategy3D::interpolate`] provision for extrapolation? + fn allow_extrapolate(&self) -> bool; +} + +clone_trait_object!( Strategy3D); + +impl Strategy3D for Box> +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, data: &InterpData3D) -> Result<(), ValidateError> { + (**self).init(data) + } + + fn interpolate( + &self, + data: &InterpData3D, + point: &[D::Elem; 3], + ) -> Result { + (**self).interpolate(data, point) + } + + fn allow_extrapolate(&self) -> bool { + (**self).allow_extrapolate() + } +} + +pub trait StrategyND: Debug + DynClone +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, _data: &InterpDataND) -> Result<(), ValidateError> { + Ok(()) + } + + fn interpolate( + &self, + data: &InterpDataND, + point: &[D::Elem], + ) -> Result; + + /// Does this type's [`StrategyND::interpolate`] provision for extrapolation? + fn allow_extrapolate(&self) -> bool; +} + +clone_trait_object!( StrategyND); + +impl StrategyND for Box> +where + D: Data + RawDataClone + Clone, + D::Elem: PartialEq + Debug, +{ + fn init(&mut self, data: &InterpDataND) -> Result<(), ValidateError> { + (**self).init(data) + } + + fn interpolate( + &self, + data: &InterpDataND, + point: &[D::Elem], + ) -> Result { + (**self).interpolate(data, point) + } + + fn allow_extrapolate(&self) -> bool { + (**self).allow_extrapolate() + } +} From 3e205217af194d574fcb646762b1fcdfda3d0c71 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 11:49:02 -0600 Subject: [PATCH 16/18] cubic z array is dyn --- src/interpolator/one/strategies.rs | 2 +- src/lib.rs | 3 ++- src/strategy/cubic.rs | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/interpolator/one/strategies.rs b/src/interpolator/one/strategies.rs index 1db77dc..ee2ff81 100644 --- a/src/interpolator/one/strategies.rs +++ b/src/interpolator/one/strategies.rs @@ -90,7 +90,7 @@ where _ => todo!(), }; - self.z = Self::thomas(sub.view(), v.view(), sup.view(), u.view()); + self.z = Self::thomas(sub.view(), v.view(), sup.view(), u.view()).into_dyn(); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 2eb4cdb..0975e5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,7 @@ //! The following settings are applicable for all interpolators: //! - [`Extrapolate::Fill(T)`](`Extrapolate::Fill`) //! - [`Extrapolate::Clamp`] +//! - [`Extrapolate::Wrap`] //! - [`Extrapolate::Error`] //! //! [`Extrapolate::Enable`] is valid for [`Linear`] for all dimensionalities. @@ -122,7 +123,7 @@ /// - The extrapolation setting enum: [`Extrapolate`] pub mod prelude { pub use crate::interpolator::*; - pub use crate::strategy::{cubic, LeftNearest, Linear, Nearest, RightNearest}; + pub use crate::strategy::{cubic, cubic::Cubic, LeftNearest, Linear, Nearest, RightNearest}; pub use crate::Extrapolate; pub use crate::Interpolator; } diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs index d117187..5cf6675 100644 --- a/src/strategy/cubic.rs +++ b/src/strategy/cubic.rs @@ -8,7 +8,7 @@ pub struct Cubic { /// Behavior of [`Extrapolate::Enable`]. pub extrapolate: CubicExtrapolate, /// Solved second derivatives. - pub z: Array1, + pub z: ArrayD, } /// Cubic spline boundary conditions. @@ -40,7 +40,7 @@ impl Cubic { Self { boundary_condition, extrapolate, - z: Array1::from_vec(Vec::new()), + z: Array1::from_vec(Vec::new()).into_dyn(), } } From 76943c32cd6a0691edfa5466e4c6fbb3b939f119 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 17 Mar 2025 12:43:38 -0600 Subject: [PATCH 17/18] cubic init cleanup --- src/interpolator/one/strategies.rs | 70 +++++++++++++++--------------- src/interpolator/one/tests.rs | 8 ++-- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/interpolator/one/strategies.rs b/src/interpolator/one/strategies.rs index ee2ff81..0de6701 100644 --- a/src/interpolator/one/strategies.rs +++ b/src/interpolator/one/strategies.rs @@ -51,46 +51,48 @@ where let six = ::from(6.).unwrap(); let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); - let v = Array1::from_shape_fn(n + 1, |i| { - if i == 0 || i == n { - match &self.boundary_condition { - CubicBC::Natural => one, - CubicBC::Clamped(_, _) => two * h[0], - _ => todo!(), - } - } else { - two * (h[i - 1] + h[i]) - } - }); + let v = Array1::from_shape_fn(n - 1, |i| two * (h[i + 1] + h[i])); let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); - let u = Array1::from_shape_fn(n + 1, |i| { - if i == 0 || i == n { - match &self.boundary_condition { - CubicBC::Natural => zero, - CubicBC::Clamped(l, r) => { - if i == 0 { - six * (b[i] - *l) - } else { - six * (*r - b[i - 1]) - } - } - _ => todo!(), - } - } else { - six * (b[i] - b[i - 1]) + let u = Array1::from_shape_fn(n - 1, |i| six * (b[i + 1] - b[i])); + + let (sub, diag, sup, rhs) = match &self.boundary_condition { + CubicBC::Natural => { + let zero = array![zero]; + let one = array![one]; + ( + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), zero.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[one.view(), v.view(), one.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[zero.view(), h.slice(s![1..n])]).unwrap(), + &ndarray::concatenate(Axis(0), &[zero.view(), u.view(), zero.view()]).unwrap(), + ) } - }); - - let (sub, sup) = match &self.boundary_condition { - CubicBC::Natural => ( - &Array1::from_shape_fn(n, |i| if i == n - 1 { zero } else { h[i] }), - &Array1::from_shape_fn(n, |i| if i == 0 { zero } else { h[i] }), + CubicBC::Clamped(l, r) => ( + &h, + &ndarray::concatenate( + Axis(0), + &[ + array![two * h[0]].view(), + v.view(), + array![two * h[n - 1]].view(), + ], + ) + .unwrap(), + &h, + &ndarray::concatenate( + Axis(0), + &[ + array![six * (b[0] - *l)].view(), + u.view(), + array![six * (*r - b[n - 1])].view(), + ], + ) + .unwrap(), ), - CubicBC::Clamped(_, _) => (&h, &h), + // CubicBC::NotAKnot => (), _ => todo!(), }; - self.z = Self::thomas(sub.view(), v.view(), sup.view(), u.view()).into_dyn(); + self.z = Self::thomas(sub.view(), diag.view(), sup.view(), rhs.view()).into_dyn(); Ok(()) } diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index 4f989ea..058393f 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -152,8 +152,8 @@ fn test_extrapolate() { #[test] fn test_cubic_natural() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., 6., 19., 99., 291., 444.]; + let x = array![1., 2., 3., 5., 7., 8., 10.]; + let f_x = array![3., 6., 19., 99., 291., 444., 222.]; let interp = Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); @@ -210,8 +210,8 @@ fn test_cubic_natural() { #[test] fn test_cubic_clamped() { - let x = array![1., 2., 3., 5., 7., 8.]; - let f_x = array![3., -90., 19., 99., 291., 444.]; + let x = array![1., 2., 3., 5., 7., 8., 10.]; + let f_x = array![3., -90., 19., 99., 291., 444., 222.]; let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); From 544496ad0993dbe5a8c14bc3655ed54f79282c6f Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Wed, 19 Mar 2025 11:25:15 -0600 Subject: [PATCH 18/18] notaknot broken --- src/interpolator/one/strategies.rs | 72 +++++++++++++++++++++--------- src/interpolator/one/tests.rs | 56 +++++++++++++++++++++-- 2 files changed, 102 insertions(+), 26 deletions(-) diff --git a/src/interpolator/one/strategies.rs b/src/interpolator/one/strategies.rs index 0de6701..a78e826 100644 --- a/src/interpolator/one/strategies.rs +++ b/src/interpolator/one/strategies.rs @@ -66,30 +66,58 @@ where &ndarray::concatenate(Axis(0), &[zero.view(), u.view(), zero.view()]).unwrap(), ) } - CubicBC::Clamped(l, r) => ( - &h, - &ndarray::concatenate( - Axis(0), - &[ - array![two * h[0]].view(), - v.view(), - array![two * h[n - 1]].view(), - ], + CubicBC::Clamped(l, r) => { + let diag_0 = array![two * h[0]]; + let diag_n = array![two * h[n - 1]]; + let rhs_0 = array![six * (b[0] - *l)]; + let rhs_n = array![six * (*r - b[n - 1])]; + ( + &h, + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap(), + &h, + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap(), ) - .unwrap(), - &h, - &ndarray::concatenate( - Axis(0), - &[ - array![six * (b[0] - *l)].view(), - u.view(), - array![six * (*r - b[n - 1])].view(), - ], + } + CubicBC::NotAKnot => { + let three = two + one; + let sub_n = + array![two * h[n - 1].powi(2) + three * h[n - 1] * h[n - 2] + h[n - 2].powi(2)]; + let diag_0 = array![h[0].powi(2) - h[1].powi(2)]; + let diag_n = array![h[n - 1].powi(2) - h[n - 2].powi(2)]; + let sup_0 = array![two * h[0].powi(2) + three * h[0] * h[1] + h[1].powi(2)]; + let rhs_0 = array![h[0] * u[0]]; + let rhs_n = array![h[n - 1] * u[n - 2]]; + + println!( + "sub {:?}", + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap() + ); + println!( + "dia {:?}", + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap() + ); + println!( + "sup {:?}", + &ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap() + ); + println!( + "rhs {:?}", + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap() + ); + ( + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap(), + &ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap(), + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap(), ) - .unwrap(), - ), - // CubicBC::NotAKnot => (), - _ => todo!(), + } + _ => unreachable!(), }; self.z = Self::thomas(sub.view(), diag.view(), sup.view(), rhs.view()).into_dyn(); diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index 058393f..a4f9159 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -152,8 +152,8 @@ fn test_extrapolate() { #[test] fn test_cubic_natural() { - let x = array![1., 2., 3., 5., 7., 8., 10.]; - let f_x = array![3., 6., 19., 99., 291., 444., 222.]; + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; let interp = Interp1D::new(x.view(), f_x.view(), Cubic::natural(), Extrapolate::Enable).unwrap(); @@ -210,8 +210,8 @@ fn test_cubic_natural() { #[test] fn test_cubic_clamped() { - let x = array![1., 2., 3., 5., 7., 8., 10.]; - let f_x = array![3., -90., 19., 99., 291., 444., 222.]; + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); @@ -258,6 +258,54 @@ fn test_cubic_clamped() { } } +#[test] +fn test_cubic_not_a_knot() { + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; + + let x = array![1., 2., 3., 5., 7., 8., 10.]; + let f_x = array![3., -90., 19., 99., 291., 444., 222.]; + + let interp = Interp1D::new( + x.view(), + f_x.view(), + Cubic::not_a_knot(), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + // // Left slope = a + // let ys: Array1 = xs_left + // .iter() + // .map(|&x| interp.interpolate(&[x]).unwrap()) + // .collect(); + // let slopes: Array1 = xs_left + // .windows(2) + // .into_iter() + // .zip(ys.windows(2)) + // .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + // .collect(); + // assert_approx_eq!(slopes.mean().unwrap(), a); + + // // Right slope = b + // let ys: Array1 = xs_right + // .iter() + // .map(|&x| interp.interpolate(&[x]).unwrap()) + // .collect(); + // let slopes: Array1 = xs_right + // .windows(2) + // .into_iter() + // .zip(ys.windows(2)) + // .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + // .collect(); + // assert_approx_eq!(slopes.mean().unwrap(), b); +} + // #[test] // fn test_cubic_periodic() { // let x = array![1., 2., 3., 5., 7., 8.];