diff --git a/Cargo.toml b/Cargo.toml index 455a68d..2eb2352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,11 @@ ndarray = ">=0.15.3, <0.17" num-traits = "0.2.15" serde = { version = "1.0.103", optional = true, features = ["derive"] } serde_unit_struct = { version = "0.1.3", optional = true } +# TODO: modify when https://github.com/RReverser/serde-ndim/pull/2 merged +# serde-ndim = { version = "2.0.2", optional = true, features = ["ndarray"] } +serde-ndim = { git = "https://github.com/kylecarow/serde-ndim.git", optional = true, features = [ + "ndarray", +] } thiserror = "1.0.1" [dev-dependencies] @@ -35,4 +40,9 @@ name = "benchmark" harness = false [features] -serde = ["dep:serde", "ndarray/serde", "dep:serde_unit_struct"] +serde = [ + "dep:serde", + "ndarray/serde", + "dep:serde_unit_struct", + "dep:serde-ndim", +] diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index f22a9d2..f66f8eb 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -23,7 +23,8 @@ pub use two::{InterpData2D, InterpData2DOwned, InterpData2DViewed}; deserialize = " D: DataOwned, D::Elem: Deserialize<'de>, - Dim<[usize; N]>: Deserialize<'de>, + Dim<[usize; N]>: Deserialize<'de> + Dimension, + [usize; N]: IntoDimension>, [ArrayBase; N]: Deserialize<'de>, " )) @@ -38,8 +39,11 @@ where /// - 1-D: `[x]` /// - 2-D: `[x, y]` /// - 3-D: `[x, y, z]` + #[cfg_attr(feature = "serde", serde(with = "serde_arr_array"))] pub grid: [ArrayBase; N], /// Function values at coordinates: a single `N`-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(serialize_with = "serde_ndim::serialize"))] + #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_fixed"))] pub values: ArrayBase>, } /// [`InterpData`] that views data. diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 3994cc2..438991f 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -29,8 +29,11 @@ where D::Elem: PartialEq + Debug, { /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(with = "serde_vec_array"))] pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(serialize_with = "serde_ndim::serialize"))] + #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_dyn"))] pub values: ArrayBase, } /// [`InterpDataND`] that views data. @@ -117,7 +120,7 @@ where deserialize = " D: DataOwned, D::Elem: Deserialize<'de>, - S: Deserialize<'de> + S: Deserialize<'de>, " )) )] diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index 0eee454..5d2fbd0 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -397,4 +397,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.1,1.1],[0.2,1.2],[0.3,1.3]],\"values\":[[[0.0,1.0],[2.0,3.0]],[[4.0,5.0],[6.0,7.0]]]}"; + let de0: InterpDataND<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.1,1.1],[0.2,1.2],[0.3,1.3]],\"values\":{\"v\":1,\"dim\":[2,2,2],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0]}}"; + let de1: InterpDataND<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.1,1.1]},{\"v\":1,\"dim\":[2],\"data\":[0.2,1.2]},{\"v\":1,\"dim\":[2],\"data\":[0.3,1.3]}],\"values\":[[[0.0,1.0],[2.0,3.0]],[[4.0,5.0],[6.0,7.0]]]}"; + let de2: InterpDataND<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.1,1.1]},{\"v\":1,\"dim\":[2],\"data\":[0.2,1.2]},{\"v\":1,\"dim\":[2],\"data\":[0.3,1.3]}],\"values\":{\"v\":1,\"dim\":[2,2,2],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0]}}"; + let de3: InterpDataND<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index ca06dbd..6d956e6 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -188,4 +188,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: Interp1DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.0,1.0,2.0,3.0,4.0]],\"values\":[0.2,0.4,0.6,0.8,1.0]}"; + let de0: InterpData1D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.0,1.0,2.0,3.0,4.0]],\"values\":{\"v\":1,\"dim\":[5],\"data\":[0.2,0.4,0.6,0.8,1.0]}}"; + let de1: InterpData1D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[5],\"data\":[0.0,1.0,2.0,3.0,4.0]}],\"values\":[0.2,0.4,0.6,0.8,1.0]}"; + let de2: InterpData1D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[5],\"data\":[0.0,1.0,2.0,3.0,4.0]}],\"values\":{\"v\":1,\"dim\":[5],\"data\":[0.2,0.4,0.6,0.8,1.0]}}"; + let de3: InterpData1D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/three/tests.rs b/src/interpolator/three/tests.rs index 2e2fe13..88d917c 100644 --- a/src/interpolator/three/tests.rs +++ b/src/interpolator/three/tests.rs @@ -217,15 +217,43 @@ fn test_partialeq() { fn test_serde() { let interp = Interp3D::new( array![0., 1.], - array![0., 1.], - array![0., 1.], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - strategy::Nearest, + array![0., 1., 2.], + array![0., 1., 2., 3.], + array![ + [ + [0.6, 0.8, 1.0, 1.2], + [0.8, 1.0, 1.2, 1.4], + [1.0, 1.2, 1.4, 1.6], + ], + [ + [0.8, 1.0, 1.2, 1.4], + [1.0, 1.2, 1.4, 1.6], + [1.2, 1.4, 1.6, 1.8], + ], + ], + strategy::Linear, Extrapolate::Error, ) .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - let de: Interp3DOwned = serde_json::from_str(&ser).unwrap(); + let de: Interp3D<_, _> = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.0,1.0],[0.0,1.0,2.0],[0.0,1.0,2.0,3.0]],\"values\":[[[0.6,0.8,1.0,1.2],[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6]],[[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6],[1.2,1.4,1.6,1.8]]]}"; + let de0: InterpData3D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.0,1.0],[0.0,1.0,2.0],[0.0,1.0,2.0,3.0]],\"values\":{\"v\":1,\"dim\":[2,3,4],\"data\":[0.6,0.8,1.0,1.2,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,1.2,1.4,1.6,1.8]}}"; + let de1: InterpData3D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.0,1.0]},{\"v\":1,\"dim\":[3],\"data\":[0.0,1.0,2.0]},{\"v\":1,\"dim\":[4],\"data\":[0.0,1.0,2.0,3.0]}],\"values\":[[[0.6,0.8,1.0,1.2],[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6]],[[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6],[1.2,1.4,1.6,1.8]]]}"; + let de2: InterpData3D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.0,1.0]},{\"v\":1,\"dim\":[3],\"data\":[0.0,1.0,2.0]},{\"v\":1,\"dim\":[4],\"data\":[0.0,1.0,2.0,3.0]}],\"values\":{\"v\":1,\"dim\":[2,3,4],\"data\":[0.6,0.8,1.0,1.2,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,1.2,1.4,1.6,1.8]}}"; + let de3: InterpData3D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/two/tests.rs b/src/interpolator/two/tests.rs index 8f65026..4f9767c 100644 --- a/src/interpolator/two/tests.rs +++ b/src/interpolator/two/tests.rs @@ -199,4 +199,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: Interp2DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.05,0.1,0.15],[0.1,0.2,0.3]],\"values\":[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]]}"; + let de0: InterpData2D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.05,0.1,0.15],[0.1,0.2,0.3]],\"values\":{\"v\":1,\"dim\":[3,3],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]}}"; + let de1: InterpData2D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[3],\"data\":[0.05,0.1,0.15]},{\"v\":1,\"dim\":[3],\"data\":[0.1,0.2,0.3]}],\"values\":[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]]}"; + let de2: InterpData2D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[3],\"data\":[0.05,0.1,0.15]},{\"v\":1,\"dim\":[3],\"data\":[0.1,0.2,0.3]}],\"values\":{\"v\":1,\"dim\":[3,3],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]}}"; + let de3: InterpData2D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/lib.rs b/src/lib.rs index b2ed07a..95c3011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,11 +57,10 @@ pub(crate) use num_traits::{clamp, Euclid, Num, One}; pub(crate) use dyn_clone::*; #[cfg(feature = "serde")] -pub(crate) use ndarray::DataOwned; +#[path = "serde.rs"] +mod serde_mod; #[cfg(feature = "serde")] -pub(crate) use serde::{Deserialize, Serialize}; -#[cfg(feature = "serde")] -pub(crate) use serde_unit_struct::{Deserialize_unit_struct, Serialize_unit_struct}; +pub(crate) use serde_mod::*; #[cfg(test)] /// Alias for [`approx::assert_abs_diff_eq`] with `epsilon = 1e-6` diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..83b2deb --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,139 @@ +use super::*; + +pub(crate) use ndarray::{DataOwned, IntoDimension}; +pub(crate) use serde::{Deserialize, Serialize}; +pub(crate) use serde_unit_struct::{Deserialize_unit_struct, Serialize_unit_struct}; + +use serde::de::{Deserializer, Error}; +use serde::ser::{SerializeSeq, Serializer}; +use serde_ndim::de::MakeNDim; + +#[derive(Serialize)] +struct ArrayWrapper<'a, D>( + #[serde(serialize_with = "serde_ndim::serialize")] &'a ArrayBase, +) +where + D: Data, + D::Elem: Serialize; + +#[derive(Deserialize)] +#[serde(untagged)] +#[serde(bound = "D::Elem: Deserialize<'de>")] +enum GridType { + VecVec(Vec>), + VecArray(Vec>), +} + +pub(crate) mod serde_arr_array { + use super::*; + + pub fn serialize( + grid: &[ArrayBase; N], + serializer: Ser, + ) -> Result + where + D: Data, + D::Elem: Serialize, + Ser: Serializer, + { + let mut seq = serializer.serialize_seq(Some(N))?; + for arr in grid { + seq.serialize_element(&ArrayWrapper(arr))?; + } + seq.end() + } + + pub fn deserialize<'de, D, const N: usize, De>( + deserializer: De, + ) -> Result<[ArrayBase; N], De::Error> + where + D: DataOwned, + D::Elem: Deserialize<'de> + Debug, + De: Deserializer<'de>, + { + match GridType::deserialize(deserializer)? { + GridType::VecVec(vecs) => vecs.into_iter().map(|v| v.into()).collect(), + GridType::VecArray(arrays) => arrays, + } + .try_into() + .map_err(|e| De::Error::custom(format_args!("expected {N} array(s): {e:?}"))) + } +} + +pub(crate) mod serde_vec_array { + use super::*; + + pub fn serialize( + grid: &[ArrayBase], + serializer: Ser, + ) -> Result + where + D: Data, + D::Elem: Serialize, + Ser: Serializer, + { + let mut seq = serializer.serialize_seq(Some(grid.len()))?; + for arr in grid { + seq.serialize_element(&ArrayWrapper(arr))?; + } + seq.end() + } + + pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> + where + D: DataOwned, + D::Elem: Deserialize<'de>, + De: Deserializer<'de>, + { + Ok(match GridType::deserialize(deserializer)? { + GridType::VecVec(vecs) => vecs.into_iter().map(|v| v.into()).collect(), + GridType::VecArray(arrays) => arrays, + }) + } +} + +#[derive(Deserialize)] +#[serde(untagged)] +#[serde(bound = " + D::Elem: Deserialize<'de>, + ArrayBase: Deserialize<'de>, +")] +enum ValuesType +where + D: DataOwned, + Dim: Dimension, + ArrayBase: MakeNDim, +{ + #[serde(deserialize_with = "serde_ndim::deserialize")] + NDimArray(ArrayBase), + Array(ArrayBase), +} + +pub fn deserialize_fixed<'de, D, const N: usize, De>( + deserializer: De, +) -> Result>, De::Error> +where + D: DataOwned, + D::Elem: Deserialize<'de>, + Dim<[Ix; N]>: Dimension + Deserialize<'de>, + ArrayBase>: MakeNDim, + De: Deserializer<'de>, +{ + Ok(match ValuesType::deserialize(deserializer)? { + ValuesType::NDimArray(values) => values, + ValuesType::Array(values) => values, + }) +} + +pub fn deserialize_dyn<'de, D, De>(deserializer: De) -> Result, De::Error> +where + D: DataOwned, + D::Elem: Deserialize<'de>, + ArrayBase: MakeNDim, + De: Deserializer<'de>, +{ + Ok(match ValuesType::deserialize(deserializer)? { + ValuesType::NDimArray(values) => values, + ValuesType::Array(values) => values, + }) +}