From 3750f74b4d498297f9dabc0a4340767c85b8190a Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 5 May 2025 14:06:57 -0600 Subject: [PATCH 1/6] partway through making serde format tidier --- Cargo.toml | 16 +++++++++++++++- src/interpolator/data.rs | 4 +++- src/interpolator/n/mod.rs | 1 + src/interpolator/n/tests.rs | 2 ++ src/interpolator/one/tests.rs | 2 ++ src/interpolator/three/tests.rs | 2 ++ src/interpolator/two/tests.rs | 2 ++ src/lib.rs | 2 +- 8 files changed, 28 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 455a68d..f73faea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,13 @@ ndarray = ">=0.15.3, <0.17" num-traits = "0.2.15" serde = { version = "1.0.103", optional = true, features = ["derive"] } serde_unit_struct = { version = "0.1.3", optional = true } +# TODO: modify when https://github.com/RReverser/serde-ndim/pull/2 merged +# serde-ndim = { version = "2.0.2", optional = true, features = ["ndarray"] } +# serde-ndim = { path = "../serde-ndim", optional = true, features = ["ndarray"] } +serde-ndim = { git = "https://github.com/kylecarow/serde-ndim.git", optional = true, features = [ + "ndarray", +] } +serde_with = { version = "3.12.0", optional = true } thiserror = "1.0.1" [dev-dependencies] @@ -35,4 +42,11 @@ name = "benchmark" harness = false [features] -serde = ["dep:serde", "ndarray/serde", "dep:serde_unit_struct"] +default = ["serde"] # TODO: remove +serde = [ + "dep:serde", + "ndarray/serde", + "dep:serde_unit_struct", + "dep:serde-ndim", + "dep:serde_with", +] diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index 4a2b474..e0c27f7 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -21,7 +21,8 @@ pub use two::{InterpData2D, InterpData2DOwned, InterpData2DViewed}; deserialize = " D: DataOwned, D::Elem: Deserialize<'de>, - Dim<[usize; N]>: Deserialize<'de>, + Dim<[usize; N]>: Deserialize<'de> + Dimension, + [usize; N]: IntoDimension>, [ArrayBase; N]: Deserialize<'de>, " )) @@ -38,6 +39,7 @@ where /// - 3-D: `[x, y, z]` pub grid: [ArrayBase; N], /// Function values at coordinates: a single `N`-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] pub values: ArrayBase>, } pub type InterpDataViewed = InterpData, N>; diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 4d0f09b..316901e 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -31,6 +31,7 @@ where /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] pub values: ArrayBase, } /// [`InterpDataND`] that views data. diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index 0eee454..0ae13b0 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -395,6 +395,8 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); + // TODO: remove + println!("{ser}"); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index ca06dbd..5eb586d 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -186,6 +186,8 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); + // TODO: remove + println!("{ser}"); let de: Interp1DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/three/tests.rs b/src/interpolator/three/tests.rs index 2e2fe13..50c7013 100644 --- a/src/interpolator/three/tests.rs +++ b/src/interpolator/three/tests.rs @@ -226,6 +226,8 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); + // TODO: remove + println!("{ser}"); let de: Interp3DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/two/tests.rs b/src/interpolator/two/tests.rs index 8f65026..b6a2602 100644 --- a/src/interpolator/two/tests.rs +++ b/src/interpolator/two/tests.rs @@ -197,6 +197,8 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); + // TODO: remove + println!("{ser}"); let de: Interp2DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/lib.rs b/src/lib.rs index 96819cc..ab0bb9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -212,7 +212,7 @@ pub(crate) use num_traits::{clamp, Euclid, Num, One}; pub(crate) use dyn_clone::*; #[cfg(feature = "serde")] -pub(crate) use ndarray::DataOwned; +pub(crate) use ndarray::{DataOwned, IntoDimension}; #[cfg(feature = "serde")] pub(crate) use serde::{Deserialize, Serialize}; #[cfg(feature = "serde")] From 189a2b221315e94359df7a428496eee3e2e8c69a Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Tue, 6 May 2025 09:57:48 -0600 Subject: [PATCH 2/6] grid is working now --- Cargo.toml | 2 - src/interpolator/data.rs | 5 +- src/interpolator/n/mod.rs | 14 ++-- src/interpolator/n/tests.rs | 2 - src/interpolator/one/mod.rs | 4 +- src/interpolator/three/mod.rs | 4 +- src/interpolator/two/mod.rs | 4 +- src/lib.rs | 121 ++++++++++++++++++++++++++++++++++ 8 files changed, 139 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f73faea..dfdf74a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ serde_unit_struct = { version = "0.1.3", optional = true } serde-ndim = { git = "https://github.com/kylecarow/serde-ndim.git", optional = true, features = [ "ndarray", ] } -serde_with = { version = "3.12.0", optional = true } thiserror = "1.0.1" [dev-dependencies] @@ -48,5 +47,4 @@ serde = [ "ndarray/serde", "dep:serde_unit_struct", "dep:serde-ndim", - "dep:serde_with", ] diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index e0c27f7..d325483 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -14,13 +14,13 @@ pub use two::{InterpData2D, InterpData2DOwned, InterpData2DViewed}; feature = "serde", serde(bound( serialize = " - D::Elem: Serialize, + D::Elem: Serialize + Clone, Dim<[usize; N]>: Serialize, [ArrayBase; N]: Serialize, ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Clone, Dim<[usize; N]>: Deserialize<'de> + Dimension, [usize; N]: IntoDimension>, [ArrayBase; N]: Deserialize<'de>, @@ -37,6 +37,7 @@ where /// - 1-D: `[x]` /// - 2-D: `[x, y]` /// - 3-D: `[x, y, z]` + #[cfg_attr(feature = "serde", serde(with = "serde_arrays_2"))] pub grid: [ArrayBase; N], /// Function values at coordinates: a single `N`-dimensional [`ArrayBase`]. #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 316901e..f78ab14 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -12,14 +12,16 @@ mod tests; /// /// See [`InterpData`] and its aliases for concrete-dimensionality interpolator data structs. #[derive(Debug, Clone)] +// #[cfg_attr(feature = "serde", serde_as)] +// #[cfg_attr(feature = "serde", cfg_eval::cfg_eval, serde_as)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr( feature = "serde", serde(bound( - serialize = "D::Elem: Serialize", + serialize = "D::Elem: Serialize + Clone", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Clone, " )) )] @@ -29,6 +31,8 @@ where D::Elem: PartialEq + Debug, { /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. + #[cfg_attr(feature = "serde", serde(with = "crate::serde_arrays"))] + // #[cfg_attr(feature = "serde", serde(serialize_with = "serde_arrs::serialize"))] pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] @@ -109,13 +113,13 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize, + D::Elem: Serialize + Clone, S: Serialize, ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, - S: Deserialize<'de> + D::Elem: Deserialize<'de> + Clone, + S: Deserialize<'de>, " )) )] diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index 0ae13b0..0eee454 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -395,8 +395,6 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - // TODO: remove - println!("{ser}"); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/one/mod.rs b/src/interpolator/one/mod.rs index 2fadd29..03d810b 100644 --- a/src/interpolator/one/mod.rs +++ b/src/interpolator/one/mod.rs @@ -36,12 +36,12 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize, + D::Elem: Serialize + Clone, S: Serialize, ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Clone, S: Deserialize<'de>, " )) diff --git a/src/interpolator/three/mod.rs b/src/interpolator/three/mod.rs index 596dbf2..92019c3 100644 --- a/src/interpolator/three/mod.rs +++ b/src/interpolator/three/mod.rs @@ -41,12 +41,12 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize, + D::Elem: Serialize + Clone, S: Serialize, ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Clone, S: Deserialize<'de>, " )) diff --git a/src/interpolator/two/mod.rs b/src/interpolator/two/mod.rs index a1dd89a..2d995cc 100644 --- a/src/interpolator/two/mod.rs +++ b/src/interpolator/two/mod.rs @@ -40,12 +40,12 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize, + D::Elem: Serialize + Clone, S: Serialize, ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Clone, S: Deserialize<'de>, " )) diff --git a/src/lib.rs b/src/lib.rs index ab0bb9b..8dd1ba5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -259,3 +259,124 @@ mod tests { assert_eq!(wrap(0.8, -1., 1.), 0.8); } } + +#[cfg(feature = "serde")] +mod serde_arrays { + use super::*; + use serde::{Deserializer, Serializer}; + + pub fn serialize(grid: &[ArrayBase], serializer: S) -> Result + where + S: Serializer, + D: Data + RawDataClone + Clone, + D::Elem: Serialize + Clone, + { + let vecs: Vec> = grid.iter().map(|arr| arr.to_vec()).collect(); + vecs.serialize(serializer) + } + + pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> + where + De: Deserializer<'de>, + D: DataOwned + RawDataClone, + D::Elem: Deserialize<'de> + Clone, + { + let vecs: Vec> = Vec::deserialize(deserializer)?; + let arrays = vecs + .into_iter() + .map(|v| ArrayBase::::from_vec(v)) + .collect(); + Ok(arrays) + } +} + +#[cfg(feature = "serde")] +mod serde_arrays_2 { + use super::*; + use serde::de::{Deserializer, Error as DeError, SeqAccess, Visitor}; + use serde::ser::{SerializeSeq, Serializer}; + use std::marker::PhantomData; + + pub fn serialize( + grid: &[ArrayBase; N], + serializer: S, + ) -> Result + where + S: Serializer, + D: Data + RawDataClone + Clone, + D::Elem: Serialize + Clone, + { + let vecs: [Vec; N] = std::array::from_fn(|i| grid[i].to_vec()); + let mut seq = serializer.serialize_seq(Some(N))?; + for vec in &vecs { + seq.serialize_element(vec)?; + } + seq.end() + } + + pub fn deserialize<'de, D, De, const N: usize>( + deserializer: De, + ) -> Result<[ArrayBase; N], De::Error> + where + De: Deserializer<'de>, + D: DataOwned, + D::Elem: Deserialize<'de> + Clone, + { + struct ArrayVisitor(PhantomData); + + impl<'de, D, const N: usize> Visitor<'de> for ArrayVisitor + where + D: DataOwned, + D::Elem: Deserialize<'de> + Clone, + { + type Value = [ArrayBase; N]; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str(&format!("an array of {} arrays", N)) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + // Create a Vec and then try to convert to array + let mut arrays = Vec::with_capacity(N); + + // Handle either format (Vec> or Vec) + for _ in 0..N { + // Try to deserialize as Vec first + if let Ok(vec) = seq.next_element::>() { + if let Some(vec) = vec { + arrays.push(ArrayBase::::from_vec(vec)); + continue; + } + } + + // Then try as ArrayBase + if let Ok(arr) = seq.next_element::>() { + if let Some(arr) = arr { + arrays.push( + ArrayBase::::from_shape_vec(arr.len(), arr.to_vec()) + .map_err(|e| DeError::custom(format!("Shape error: {}", e)))?, + ); + continue; + } + } + + // If we get here, we didn't find a valid element + return Err(DeError::custom(format!( + "Expected {} arrays, found fewer", + N + ))); + } + + // Convert Vec to fixed-size array + arrays + .try_into() + .map_err(|_| DeError::custom(format!("Expected array of length {}", N))) + } + } + + deserializer.deserialize_seq(ArrayVisitor::(PhantomData)) + } +} From 328d21f72d05c3e770a81001324005b0e8fd1e6a Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Wed, 7 May 2025 12:42:04 -0600 Subject: [PATCH 3/6] feature gate behind serde-simple --- Cargo.toml | 10 +-- README.md | 4 + src/interpolator/data.rs | 4 +- src/interpolator/n/mod.rs | 6 +- src/interpolator/n/tests.rs | 1 + src/interpolator/one/tests.rs | 2 - src/interpolator/three/tests.rs | 23 ++++-- src/interpolator/two/tests.rs | 2 - src/lib.rs | 132 ++------------------------------ src/serde.rs | 71 +++++++++++++++++ 10 files changed, 105 insertions(+), 150 deletions(-) create mode 100644 src/serde.rs diff --git a/Cargo.toml b/Cargo.toml index dfdf74a..450697a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ serde = { version = "1.0.103", optional = true, features = ["derive"] } serde_unit_struct = { version = "0.1.3", optional = true } # TODO: modify when https://github.com/RReverser/serde-ndim/pull/2 merged # serde-ndim = { version = "2.0.2", optional = true, features = ["ndarray"] } -# serde-ndim = { path = "../serde-ndim", optional = true, features = ["ndarray"] } serde-ndim = { git = "https://github.com/kylecarow/serde-ndim.git", optional = true, features = [ "ndarray", ] } @@ -41,10 +40,5 @@ name = "benchmark" harness = false [features] -default = ["serde"] # TODO: remove -serde = [ - "dep:serde", - "ndarray/serde", - "dep:serde_unit_struct", - "dep:serde-ndim", -] +serde = ["dep:serde", "ndarray/serde", "dep:serde_unit_struct"] +serde-simple = ["serde", "dep:serde-ndim"] diff --git a/README.md b/README.md index 16aa88e..80852ca 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ cargo add ninterp ``` cargo add ninterp --features serde ``` +- `serde-simple`: same as `serde` feature, with alternate simplified array serde format + ``` + cargo add ninterp --features serde-simple + ``` ## Examples See examples in `new` method documentation: diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index d325483..7cefd2f 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -37,10 +37,10 @@ where /// - 1-D: `[x]` /// - 2-D: `[x, y]` /// - 3-D: `[x, y, z]` - #[cfg_attr(feature = "serde", serde(with = "serde_arrays_2"))] + #[cfg_attr(feature = "serde-simple", serde(with = "serde_arr_array"))] pub grid: [ArrayBase; N], /// Function values at coordinates: a single `N`-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] + #[cfg_attr(feature = "serde-simple", serde(with = "serde_ndim"))] pub values: ArrayBase>, } pub type InterpDataViewed = InterpData, N>; diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index f78ab14..6e10be2 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -12,8 +12,6 @@ mod tests; /// /// See [`InterpData`] and its aliases for concrete-dimensionality interpolator data structs. #[derive(Debug, Clone)] -// #[cfg_attr(feature = "serde", serde_as)] -// #[cfg_attr(feature = "serde", cfg_eval::cfg_eval, serde_as)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr( feature = "serde", @@ -31,11 +29,11 @@ where D::Elem: PartialEq + Debug, { /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde", serde(with = "crate::serde_arrays"))] + #[cfg_attr(feature = "serde-simple", serde(with = "serde_vec_array"))] // #[cfg_attr(feature = "serde", serde(serialize_with = "serde_arrs::serialize"))] pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde", serde(with = "serde_ndim"))] + #[cfg_attr(feature = "serde-simple", serde(with = "serde_ndim"))] pub values: ArrayBase, } /// [`InterpDataND`] that views data. diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index 0eee454..adfeb2d 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -395,6 +395,7 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); + println!("{ser}"); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index 5eb586d..ca06dbd 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -186,8 +186,6 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - // TODO: remove - println!("{ser}"); let de: Interp1DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/three/tests.rs b/src/interpolator/three/tests.rs index 50c7013..fd56d9d 100644 --- a/src/interpolator/three/tests.rs +++ b/src/interpolator/three/tests.rs @@ -217,17 +217,26 @@ fn test_partialeq() { fn test_serde() { let interp = Interp3D::new( array![0., 1.], - array![0., 1.], - array![0., 1.], - array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],], - strategy::Nearest, + array![0., 1., 2.], + array![0., 1., 2., 3.], + array![ + [ + [0.6, 0.8, 1.0, 1.2], + [0.8, 1.0, 1.2, 1.4], + [1.0, 1.2, 1.4, 1.6], + ], + [ + [0.8, 1.0, 1.2, 1.4], + [1.0, 1.2, 1.4, 1.6], + [1.2, 1.4, 1.6, 1.8], + ], + ], + strategy::Linear, Extrapolate::Error, ) .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - // TODO: remove - println!("{ser}"); - let de: Interp3DOwned = serde_json::from_str(&ser).unwrap(); + let de: Interp3D<_, _> = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/interpolator/two/tests.rs b/src/interpolator/two/tests.rs index b6a2602..8f65026 100644 --- a/src/interpolator/two/tests.rs +++ b/src/interpolator/two/tests.rs @@ -197,8 +197,6 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - // TODO: remove - println!("{ser}"); let de: Interp2DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/lib.rs b/src/lib.rs index 8dd1ba5..87abe43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,10 @@ //! ```text //! cargo add ninterp --features serde //! ``` +//! - `serde-simple`: same as `serde` feature, with alternate simplified array serde format +//! ```text +//! cargo add ninterp --features serde-simple +//! ``` //! //! # Examples //! See examples in `new` method documentation: @@ -212,11 +216,10 @@ pub(crate) use num_traits::{clamp, Euclid, Num, One}; pub(crate) use dyn_clone::*; #[cfg(feature = "serde")] -pub(crate) use ndarray::{DataOwned, IntoDimension}; +#[path = "serde.rs"] +mod custom_serde; #[cfg(feature = "serde")] -pub(crate) use serde::{Deserialize, Serialize}; -#[cfg(feature = "serde")] -pub(crate) use serde_unit_struct::{Deserialize_unit_struct, Serialize_unit_struct}; +pub(crate) use crate::custom_serde::*; #[cfg(test)] /// Alias for [`approx::assert_abs_diff_eq`] with `epsilon = 1e-6` @@ -259,124 +262,3 @@ mod tests { assert_eq!(wrap(0.8, -1., 1.), 0.8); } } - -#[cfg(feature = "serde")] -mod serde_arrays { - use super::*; - use serde::{Deserializer, Serializer}; - - pub fn serialize(grid: &[ArrayBase], serializer: S) -> Result - where - S: Serializer, - D: Data + RawDataClone + Clone, - D::Elem: Serialize + Clone, - { - let vecs: Vec> = grid.iter().map(|arr| arr.to_vec()).collect(); - vecs.serialize(serializer) - } - - pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> - where - De: Deserializer<'de>, - D: DataOwned + RawDataClone, - D::Elem: Deserialize<'de> + Clone, - { - let vecs: Vec> = Vec::deserialize(deserializer)?; - let arrays = vecs - .into_iter() - .map(|v| ArrayBase::::from_vec(v)) - .collect(); - Ok(arrays) - } -} - -#[cfg(feature = "serde")] -mod serde_arrays_2 { - use super::*; - use serde::de::{Deserializer, Error as DeError, SeqAccess, Visitor}; - use serde::ser::{SerializeSeq, Serializer}; - use std::marker::PhantomData; - - pub fn serialize( - grid: &[ArrayBase; N], - serializer: S, - ) -> Result - where - S: Serializer, - D: Data + RawDataClone + Clone, - D::Elem: Serialize + Clone, - { - let vecs: [Vec; N] = std::array::from_fn(|i| grid[i].to_vec()); - let mut seq = serializer.serialize_seq(Some(N))?; - for vec in &vecs { - seq.serialize_element(vec)?; - } - seq.end() - } - - pub fn deserialize<'de, D, De, const N: usize>( - deserializer: De, - ) -> Result<[ArrayBase; N], De::Error> - where - De: Deserializer<'de>, - D: DataOwned, - D::Elem: Deserialize<'de> + Clone, - { - struct ArrayVisitor(PhantomData); - - impl<'de, D, const N: usize> Visitor<'de> for ArrayVisitor - where - D: DataOwned, - D::Elem: Deserialize<'de> + Clone, - { - type Value = [ArrayBase; N]; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str(&format!("an array of {} arrays", N)) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - // Create a Vec and then try to convert to array - let mut arrays = Vec::with_capacity(N); - - // Handle either format (Vec> or Vec) - for _ in 0..N { - // Try to deserialize as Vec first - if let Ok(vec) = seq.next_element::>() { - if let Some(vec) = vec { - arrays.push(ArrayBase::::from_vec(vec)); - continue; - } - } - - // Then try as ArrayBase - if let Ok(arr) = seq.next_element::>() { - if let Some(arr) = arr { - arrays.push( - ArrayBase::::from_shape_vec(arr.len(), arr.to_vec()) - .map_err(|e| DeError::custom(format!("Shape error: {}", e)))?, - ); - continue; - } - } - - // If we get here, we didn't find a valid element - return Err(DeError::custom(format!( - "Expected {} arrays, found fewer", - N - ))); - } - - // Convert Vec to fixed-size array - arrays - .try_into() - .map_err(|_| DeError::custom(format!("Expected array of length {}", N))) - } - } - - deserializer.deserialize_seq(ArrayVisitor::(PhantomData)) - } -} diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..f28e06a --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,71 @@ +use super::*; + +pub(crate) use ndarray::{DataOwned, IntoDimension}; +pub(crate) use serde::{Deserialize, Serialize}; +pub(crate) use serde_unit_struct::{Deserialize_unit_struct, Serialize_unit_struct}; + +#[cfg(feature = "serde-simple")] +pub(crate) mod serde_arr_array { + use super::*; + use serde::de::{Deserializer, Error}; + use serde::ser::{SerializeSeq, Serializer}; + + pub fn serialize( + grid: &[ArrayBase; N], + serializer: S, + ) -> Result + where + S: Serializer, + D: Data + RawDataClone + Clone, + D::Elem: Serialize + Clone, + { + let vecs: [Vec; N] = std::array::from_fn(|i| grid[i].to_vec()); + let mut seq = serializer.serialize_seq(Some(N))?; + for vec in &vecs { + seq.serialize_element(vec)?; + } + seq.end() + } + + pub fn deserialize<'de, D, De, const N: usize>( + deserializer: De, + ) -> Result<[ArrayBase; N], De::Error> + where + De: Deserializer<'de>, + D: DataOwned + RawDataClone, + D::Elem: Deserialize<'de> + Clone, + { + let items: Vec> = Deserialize::deserialize(deserializer)?; + let arrays: Vec> = items.into_iter().map(|v| v.into()).collect(); + arrays + .try_into() + .map_err(|_| De::Error::custom(format_args!("Expected {} arrays", N))) + } +} + +#[cfg(feature = "serde-simple")] +pub(crate) mod serde_vec_array { + use super::*; + use serde::{Deserializer, Serializer}; + + pub fn serialize(grid: &[ArrayBase], serializer: S) -> Result + where + S: Serializer, + D: Data + RawDataClone + Clone, + D::Elem: Serialize + Clone, + { + let vecs: Vec> = grid.iter().map(|arr| arr.to_vec()).collect(); + vecs.serialize(serializer) + } + + pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> + where + De: Deserializer<'de>, + D: DataOwned + RawDataClone, + D::Elem: Deserialize<'de> + Clone, + { + let items = Vec::>::deserialize(deserializer)?; + let arrays = items.into_iter().map(|v| v.into()).collect(); + Ok(arrays) + } +} From 9e8804440b2746054e7352689ee038269394fe2a Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Wed, 7 May 2025 14:44:38 -0600 Subject: [PATCH 4/6] remove unnecessary trait bounds --- README.md | 2 +- src/interpolator/n/mod.rs | 1 - src/interpolator/n/tests.rs | 1 - src/lib.rs | 6 +++--- src/serde.rs | 10 +++++----- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 80852ca..c2b2472 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ cargo add ninterp ``` cargo add ninterp --features serde ``` -- `serde-simple`: same as `serde` feature, with alternate simplified array serde format +- `serde-simple`: same as `serde` feature, with alternate simplified array format ``` cargo add ninterp --features serde-simple ``` diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 6e10be2..168b3ed 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -30,7 +30,6 @@ where { /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. #[cfg_attr(feature = "serde-simple", serde(with = "serde_vec_array"))] - // #[cfg_attr(feature = "serde", serde(serialize_with = "serde_arrs::serialize"))] pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. #[cfg_attr(feature = "serde-simple", serde(with = "serde_ndim"))] diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index adfeb2d..0eee454 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -395,7 +395,6 @@ fn test_serde() { .unwrap(); let ser = serde_json::to_string(&interp).unwrap(); - println!("{ser}"); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); } diff --git a/src/lib.rs b/src/lib.rs index 87abe43..98a32bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,7 @@ //! ```text //! cargo add ninterp --features serde //! ``` -//! - `serde-simple`: same as `serde` feature, with alternate simplified array serde format +//! - `serde-simple`: same as `serde` feature, with alternate simplified array format //! ```text //! cargo add ninterp --features serde-simple //! ``` @@ -217,9 +217,9 @@ pub(crate) use dyn_clone::*; #[cfg(feature = "serde")] #[path = "serde.rs"] -mod custom_serde; +mod serde_mod; #[cfg(feature = "serde")] -pub(crate) use crate::custom_serde::*; +pub(crate) use serde_mod::*; #[cfg(test)] /// Alias for [`approx::assert_abs_diff_eq`] with `epsilon = 1e-6` diff --git a/src/serde.rs b/src/serde.rs index f28e06a..bbe034f 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -32,14 +32,14 @@ pub(crate) mod serde_arr_array { ) -> Result<[ArrayBase; N], De::Error> where De: Deserializer<'de>, - D: DataOwned + RawDataClone, - D::Elem: Deserialize<'de> + Clone, + D: DataOwned, + D::Elem: Deserialize<'de>, { let items: Vec> = Deserialize::deserialize(deserializer)?; let arrays: Vec> = items.into_iter().map(|v| v.into()).collect(); arrays .try_into() - .map_err(|_| De::Error::custom(format_args!("Expected {} arrays", N))) + .map_err(|_| De::Error::custom(format_args!("expected {} arrays", N))) } } @@ -61,8 +61,8 @@ pub(crate) mod serde_vec_array { pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> where De: Deserializer<'de>, - D: DataOwned + RawDataClone, - D::Elem: Deserialize<'de> + Clone, + D: DataOwned, + D::Elem: Deserialize<'de>, { let items = Vec::>::deserialize(deserializer)?; let arrays = items.into_iter().map(|v| v.into()).collect(); From a342effe19030a2521ca8f6d9f9f577c0bc6cc70 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Thu, 8 May 2025 15:20:22 -0600 Subject: [PATCH 5/6] simple serde format as default, with ability to deserialize both --- Cargo.toml | 8 ++- README.md | 4 -- src/interpolator/data.rs | 7 ++- src/interpolator/n/mod.rs | 7 ++- src/interpolator/n/tests.rs | 17 +++++ src/interpolator/one/tests.rs | 17 +++++ src/interpolator/three/tests.rs | 17 +++++ src/interpolator/two/tests.rs | 17 +++++ src/serde.rs | 107 +++++++++++++++++++++++++------- 9 files changed, 166 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 450697a..2eb2352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,5 +40,9 @@ name = "benchmark" harness = false [features] -serde = ["dep:serde", "ndarray/serde", "dep:serde_unit_struct"] -serde-simple = ["serde", "dep:serde-ndim"] +serde = [ + "dep:serde", + "ndarray/serde", + "dep:serde_unit_struct", + "dep:serde-ndim", +] diff --git a/README.md b/README.md index 7a6bfb5..3ef4c7a 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,6 @@ cargo add ninterp ``` cargo add ninterp --features serde ``` -- `serde-simple`: same as `serde` feature, with alternate simplified array format - ``` - cargo add ninterp --features serde-simple - ``` ## Examples See examples in `new` method documentation: diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index 7cefd2f..4df0b56 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -20,7 +20,7 @@ pub use two::{InterpData2D, InterpData2DOwned, InterpData2DViewed}; ", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de> + Clone, + D::Elem: Deserialize<'de>, Dim<[usize; N]>: Deserialize<'de> + Dimension, [usize; N]: IntoDimension>, [ArrayBase; N]: Deserialize<'de>, @@ -37,10 +37,11 @@ where /// - 1-D: `[x]` /// - 2-D: `[x, y]` /// - 3-D: `[x, y, z]` - #[cfg_attr(feature = "serde-simple", serde(with = "serde_arr_array"))] + #[cfg_attr(feature = "serde", serde(with = "serde_arr_array"))] pub grid: [ArrayBase; N], /// Function values at coordinates: a single `N`-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde-simple", serde(with = "serde_ndim"))] + #[cfg_attr(feature = "serde", serde(serialize_with = "serde_ndim::serialize"))] + #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_fixed"))] pub values: ArrayBase>, } pub type InterpDataViewed = InterpData, N>; diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 168b3ed..01b3fc1 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -19,7 +19,7 @@ mod tests; serialize = "D::Elem: Serialize + Clone", deserialize = " D: DataOwned, - D::Elem: Deserialize<'de> + Clone, + D::Elem: Deserialize<'de>, " )) )] @@ -29,10 +29,11 @@ where D::Elem: PartialEq + Debug, { /// Coordinate grid: a vector of 1-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde-simple", serde(with = "serde_vec_array"))] + #[cfg_attr(feature = "serde", serde(with = "serde_vec_array"))] pub grid: Vec>, /// Function values at coordinates: a single dynamic-dimensional [`ArrayBase`]. - #[cfg_attr(feature = "serde-simple", serde(with = "serde_ndim"))] + #[cfg_attr(feature = "serde", serde(serialize_with = "serde_ndim::serialize"))] + #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_dyn"))] pub values: ArrayBase, } /// [`InterpDataND`] that views data. diff --git a/src/interpolator/n/tests.rs b/src/interpolator/n/tests.rs index 0eee454..5d2fbd0 100644 --- a/src/interpolator/n/tests.rs +++ b/src/interpolator/n/tests.rs @@ -397,4 +397,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: InterpNDOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.1,1.1],[0.2,1.2],[0.3,1.3]],\"values\":[[[0.0,1.0],[2.0,3.0]],[[4.0,5.0],[6.0,7.0]]]}"; + let de0: InterpDataND<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.1,1.1],[0.2,1.2],[0.3,1.3]],\"values\":{\"v\":1,\"dim\":[2,2,2],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0]}}"; + let de1: InterpDataND<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.1,1.1]},{\"v\":1,\"dim\":[2],\"data\":[0.2,1.2]},{\"v\":1,\"dim\":[2],\"data\":[0.3,1.3]}],\"values\":[[[0.0,1.0],[2.0,3.0]],[[4.0,5.0],[6.0,7.0]]]}"; + let de2: InterpDataND<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.1,1.1]},{\"v\":1,\"dim\":[2],\"data\":[0.2,1.2]},{\"v\":1,\"dim\":[2],\"data\":[0.3,1.3]}],\"values\":{\"v\":1,\"dim\":[2,2,2],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0]}}"; + let de3: InterpDataND<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index ca06dbd..6d956e6 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -188,4 +188,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: Interp1DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.0,1.0,2.0,3.0,4.0]],\"values\":[0.2,0.4,0.6,0.8,1.0]}"; + let de0: InterpData1D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.0,1.0,2.0,3.0,4.0]],\"values\":{\"v\":1,\"dim\":[5],\"data\":[0.2,0.4,0.6,0.8,1.0]}}"; + let de1: InterpData1D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[5],\"data\":[0.0,1.0,2.0,3.0,4.0]}],\"values\":[0.2,0.4,0.6,0.8,1.0]}"; + let de2: InterpData1D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[5],\"data\":[0.0,1.0,2.0,3.0,4.0]}],\"values\":{\"v\":1,\"dim\":[5],\"data\":[0.2,0.4,0.6,0.8,1.0]}}"; + let de3: InterpData1D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/three/tests.rs b/src/interpolator/three/tests.rs index fd56d9d..88d917c 100644 --- a/src/interpolator/three/tests.rs +++ b/src/interpolator/three/tests.rs @@ -239,4 +239,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: Interp3D<_, _> = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.0,1.0],[0.0,1.0,2.0],[0.0,1.0,2.0,3.0]],\"values\":[[[0.6,0.8,1.0,1.2],[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6]],[[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6],[1.2,1.4,1.6,1.8]]]}"; + let de0: InterpData3D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.0,1.0],[0.0,1.0,2.0],[0.0,1.0,2.0,3.0]],\"values\":{\"v\":1,\"dim\":[2,3,4],\"data\":[0.6,0.8,1.0,1.2,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,1.2,1.4,1.6,1.8]}}"; + let de1: InterpData3D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.0,1.0]},{\"v\":1,\"dim\":[3],\"data\":[0.0,1.0,2.0]},{\"v\":1,\"dim\":[4],\"data\":[0.0,1.0,2.0,3.0]}],\"values\":[[[0.6,0.8,1.0,1.2],[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6]],[[0.8,1.0,1.2,1.4],[1.0,1.2,1.4,1.6],[1.2,1.4,1.6,1.8]]]}"; + let de2: InterpData3D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[2],\"data\":[0.0,1.0]},{\"v\":1,\"dim\":[3],\"data\":[0.0,1.0,2.0]},{\"v\":1,\"dim\":[4],\"data\":[0.0,1.0,2.0,3.0]}],\"values\":{\"v\":1,\"dim\":[2,3,4],\"data\":[0.6,0.8,1.0,1.2,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,0.8,1.0,1.2,1.4,1.0,1.2,1.4,1.6,1.2,1.4,1.6,1.8]}}"; + let de3: InterpData3D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/interpolator/two/tests.rs b/src/interpolator/two/tests.rs index 8f65026..4f9767c 100644 --- a/src/interpolator/two/tests.rs +++ b/src/interpolator/two/tests.rs @@ -199,4 +199,21 @@ fn test_serde() { let ser = serde_json::to_string(&interp).unwrap(); let de: Interp2DOwned = serde_json::from_str(&ser).unwrap(); assert_eq!(interp, de); + + // simple format (new serialization output) + let ser0 = "{\"grid\":[[0.05,0.1,0.15],[0.1,0.2,0.3]],\"values\":[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]]}"; + let de0: InterpData2D<_> = serde_json::from_str(&ser0).unwrap(); + assert_eq!(interp.data, de0); + // mixed format (simple grid) + let ser1 = "{\"grid\":[[0.05,0.1,0.15],[0.1,0.2,0.3]],\"values\":{\"v\":1,\"dim\":[3,3],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]}}"; + let de1: InterpData2D<_> = serde_json::from_str(&ser1).unwrap(); + assert_eq!(interp.data, de1); + // mixed format (simple values) + let ser2 = "{\"grid\":[{\"v\":1,\"dim\":[3],\"data\":[0.05,0.1,0.15]},{\"v\":1,\"dim\":[3],\"data\":[0.1,0.2,0.3]}],\"values\":[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]]}"; + let de2: InterpData2D<_> = serde_json::from_str(&ser2).unwrap(); + assert_eq!(interp.data, de2); + // complex format (legacy serialization output) + let ser3 = "{\"grid\":[{\"v\":1,\"dim\":[3],\"data\":[0.05,0.1,0.15]},{\"v\":1,\"dim\":[3],\"data\":[0.1,0.2,0.3]}],\"values\":{\"v\":1,\"dim\":[3,3],\"data\":[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]}}"; + let de3: InterpData2D<_> = serde_json::from_str(&ser3).unwrap(); + assert_eq!(interp.data, de3); } diff --git a/src/serde.rs b/src/serde.rs index bbe034f..3704e92 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -4,20 +4,31 @@ pub(crate) use ndarray::{DataOwned, IntoDimension}; pub(crate) use serde::{Deserialize, Serialize}; pub(crate) use serde_unit_struct::{Deserialize_unit_struct, Serialize_unit_struct}; -#[cfg(feature = "serde-simple")] +use serde::de::{Deserializer, Error}; +use serde::ser::{SerializeSeq, Serializer}; +use serde_ndim::de::MakeNDim; + +#[derive(Deserialize)] +#[serde(untagged)] +#[serde(bound = " + D::Elem: Deserialize<'de>, +")] +enum GridType { + VecVec(Vec>), + VecArray(Vec>), +} + pub(crate) mod serde_arr_array { use super::*; - use serde::de::{Deserializer, Error}; - use serde::ser::{SerializeSeq, Serializer}; - pub fn serialize( + pub fn serialize( grid: &[ArrayBase; N], - serializer: S, - ) -> Result + serializer: Ser, + ) -> Result where - S: Serializer, D: Data + RawDataClone + Clone, D::Elem: Serialize + Clone, + Ser: Serializer, { let vecs: [Vec; N] = std::array::from_fn(|i| grid[i].to_vec()); let mut seq = serializer.serialize_seq(Some(N))?; @@ -27,32 +38,34 @@ pub(crate) mod serde_arr_array { seq.end() } - pub fn deserialize<'de, D, De, const N: usize>( + pub fn deserialize<'de, D, const N: usize, De>( deserializer: De, ) -> Result<[ArrayBase; N], De::Error> where - De: Deserializer<'de>, D: DataOwned, - D::Elem: Deserialize<'de>, + D::Elem: Deserialize<'de> + Debug, + De: Deserializer<'de>, { - let items: Vec> = Deserialize::deserialize(deserializer)?; - let arrays: Vec> = items.into_iter().map(|v| v.into()).collect(); - arrays - .try_into() - .map_err(|_| De::Error::custom(format_args!("expected {} arrays", N))) + match GridType::deserialize(deserializer)? { + GridType::VecVec(vecs) => vecs.into_iter().map(|v| v.into()).collect(), + GridType::VecArray(arrays) => arrays, + } + .try_into() + .map_err(|e| De::Error::custom(format_args!("expected {N} array(s): {e:?}"))) } } -#[cfg(feature = "serde-simple")] pub(crate) mod serde_vec_array { use super::*; - use serde::{Deserializer, Serializer}; - pub fn serialize(grid: &[ArrayBase], serializer: S) -> Result + pub fn serialize( + grid: &[ArrayBase], + serializer: Ser, + ) -> Result where - S: Serializer, D: Data + RawDataClone + Clone, D::Elem: Serialize + Clone, + Ser: Serializer, { let vecs: Vec> = grid.iter().map(|arr| arr.to_vec()).collect(); vecs.serialize(serializer) @@ -60,12 +73,60 @@ pub(crate) mod serde_vec_array { pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> where - De: Deserializer<'de>, D: DataOwned, D::Elem: Deserialize<'de>, + De: Deserializer<'de>, { - let items = Vec::>::deserialize(deserializer)?; - let arrays = items.into_iter().map(|v| v.into()).collect(); - Ok(arrays) + Ok(match GridType::deserialize(deserializer)? { + GridType::VecVec(vecs) => vecs.into_iter().map(|v| v.into()).collect(), + GridType::VecArray(arrays) => arrays, + }) } } + +#[derive(Deserialize)] +#[serde(untagged)] +#[serde(bound = " + D::Elem: Deserialize<'de>, + ArrayBase: Deserialize<'de>, +")] +enum ValuesType +where + D: DataOwned, + Dim: Dimension, + ArrayBase: MakeNDim, +{ + #[serde(deserialize_with = "serde_ndim::deserialize")] + NDimArray(ArrayBase), + Array(ArrayBase), +} + +pub fn deserialize_fixed<'de, D, const N: usize, De>( + deserializer: De, +) -> Result>, De::Error> +where + D: DataOwned, + Dim<[Ix; N]>: Dimension, + ArrayBase>: Deserialize<'de>, + D::Elem: Deserialize<'de>, + ArrayBase>: MakeNDim, + De: Deserializer<'de>, +{ + Ok(match ValuesType::deserialize(deserializer)? { + ValuesType::NDimArray(values) => values, + ValuesType::Array(values) => values, + }) +} + +pub fn deserialize_dyn<'de, D, De>(deserializer: De) -> Result, De::Error> +where + D: DataOwned, + D::Elem: Deserialize<'de>, + ArrayBase: MakeNDim, + De: Deserializer<'de>, +{ + Ok(match ValuesType::deserialize(deserializer)? { + ValuesType::NDimArray(values) => values, + ValuesType::Array(values) => values, + }) +} From 493318c61037714739d5f62e6acaaab7b88673ea Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Thu, 8 May 2025 20:10:32 -0600 Subject: [PATCH 6/6] avoid clone in serialization --- src/interpolator/data.rs | 2 +- src/interpolator/n/mod.rs | 4 ++-- src/interpolator/one/mod.rs | 2 +- src/interpolator/three/mod.rs | 2 +- src/interpolator/two/mod.rs | 2 +- src/serde.rs | 35 +++++++++++++++++++++-------------- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/interpolator/data.rs b/src/interpolator/data.rs index 96252f6..f66f8eb 100644 --- a/src/interpolator/data.rs +++ b/src/interpolator/data.rs @@ -16,7 +16,7 @@ pub use two::{InterpData2D, InterpData2DOwned, InterpData2DViewed}; feature = "serde", serde(bound( serialize = " - D::Elem: Serialize + Clone, + D::Elem: Serialize, Dim<[usize; N]>: Serialize, [ArrayBase; N]: Serialize, ", diff --git a/src/interpolator/n/mod.rs b/src/interpolator/n/mod.rs index 02d3fee..438991f 100644 --- a/src/interpolator/n/mod.rs +++ b/src/interpolator/n/mod.rs @@ -16,7 +16,7 @@ mod tests; #[cfg_attr( feature = "serde", serde(bound( - serialize = "D::Elem: Serialize + Clone", + serialize = "D::Elem: Serialize", deserialize = " D: DataOwned, D::Elem: Deserialize<'de>, @@ -114,7 +114,7 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize + Clone, + D::Elem: Serialize, S: Serialize, ", deserialize = " diff --git a/src/interpolator/one/mod.rs b/src/interpolator/one/mod.rs index 37519d0..e494127 100644 --- a/src/interpolator/one/mod.rs +++ b/src/interpolator/one/mod.rs @@ -38,7 +38,7 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize + Clone, + D::Elem: Serialize, S: Serialize, ", deserialize = " diff --git a/src/interpolator/three/mod.rs b/src/interpolator/three/mod.rs index 2da7ccc..9ec8584 100644 --- a/src/interpolator/three/mod.rs +++ b/src/interpolator/three/mod.rs @@ -43,7 +43,7 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize + Clone, + D::Elem: Serialize, S: Serialize, ", deserialize = " diff --git a/src/interpolator/two/mod.rs b/src/interpolator/two/mod.rs index f3023ed..4952ec1 100644 --- a/src/interpolator/two/mod.rs +++ b/src/interpolator/two/mod.rs @@ -42,7 +42,7 @@ where feature = "serde", serde(bound( serialize = " - D::Elem: Serialize + Clone, + D::Elem: Serialize, S: Serialize, ", deserialize = " diff --git a/src/serde.rs b/src/serde.rs index 3704e92..83b2deb 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -8,11 +8,17 @@ use serde::de::{Deserializer, Error}; use serde::ser::{SerializeSeq, Serializer}; use serde_ndim::de::MakeNDim; +#[derive(Serialize)] +struct ArrayWrapper<'a, D>( + #[serde(serialize_with = "serde_ndim::serialize")] &'a ArrayBase, +) +where + D: Data, + D::Elem: Serialize; + #[derive(Deserialize)] #[serde(untagged)] -#[serde(bound = " - D::Elem: Deserialize<'de>, -")] +#[serde(bound = "D::Elem: Deserialize<'de>")] enum GridType { VecVec(Vec>), VecArray(Vec>), @@ -26,14 +32,13 @@ pub(crate) mod serde_arr_array { serializer: Ser, ) -> Result where - D: Data + RawDataClone + Clone, - D::Elem: Serialize + Clone, + D: Data, + D::Elem: Serialize, Ser: Serializer, { - let vecs: [Vec; N] = std::array::from_fn(|i| grid[i].to_vec()); let mut seq = serializer.serialize_seq(Some(N))?; - for vec in &vecs { - seq.serialize_element(vec)?; + for arr in grid { + seq.serialize_element(&ArrayWrapper(arr))?; } seq.end() } @@ -63,12 +68,15 @@ pub(crate) mod serde_vec_array { serializer: Ser, ) -> Result where - D: Data + RawDataClone + Clone, - D::Elem: Serialize + Clone, + D: Data, + D::Elem: Serialize, Ser: Serializer, { - let vecs: Vec> = grid.iter().map(|arr| arr.to_vec()).collect(); - vecs.serialize(serializer) + let mut seq = serializer.serialize_seq(Some(grid.len()))?; + for arr in grid { + seq.serialize_element(&ArrayWrapper(arr))?; + } + seq.end() } pub fn deserialize<'de, D, De>(deserializer: De) -> Result>, De::Error> @@ -106,9 +114,8 @@ pub fn deserialize_fixed<'de, D, const N: usize, De>( ) -> Result>, De::Error> where D: DataOwned, - Dim<[Ix; N]>: Dimension, - ArrayBase>: Deserialize<'de>, D::Elem: Deserialize<'de>, + Dim<[Ix; N]>: Dimension + Deserialize<'de>, ArrayBase>: MakeNDim, De: Deserializer<'de>, {