From 4c3f21959b16db9a3495d5c06ad2c5359a08a134 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 21:58:59 +0200 Subject: [PATCH 1/9] feat: add v3 typed float storage --- atompack-py/python/atompack/__init__.pyi | 64 +- atompack-py/src/database.rs | 41 +- atompack-py/src/database_batch.rs | 636 +++++++++++------ atompack-py/src/database_flat.rs | 226 +++--- atompack-py/src/lib.rs | 296 ++++++-- atompack-py/src/molecule.rs | 347 ++------- atompack-py/src/molecule_helpers.rs | 852 +++++++++++------------ atompack-py/tests/test_atom_molecule.py | 18 +- atompack-py/tests/test_database.py | 63 ++ atompack/examples/basic_usage.rs | 6 +- atompack/src/atom.rs | 248 +++++-- atompack/src/bin/atompack-bench.rs | 10 +- atompack/src/lib.rs | 4 +- atompack/src/storage/mod.rs | 209 +++++- atompack/src/storage/soa.rs | 842 +++++++++++++++++++++- atompack/tests/throughput_smoke.rs | 12 +- 16 files changed, 2591 insertions(+), 1283 deletions(-) diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index 67dddae..a28c97e 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -7,6 +7,10 @@ import numpy.typing as npt from . import hub as hub +Float1D = npt.NDArray[np.float32] | npt.NDArray[np.float64] +Float2D = npt.NDArray[np.float32] | npt.NDArray[np.float64] +Float3D = npt.NDArray[np.float32] | npt.NDArray[np.float64] + class Atom: """ Represents a single atom with 3D coordinates and atomic number. @@ -115,29 +119,29 @@ class Molecule: def __init__( self, - positions: npt.NDArray[np.float32], + positions: Float2D, atomic_numbers: npt.NDArray[np.uint8], *, energy: float | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ..., + forces: Float2D | None = ..., + charges: Float1D | None = ..., + velocities: Float2D | None = ..., + cell: Float2D | None = ..., + stress: Float2D | None = ..., pbc: tuple[bool, bool, bool] | None = ..., name: str | None = ..., ) -> None: ... @staticmethod def from_arrays( - positions: npt.NDArray[np.float32], + positions: Float2D, atomic_numbers: npt.NDArray[np.uint8], *, energy: float | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ..., + forces: Float2D | None = ..., + charges: Float1D | None = ..., + velocities: Float2D | None = ..., + cell: Float2D | None = ..., + stress: Float2D | None = ..., pbc: tuple[bool, bool, bool] | None = ..., name: str | None = ..., ) -> Molecule: @@ -204,7 +208,7 @@ class Molecule: ... @property - def forces(self) -> npt.NDArray[np.float32] | None: + def forces(self) -> Float2D | None: """ Per-atom forces. @@ -216,7 +220,7 @@ class Molecule: ... @forces.setter - def forces(self, value: npt.NDArray[np.float32]) -> None: ... + def forces(self, value: Float2D) -> None: ... @property def energy(self) -> float | None: """ @@ -232,7 +236,7 @@ class Molecule: @energy.setter def energy(self, value: float | None) -> None: ... @property - def charges(self) -> npt.NDArray[np.float64] | None: + def charges(self) -> Float1D | None: """ Per-atom partial charges. @@ -244,9 +248,9 @@ class Molecule: ... @charges.setter - def charges(self, value: npt.NDArray[np.float64]) -> None: ... + def charges(self, value: Float1D) -> None: ... @property - def velocities(self) -> npt.NDArray[np.float32] | None: + def velocities(self) -> Float2D | None: """ Per-atom velocities. @@ -258,9 +262,9 @@ class Molecule: ... @velocities.setter - def velocities(self, value: npt.NDArray[np.float32]) -> None: ... + def velocities(self, value: Float2D) -> None: ... @property - def cell(self) -> npt.NDArray[np.float64] | None: + def cell(self) -> Float2D | None: """ Unit cell for periodic systems. @@ -272,9 +276,9 @@ class Molecule: ... @cell.setter - def cell(self, value: npt.NDArray[np.float64]) -> None: ... + def cell(self, value: Float2D) -> None: ... @property - def stress(self) -> npt.NDArray[np.float64] | None: + def stress(self) -> Float2D | None: """ Virial stress tensor. @@ -286,9 +290,9 @@ class Molecule: ... @stress.setter - def stress(self, value: npt.NDArray[np.float32] | npt.NDArray[np.float64]) -> None: ... + def stress(self, value: Float2D) -> None: ... @property - def positions(self) -> npt.NDArray[np.float32]: + def positions(self) -> Float2D: """ Atomic positions (read-only). @@ -505,15 +509,15 @@ class Database: ... def add_arrays_batch( self, - positions: npt.NDArray[np.float32], + positions: Float3D, atomic_numbers: npt.NDArray[np.uint8], *, - energy: npt.NDArray[np.float64] | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | None = ..., + energy: Float1D | None = ..., + forces: Float3D | None = ..., + charges: Float2D | None = ..., + velocities: Float3D | None = ..., + cell: Float3D | None = ..., + stress: Float3D | None = ..., pbc: npt.NDArray[np.bool_] | None = ..., name: Sequence[str] | None = ..., properties: dict[str, Any] | None = ..., diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index f472cc9..3cb30e2 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -13,6 +13,7 @@ pub(crate) struct PyAtomDatabase { impl PyAtomDatabase { fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult { let compression = self.inner.compression(); + let record_format = self.inner.record_format(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); if use_mmap { @@ -22,7 +23,7 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", index)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes) + SoaMoleculeView::from_shared_bytes_inner(bytes, record_format) } else { let compressed = self.inner.get_compressed_slice(index).ok_or_else(|| { @@ -43,7 +44,7 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed) + SoaMoleculeView::from_bytes_inner(decompressed, record_format) } }) .map_err(|e| PyValueError::new_err(format!("{}", e))); @@ -55,7 +56,7 @@ impl PyAtomDatabase { let raw = raw_bytes.pop().ok_or_else(|| { PyValueError::new_err(format!("Missing raw bytes for molecule {}", index)) })?; - SoaMoleculeView::from_bytes(raw) + SoaMoleculeView::from_bytes(raw, record_format) } } @@ -117,7 +118,11 @@ impl PyAtomDatabase { /// Add a molecule to the database fn add_molecule(&mut self, molecule: &PyMolecule) -> PyResult<()> { - if let Some((soa_bytes, n_atoms)) = molecule.soa_bytes() { + if let Some(view) = molecule.as_view() + && view.record_format == self.inner.record_format() + { + let soa_bytes = view.bytes.as_slice(); + let n_atoms = view.n_atoms as u32; return self .inner .add_raw_soa_records(&[(soa_bytes, n_atoms)]) @@ -134,10 +139,13 @@ impl PyAtomDatabase { // Split into view-backed (fast path) and owned molecules let mut raw_records: Vec<(&[u8], u32)> = Vec::new(); let mut owned_molecules: Vec = Vec::new(); + let record_format = self.inner.record_format(); for m in &molecules { - if let Some((soa_bytes, n_atoms)) = m.soa_bytes() { - raw_records.push((soa_bytes, n_atoms)); + if let Some(view) = m.as_view() + && view.record_format == record_format + { + raw_records.push((view.bytes.as_slice(), view.n_atoms as u32)); } else { owned_molecules.push(m.clone_as_owned()?); } @@ -179,14 +187,14 @@ impl PyAtomDatabase { fn add_arrays_batch( &mut self, py: Python<'_>, - positions: &Bound<'_, PyArray3>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray2>, - energy: Option<&Bound<'_, PyArray1>>, - forces: Option<&Bound<'_, PyArray3>>, - charges: Option<&Bound<'_, PyArray2>>, - velocities: Option<&Bound<'_, PyArray3>>, - cell: Option<&Bound<'_, PyArray3>>, - stress: Option<&Bound<'_, PyArray3>>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, pbc: Option<&Bound<'_, PyArray2>>, name: Option>, properties: Option<&Bound<'_, PyDict>>, @@ -238,6 +246,7 @@ impl PyAtomDatabase { } let compression = self.inner.compression(); + let record_format = self.inner.record_format(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); let views: Vec = if use_mmap { @@ -250,7 +259,7 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(idx).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", idx)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes) + SoaMoleculeView::from_shared_bytes_inner(bytes, record_format) } else { let compressed = self.inner.get_compressed_slice(idx).ok_or_else(|| { @@ -271,7 +280,7 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed) + SoaMoleculeView::from_bytes_inner(decompressed, record_format) } }) .collect() @@ -283,7 +292,7 @@ impl PyAtomDatabase { .map_err(|e| PyValueError::new_err(format!("{}", e)))?; raw_bytes .into_iter() - .map(SoaMoleculeView::from_bytes) + .map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format)) .collect::>>() .map_err(|e| invalid_data(format!("{}", e))) } diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 477e600..9433a35 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -1,7 +1,7 @@ use super::*; -use crate::molecule::{SoaBuiltinPayloads, SoaCustomSection, build_soa_record_with_custom}; +use crate::molecule::{SoaRecord, SoaSection, build_soa_record}; -struct BatchCustomColumn { +struct BatchSectionColumn { key: String, kind: u8, type_tag: u8, @@ -10,10 +10,10 @@ struct BatchCustomColumn { strings: Option>, } -impl BatchCustomColumn { - fn section_for<'a>(&'a self, index: usize) -> SoaCustomSection<'a> { +impl BatchSectionColumn { + fn section_for<'a>(&'a self, index: usize) -> SoaSection<'a> { if let Some(strings) = &self.strings { - return SoaCustomSection { + return SoaSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, @@ -22,7 +22,7 @@ impl BatchCustomColumn { } let start = index * self.slot_bytes; let end = start + self.slot_bytes; - SoaCustomSection { + SoaSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, @@ -67,7 +67,7 @@ fn extract_string_column( batch: usize, key: &str, kind: u8, -) -> PyResult> { +) -> PyResult> { let Ok(strings) = value.extract::>() else { return Ok(None); }; @@ -79,7 +79,7 @@ fn extract_string_column( batch ))); } - Ok(Some(BatchCustomColumn { + Ok(Some(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_STRING, @@ -94,7 +94,7 @@ fn extract_scalar_column_f64>( batch: usize, key: &str, kind: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); if view.len() != batch { @@ -106,7 +106,7 @@ fn extract_scalar_column_f64>( ))); } let values: Vec = view.iter().copied().map(Into::into).collect(); - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_FLOAT, @@ -121,7 +121,7 @@ fn extract_scalar_column_i64>( batch: usize, key: &str, kind: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); if view.len() != batch { @@ -133,7 +133,7 @@ fn extract_scalar_column_i64>( ))); } let values: Vec = view.iter().copied().map(Into::into).collect(); - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_INT, @@ -149,7 +149,7 @@ fn extract_matrix_column( key: &str, kind: u8, type_tag: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); let shape = view.shape(); @@ -162,7 +162,7 @@ fn extract_matrix_column( let slice = readonly.as_slice().map_err(|_| { PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key)) })?; - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag, @@ -180,7 +180,7 @@ fn extract_vec3_column( kind: u8, type_tag: u8, shape_label: &str, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); let shape = view.shape(); @@ -193,7 +193,7 @@ fn extract_vec3_column( let slice = readonly.as_slice().map_err(|_| { PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key)) })?; - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag, @@ -208,7 +208,7 @@ fn extract_property_column( batch: usize, key: &str, kind: u8, -) -> PyResult> { +) -> PyResult> { if let Some(column) = extract_string_column(value, batch, key, kind)? { return Ok(Some(column)); } @@ -300,7 +300,7 @@ fn extract_atom_property_column( batch: usize, n_atoms: usize, key: &str, -) -> PyResult> { +) -> PyResult> { if let Ok(arr) = value.cast::>() { let column = extract_matrix_column(arr, batch, key, KIND_ATOM_PROP, TYPE_F64_ARRAY)?; if column.slot_bytes != n_atoms * std::mem::size_of::() { @@ -371,7 +371,7 @@ fn extract_custom_columns( atom_properties: Option<&Bound<'_, PyDict>>, batch: usize, n_atoms: usize, -) -> PyResult> { +) -> PyResult> { let mut columns = Vec::new(); let mut seen_keys = std::collections::HashSet::new(); @@ -408,244 +408,418 @@ fn extract_custom_columns( Ok(columns) } -#[allow(clippy::too_many_arguments)] -pub(super) fn add_arrays_batch_impl( - inner: &mut AtomDatabase, - py: Python<'_>, - positions: &Bound<'_, PyArray3>, - atomic_numbers: &Bound<'_, PyArray2>, - energy: Option<&Bound<'_, PyArray1>>, - forces: Option<&Bound<'_, PyArray3>>, - charges: Option<&Bound<'_, PyArray2>>, - velocities: Option<&Bound<'_, PyArray3>>, - cell: Option<&Bound<'_, PyArray3>>, - stress: Option<&Bound<'_, PyArray3>>, - pbc: Option<&Bound<'_, PyArray2>>, - name: Option>, - properties: Option<&Bound<'_, PyDict>>, - atom_properties: Option<&Bound<'_, PyDict>>, -) -> PyResult<()> { - let pos = positions.readonly(); - let pos_arr = pos.as_array(); - let pos_shape = pos_arr.shape(); - if pos_shape.len() != 3 || pos_shape[2] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (batch, n_atoms, 3)", +fn extract_positions_payload(value: &Bound<'_, PyAny>) -> PyResult<(usize, usize, u8, Vec)> { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F32, + bytemuck::cast_slice::(slice).to_vec(), )); } - let batch = pos_shape[0]; - let n_atoms = pos_shape[1]; - let pos_slice = pos_arr - .as_slice() - .ok_or_else(|| PyValueError::new_err("positions must be C-contiguous"))?; + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F64, + bytemuck::cast_slice::(slice).to_vec(), + )); + } + Err(PyValueError::new_err( + "positions must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )) +} - let z = atomic_numbers.readonly(); - let z_arr = z.as_array(); - if z_arr.shape() != [batch, n_atoms] { +fn extract_atomic_numbers_payload( + atomic_numbers: &Bound<'_, PyArray2>, + batch: usize, + n_atoms: usize, +) -> PyResult> { + let readonly = atomic_numbers.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms] { return Err(PyValueError::new_err(format!( "atomic_numbers must have shape ({}, {})", batch, n_atoms ))); } - let z_slice = z_arr + Ok(readonly .as_slice() - .ok_or_else(|| PyValueError::new_err("atomic_numbers must be C-contiguous"))?; - - let energy_ro = energy.map(|arr| arr.readonly()); - let energy_slice = if let Some(ro) = energy_ro.as_ref() { - let view = ro.as_array(); - if view.len() != batch { - return Err(PyValueError::new_err(format!( - "energy length ({}) doesn't match batch size ({})", - view.len(), - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("energy must be C-contiguous"))?, - ) - } else { - None - }; + .map_err(|_| PyValueError::new_err("atomic_numbers must be C-contiguous"))? + .to_vec()) +} - let forces_ro = forces.map(|arr| arr.readonly()); - let forces_slice = if let Some(ro) = forces_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, n_atoms, 3] { - return Err(PyValueError::new_err(format!( - "forces must have shape ({}, {}, 3)", - batch, n_atoms - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("forces must be C-contiguous"))?, - ) - } else { - None - }; +fn extract_builtin_scalar_column( + arr: &Bound<'_, PyArray1>, + batch: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.len() != batch { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match batch size ({})", + key, + view.len(), + batch + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} - let charges_ro = charges.map(|arr| arr.readonly()); - let charges_slice = if let Some(ro) = charges_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, n_atoms] { - return Err(PyValueError::new_err(format!( - "charges must have shape ({}, {})", - batch, n_atoms - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("charges must be C-contiguous"))?, - ) - } else { - None - }; +fn extract_builtin_float_array_column( + arr: &Bound<'_, PyArray2>, + batch: usize, + n_atoms: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, {})", + key, batch, n_atoms + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: n_atoms * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} - let velocities_ro = velocities.map(|arr| arr.readonly()); - let velocities_slice = if let Some(ro) = velocities_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, n_atoms, 3] { - return Err(PyValueError::new_err(format!( - "velocities must have shape ({}, {}, 3)", - batch, n_atoms - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("velocities must be C-contiguous"))?, - ) - } else { - None - }; +fn extract_builtin_vec3_column( + arr: &Bound<'_, PyArray3>, + batch: usize, + n_atoms: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, {}, 3)", + key, batch, n_atoms + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: n_atoms * 3 * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} - let cell_ro = cell.map(|arr| arr.readonly()); - let cell_slice = if let Some(ro) = cell_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "cell must have shape ({}, 3, 3)", - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("cell must be C-contiguous"))?, - ) - } else { - None - }; +fn extract_builtin_mat3_column( + arr: &Bound<'_, PyArray3>, + batch: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3, 3)", + key, batch + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: 9 * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} - let stress_ro = stress.map(|arr| arr.readonly()); - let stress_slice = if let Some(ro) = stress_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "stress must have shape ({}, 3, 3)", - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("stress must be C-contiguous"))?, - ) - } else { - None - }; +fn extract_builtin_pbc_column( + pbc: &Bound<'_, PyArray2>, + batch: usize, +) -> PyResult { + let readonly = pbc.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, 3] { + return Err(PyValueError::new_err(format!( + "pbc must have shape ({}, 3)", + batch + ))); + } + Ok(BatchSectionColumn { + key: "pbc".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_BOOL3, + slot_bytes: 3, + payload: view.iter().map(|value| u8::from(*value)).collect(), + strings: None, + }) +} - let pbc_ro = pbc.map(|arr| arr.readonly()); - let pbc_slice = if let Some(ro) = pbc_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, 3] { - return Err(PyValueError::new_err(format!( - "pbc must have shape ({}, 3)", - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("pbc must be C-contiguous"))?, - ) - } else { - None +fn extract_builtin_name_column( + name: Option>, + batch: usize, +) -> PyResult> { + let Some(names) = name else { + return Ok(None); }; - - if let Some(names) = &name - && names.len() != batch - { + if names.len() != batch { return Err(PyValueError::new_err(format!( "name length ({}) doesn't match batch size ({})", names.len(), batch ))); } + Ok(Some(BatchSectionColumn { + key: "name".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_STRING, + slot_bytes: 0, + payload: Vec::new(), + strings: Some(names), + })) +} - let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; +#[allow(clippy::too_many_arguments)] +pub(super) fn add_arrays_batch_impl( + inner: &mut AtomDatabase, + py: Python<'_>, + positions: &Bound<'_, PyAny>, + atomic_numbers: &Bound<'_, PyArray2>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, + pbc: Option<&Bound<'_, PyArray2>>, + name: Option>, + properties: Option<&Bound<'_, PyDict>>, + atom_properties: Option<&Bound<'_, PyDict>>, +) -> PyResult<()> { + let (batch, n_atoms, positions_type, positions_payload) = extract_positions_payload(positions)?; + let atomic_numbers_payload = extract_atomic_numbers_payload(atomic_numbers, batch, n_atoms)?; + + let mut builtin_columns = Vec::new(); + if let Some(energy) = energy { + if let Ok(arr) = energy.cast::>() { + builtin_columns.push(extract_builtin_scalar_column( + arr, + batch, + "energy", + TYPE_FLOAT32, + )?); + } else if let Ok(arr) = energy.cast::>() { + builtin_columns.push(extract_builtin_scalar_column( + arr, batch, "energy", TYPE_FLOAT, + )?); + } else { + return Err(PyValueError::new_err( + "energy must be a float32 or float64 ndarray with shape (batch,)", + )); + } + } + if let Some(forces) = forces { + if let Ok(arr) = forces.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "forces", + TYPE_VEC3_F32, + )?); + } else if let Ok(arr) = forces.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "forces", + TYPE_VEC3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "forces must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )); + } + } + if let Some(charges) = charges { + if let Ok(arr) = charges.cast::>() { + builtin_columns.push(extract_builtin_float_array_column( + arr, + batch, + n_atoms, + "charges", + TYPE_F32_ARRAY, + )?); + } else if let Ok(arr) = charges.cast::>() { + builtin_columns.push(extract_builtin_float_array_column( + arr, + batch, + n_atoms, + "charges", + TYPE_F64_ARRAY, + )?); + } else { + return Err(PyValueError::new_err( + "charges must be a float32 or float64 ndarray with shape (batch, n_atoms)", + )); + } + } + if let Some(velocities) = velocities { + if let Ok(arr) = velocities.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "velocities", + TYPE_VEC3_F32, + )?); + } else if let Ok(arr) = velocities.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "velocities", + TYPE_VEC3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "velocities must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )); + } + } + if let Some(cell) = cell { + if let Ok(arr) = cell.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "cell", + TYPE_MAT3X3_F32, + )?); + } else if let Ok(arr) = cell.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "cell", + TYPE_MAT3X3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "cell must be a float32 or float64 ndarray with shape (batch, 3, 3)", + )); + } + } + if let Some(stress) = stress { + if let Ok(arr) = stress.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "stress", + TYPE_MAT3X3_F32, + )?); + } else if let Ok(arr) = stress.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "stress", + TYPE_MAT3X3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "stress must be a float32 or float64 ndarray with shape (batch, 3, 3)", + )); + } + } + if let Some(pbc) = pbc { + builtin_columns.push(extract_builtin_pbc_column(pbc, batch)?); + } + if let Some(name) = extract_builtin_name_column(name, batch)? { + builtin_columns.push(name); + } - let build_record = |i: usize| { - let pos_start = i * n_atoms * 3; - let pos_end = pos_start + n_atoms * 3; - let z_start = i * n_atoms; + let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + let positions_slot_bytes = n_atoms + .checked_mul(type_tag_elem_bytes(positions_type)) + .ok_or_else(|| PyValueError::new_err("positions byte length overflow"))?; + let record_format = inner.record_format(); + + let mut records = Vec::with_capacity(batch); + for index in 0..batch { + let pos_start = index * positions_slot_bytes; + let pos_end = pos_start + positions_slot_bytes; + let z_start = index * n_atoms; let z_end = z_start + n_atoms; - let forces_payload = forces_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); - let charges_payload = charges_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[z_start..z_end])); - let velocities_payload = velocities_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); - let mat_start = i * 9; - let mat_end = mat_start + 9; - let cell_payload = cell_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[mat_start..mat_end])); - let stress_payload = stress_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[mat_start..mat_end])); - let pbc_value = pbc_slice - .as_ref() - .map(|slice| [slice[i * 3], slice[i * 3 + 1], slice[i * 3 + 2]]); - let name_value = name.as_ref().map(|names| names[i].as_str()); - let custom_sections: Vec> = custom_columns - .iter() - .map(|column| column.section_for(i)) - .collect(); - - build_soa_record_with_custom( - &pos_slice[pos_start..pos_end], - &z_slice[z_start..z_end], - SoaBuiltinPayloads { - energy: energy_slice.as_ref().map(|slice| slice[i]), - forces: forces_payload, - charges: charges_payload, - velocities: velocities_payload, - cell: cell_payload, - stress: stress_payload, - pbc: pbc_value, - name: name_value, - }, - &custom_sections, - ) - .map(|record| record.into_parts()) - }; - let serialized: Vec<(Vec, u32)> = if batch >= 1024 { - use rayon::prelude::*; - (0..batch) - .into_par_iter() - .map(build_record) - .collect::, _>>() - .map_err(PyValueError::new_err)? - } else { - (0..batch) - .map(build_record) - .collect::, _>>() - .map_err(PyValueError::new_err)? - }; + let mut sections = Vec::with_capacity(builtin_columns.len() + custom_columns.len()); + sections.extend( + builtin_columns + .iter() + .map(|column| column.section_for(index)), + ); + sections.extend( + custom_columns + .iter() + .map(|column| column.section_for(index)), + ); + + let record = build_soa_record(SoaRecord { + record_format, + positions_type, + positions: &positions_payload[pos_start..pos_end], + atomic_numbers: &atomic_numbers_payload[z_start..z_end], + sections: §ions, + }) + .map_err(PyValueError::new_err)?; + records.push((record, n_atoms as u32)); + } - py.detach(|| inner.add_owned_soa_records(serialized)) + py.detach(move || inner.add_owned_soa_records(records)) .map_err(|e| PyValueError::new_err(format!("{}", e))) } diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index c02b187..0ea2722 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -43,68 +43,59 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .copied() .ok_or_else(|| invalid_data("missing final atom offset"))?; - let compression = inner.compression(); - let use_mmap = inner.get_compressed_slice(0).is_some(); + let record_format = inner.record_format(); + let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; - let raw_bytes_owned: Option>>; - let schema: Vec; + let mut positions_type: Option = None; + let mut schema: Vec = Vec::new(); + for bytes in &raw_bytes { + let md = parse_mol_fast_soa(bytes, record_format)?; + match positions_type { + None => positions_type = Some(md.positions_type), + Some(expected) if expected != md.positions_type => { + return Err(invalid_data(format!( + "Position dtype mismatch across selected molecules: expected type tag {}, got {}", + expected, md.positions_type + ))); + } + _ => {} + } - if use_mmap { - if compression == CompressionType::None { - let shared = inner.get_shared_mmap_bytes(indices[0]).ok_or_else(|| { - invalid_data(format!("Missing mmap bytes for molecule {}", indices[0])) - })?; - let first_md = parse_mol_fast_soa(shared.as_slice())?; - let n = first_md.n_atoms; - schema = first_md - .sections - .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; - } else { - let compressed = inner.get_compressed_slice(indices[0]).ok_or_else(|| { - invalid_data(format!( - "Missing compressed bytes for molecule {}", - indices[0] - )) - })?; - let uncompressed_size = - inner.uncompressed_size(indices[0]).ok_or_else(|| { - invalid_data(format!( - "Missing uncompressed size for molecule {}", - indices[0] - )) - })? as usize; - let first_bytes = atompack::decompress_bytes( - compressed, - compression, - Some(uncompressed_size), - )?; - let first_md = parse_mol_fast_soa(&first_bytes)?; - let n = first_md.n_atoms; - schema = first_md - .sections + for section in &md.sections { + let incoming = section_schema_from_ref(section, md.n_atoms)?; + if let Some(existing) = schema .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; + .find(|candidate| candidate.kind == incoming.kind && candidate.key == incoming.key) + { + if existing.type_tag != incoming.type_tag + || existing.per_atom != incoming.per_atom + || existing.elem_bytes != incoming.elem_bytes + || existing.slot_bytes != incoming.slot_bytes + { + return Err(invalid_data(format!( + "SOA schema mismatch for section '{}'", + incoming.key + ))); + } + } else { + schema.push(incoming); + } } - raw_bytes_owned = None; - } else { - let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; - let first_md = parse_mol_fast_soa(&raw_bytes[0])?; - let n = first_md.n_atoms; - schema = first_md - .sections - .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; - raw_bytes_owned = Some(raw_bytes); } + let positions_type = + positions_type.ok_or_else(|| invalid_data("Missing position dtype for batch"))?; - let schema_keys: Vec<(u8, &[u8])> = - schema.iter().map(|s| (s.kind, s.key.as_bytes())).collect(); - - let mut positions = vec![0f32; total_atoms * 3]; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let mut positions = vec![0u8; total_atoms * positions_stride]; let mut atomic_numbers = vec![0u8; total_atoms]; let mut section_buffers: Vec> = schema @@ -153,23 +144,21 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .collect(); let process_mol = |i: usize, mol_bytes: &[u8]| -> atompack::Result<()> { - let md = parse_mol_fast_soa(mol_bytes)?; + let md = parse_mol_fast_soa(mol_bytes, record_format)?; let atom_off = offsets[i]; let n = md.n_atoms; - if md.sections.len() != schema.len() { + if md.positions_type != positions_type { return Err(invalid_data(format!( - "SOA schema mismatch for molecule {}: expected {} sections, got {}", - i, - schema.len(), - md.sections.len() + "Position dtype mismatch for molecule {}: expected type tag {}, got {}", + i, positions_type, md.positions_type ))); } unsafe { std::ptr::copy_nonoverlapping( md.positions_bytes.as_ptr(), - pos_buf.at(atom_off * 3) as *mut u8, - n * 12, + pos_buf.at(atom_off * positions_stride), + n * positions_stride, ); std::ptr::copy_nonoverlapping( md.atomic_numbers_bytes.as_ptr(), @@ -178,15 +167,21 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ); } - for (section_idx, sec) in md.sections.iter().enumerate() { - let schema_entry = &schema[section_idx]; - let expected_key = &schema_keys[section_idx]; - if sec.kind != expected_key.0 || sec.key.as_bytes() != expected_key.1 { + for (section_idx, schema_entry) in schema.iter().enumerate() { + let sec = md.sections.iter().find(|sec| { + sec.kind == schema_entry.kind && sec.key == schema_entry.key + }); + let Some(sec) = sec else { + continue; + }; + + if sec.type_tag != schema_entry.type_tag { return Err(invalid_data(format!( - "SOA schema order mismatch at molecule {} for section '{}'", + "SOA schema mismatch at molecule {} for section '{}'", i, sec.key ))); } + if schema_entry.per_atom { let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { invalid_data(format!("Section '{}' payload length overflow", sec.key)) @@ -199,6 +194,15 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( expected ))); } + } else if schema_entry.slot_bytes != 0 + && sec.payload.len() != schema_entry.slot_bytes + { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, + sec.payload.len(), + schema_entry.slot_bytes + ))); } if schema_entry.slot_bytes == 0 { @@ -237,51 +241,16 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( Ok(()) }; - let results: Vec> = if use_mmap { - (0..n_mols) - .into_par_iter() - .map(|i| { - let idx = indices[i]; - if compression == CompressionType::None { - let shared = inner.get_shared_mmap_bytes(idx).ok_or_else(|| { - invalid_data(format!("Missing mmap bytes for molecule {}", idx)) - })?; - process_mol(i, shared.as_slice()) - } else { - let compressed = inner.get_compressed_slice(idx).ok_or_else(|| { - invalid_data(format!( - "Missing compressed bytes for molecule {}", - idx - )) - })?; - let uncompressed_size = - inner.uncompressed_size(idx).ok_or_else(|| { - invalid_data(format!( - "Missing uncompressed size for molecule {}", - idx - )) - })? as usize; - let decompressed = atompack::decompress_bytes( - compressed, - compression, - Some(uncompressed_size), - )?; - process_mol(i, &decompressed) - } - }) - .collect() - } else { - let raw_bytes = raw_bytes_owned.unwrap(); - raw_bytes - .par_iter() - .enumerate() - .map(|(i, bytes)| process_mol(i, bytes)) - .collect() - }; + let results: Vec> = raw_bytes + .par_iter() + .enumerate() + .map(|(i, bytes)| process_mol(i, bytes)) + .collect(); results.into_iter().collect::>>()?; Ok(Some(( n_atoms_vec, + positions_type, positions, atomic_numbers, schema, @@ -295,6 +264,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let ( n_atoms_vec, + positions_type, positions, atomic_numbers, schema, @@ -320,12 +290,22 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let dict = PyDict::new(py); dict.set_item("n_atoms", PyArray1::from_vec(py, n_atoms_vec))?; - dict.set_item( - "positions", - PyArray1::from_vec(py, positions) - .reshape([total_atoms, 3]) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?, - )?; + match positions_type { + TYPE_VEC3_F32 => { + let arr = cast_or_decode_f32(&positions)?; + dict.set_item("positions", pyarray2_from_cow(py, arr, total_atoms, 3)?)?; + } + TYPE_VEC3_F64 => { + let arr = cast_or_decode_f64(&positions)?; + dict.set_item("positions", pyarray2_from_cow(py, arr, total_atoms, 3)?)?; + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + } dict.set_item("atomic_numbers", PyArray1::from_vec(py, atomic_numbers))?; let atom_props_dict = PyDict::new(py); @@ -361,6 +341,10 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let arr = cast_or_decode_f64(&buf)?; target.set_item(&s.key, pyarray1_from_cow(py, arr))?; } + TYPE_FLOAT32 => { + let arr = cast_or_decode_f32(&buf)?; + target.set_item(&s.key, pyarray1_from_cow(py, arr))?; + } TYPE_INT => { let arr = cast_or_decode_i64(&buf)?; target.set_item(&s.key, pyarray1_from_cow(py, arr))?; @@ -411,6 +395,16 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .map_err(|e| PyValueError::new_err(format!("{}", e)))?, )?; } + TYPE_MAT3X3_F32 => { + let arr = cast_or_decode_f32(&buf)?; + let n = arr.len() / 9; + target.set_item( + &s.key, + pyarray1_from_cow(py, arr) + .reshape([n, 3, 3]) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?, + )?; + } _ => {} } } diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index dcca0c5..f62eaa3 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -8,8 +8,8 @@ #![allow(unsafe_op_in_unsafe_fn)] use atompack::{ - Atom, AtomDatabase, Molecule, SharedMmapBytes, atom::PropertyValue, - compression::CompressionType, + Atom, AtomDatabase, FloatArrayData, FloatScalarData, Mat3Data, Molecule, SharedMmapBytes, + Vec3Data, atom::PropertyValue, compression::CompressionType, }; use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayMethods}; use pyo3::exceptions::{PyFileExistsError, PyIndexError, PyKeyError, PyTypeError, PyValueError}; @@ -111,6 +111,11 @@ const TYPE_VEC3_F64: u8 = 7; const TYPE_I32_ARRAY: u8 = 8; const TYPE_BOOL3: u8 = 9; const TYPE_MAT3X3_F64: u8 = 10; +const TYPE_FLOAT32: u8 = 11; +const TYPE_MAT3X3_F32: u8 = 12; + +const RECORD_FORMAT_SOA_V2: u32 = 2; +const RECORD_FORMAT_SOA_V3: u32 = 3; /// A single parsed section reference (zero-copy into decompressed bytes). #[derive(Clone)] @@ -124,7 +129,8 @@ struct SectionRef<'a> { /// Per-molecule extracted data (references into decompressed bytes). struct MolData<'a> { n_atoms: usize, - positions_bytes: &'a [u8], // n_atoms * 12 + positions_type: u8, + positions_bytes: &'a [u8], atomic_numbers_bytes: &'a [u8], // n_atoms sections: Vec>, } @@ -145,11 +151,31 @@ struct SectionSchema { /// [n_atoms:u32][positions:n*12][atomic_numbers:n] /// [n_sections:u16] /// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] -fn parse_mol_fast_soa(bytes: &[u8]) -> atompack::Result> { +fn parse_mol_fast_soa(bytes: &[u8], record_format: u32) -> atompack::Result> { let mut pos = 0usize; let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => read_u8_at(bytes, &mut pos, "SOA positions type")?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; let positions_len = n_atoms - .checked_mul(12) + .checked_mul(positions_stride) .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; @@ -175,6 +201,7 @@ fn parse_mol_fast_soa(bytes: &[u8]) -> atompack::Result> { Ok(MolData { n_atoms, + positions_type, positions_bytes, atomic_numbers_bytes, sections, @@ -199,7 +226,9 @@ fn section_schema_from_ref( elem_bytes } TYPE_FLOAT | TYPE_INT => 8, + TYPE_FLOAT32 => 4, TYPE_BOOL3 => 3, + TYPE_MAT3X3_F32 => 36, TYPE_MAT3X3_F64 => 72, _ => section.payload.len(), }; @@ -241,7 +270,7 @@ fn validate_section_payload( ))); } } - TYPE_FLOAT | TYPE_INT | TYPE_BOOL3 | TYPE_MAT3X3_F64 => { + TYPE_FLOAT | TYPE_INT | TYPE_FLOAT32 | TYPE_BOOL3 | TYPE_MAT3X3_F32 | TYPE_MAT3X3_F64 => { if section.payload.len() != slot_bytes { return Err(invalid_data(format!( "Section '{}' has invalid payload length {} (expected {})", @@ -299,6 +328,8 @@ fn type_tag_elem_bytes(tag: u8) -> usize { TYPE_VEC3_F64 => 24, TYPE_I32_ARRAY => 4, TYPE_BOOL3 => 3, + TYPE_FLOAT32 => 4, + TYPE_MAT3X3_F32 => 36, TYPE_MAT3X3_F64 => 72, _ => 0, } @@ -354,8 +385,11 @@ impl std::ops::Deref for SoaBytes { struct SoaMoleculeView { bytes: SoaBytes, + record_format: u32, n_atoms: usize, + positions_type: u8, positions_start: usize, + positions_len: usize, atomic_numbers_start: usize, // Known builtins — zero-alloc, set during from_bytes forces: Option, @@ -372,17 +406,46 @@ struct SoaMoleculeView { impl SoaMoleculeView { /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. - fn from_storage_inner(bytes: SoaBytes) -> atompack::Result { + fn from_storage_inner(bytes: SoaBytes, record_format: u32) -> atompack::Result { if bytes.len() < 6 { return Err(invalid_data("SOA record too small")); } let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; let mut pos = 4usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => { + if pos + 1 > bytes.len() { + return Err(invalid_data("SOA record truncated at positions type")); + } + let tag = bytes[pos]; + pos += 1; + tag + } + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; let positions_start = pos; - pos = n_atoms - .checked_mul(12) - .and_then(|n| pos.checked_add(n)) + let positions_len = n_atoms + .checked_mul(positions_stride) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + pos = pos + .checked_add(positions_len) .ok_or_else(|| invalid_data("SOA positions overflow"))?; if pos > bytes.len() { return Err(invalid_data("SOA record truncated at positions")); @@ -478,8 +541,11 @@ impl SoaMoleculeView { Ok(Self { bytes, + record_format, n_atoms, + positions_type, positions_start, + positions_len, atomic_numbers_start, forces, energy, @@ -493,21 +559,25 @@ impl SoaMoleculeView { }) } - fn from_bytes_inner(bytes: Vec) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Owned(bytes)) + fn from_bytes_inner(bytes: Vec, record_format: u32) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Owned(bytes), record_format) } - fn from_shared_bytes_inner(bytes: SharedMmapBytes) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Shared(bytes)) + fn from_shared_bytes_inner( + bytes: SharedMmapBytes, + record_format: u32, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Shared(bytes), record_format) } /// Thin wrapper for call sites that need PyResult. - fn from_bytes(bytes: Vec) -> PyResult { - Self::from_bytes_inner(bytes).map_err(|e| PyValueError::new_err(format!("{}", e))) + fn from_bytes(bytes: Vec, record_format: u32) -> PyResult { + Self::from_bytes_inner(bytes, record_format) + .map_err(|e| PyValueError::new_err(format!("{}", e))) } fn positions_bytes(&self) -> &[u8] { - &self.bytes[self.positions_start..self.positions_start + self.n_atoms * 12] + &self.bytes[self.positions_start..self.positions_start + self.positions_len] } fn atomic_numbers_bytes(&self) -> &[u8] { @@ -549,18 +619,45 @@ impl SoaMoleculeView { if index >= self.n_atoms { return Ok(None); } - let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; - Ok(Some(Atom::new( - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), - self.atomic_numbers_bytes()[index], - ))) + let atomic_number = self.atomic_numbers_bytes()[index]; + Ok(Some(match self.positions_type { + TYPE_VEC3_F32 => { + let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; + Atom::new( + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), + atomic_number, + ) + } + TYPE_VEC3_F64 => { + let pos = &self.positions_bytes()[index * 24..(index + 1) * 24]; + Atom::new( + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "atom x")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "atom y")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "atom z")?) as f32, + atomic_number, + ) + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + })) } fn energy(&self) -> PyResult> { match self.energy { - Some(slot) => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), + Some(slot) => match slot.2 { + TYPE_FLOAT => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), + TYPE_FLOAT32 => Ok(Some(read_f32_scalar(self.builtin_payload(slot))? as f64)), + other => Err(PyValueError::new_err(format!( + "Unsupported energy type tag {}", + other + ))), + }, None => Ok(None), } } @@ -579,32 +676,94 @@ impl SoaMoleculeView { } fn materialize(&self) -> PyResult { - let positions: Vec<[f32; 3]> = (0..self.n_atoms) - .map(|i| { - let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; - Ok([ - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), - ]) - }) - .collect::>()?; let atomic_numbers = self.atomic_numbers_bytes().to_vec(); - let mut molecule = - Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)?; + let mut molecule = match self.positions_type { + TYPE_VEC3_F32 => { + let positions: Vec<[f32; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; + Ok([ + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + TYPE_VEC3_F64 => { + let positions: Vec<[f64; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 24..(i + 1) * 24]; + Ok([ + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "position x")?), + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "position y")?), + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new_f64(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + }; // Builtins if let Some(slot) = self.charges { - molecule.charges = Some(decode_f64_array(self.builtin_payload(slot))?); + molecule.charges = Some(match slot.2 { + TYPE_F32_ARRAY => { + FloatArrayData::F32(decode_f32_array(self.builtin_payload(slot))?) + } + TYPE_F64_ARRAY => { + FloatArrayData::F64(decode_f64_array(self.builtin_payload(slot))?) + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported charges type tag {}", + other + ))); + } + }); } if let Some(slot) = self.cell { - molecule.cell = Some(decode_mat3x3_f64(self.builtin_payload(slot))?); + molecule.cell = Some(match slot.2 { + TYPE_MAT3X3_F32 => Mat3Data::F32(decode_mat3x3_f32(self.builtin_payload(slot))?), + TYPE_MAT3X3_F64 => Mat3Data::F64(decode_mat3x3_f64(self.builtin_payload(slot))?), + other => { + return Err(PyValueError::new_err(format!( + "Unsupported cell type tag {}", + other + ))); + } + }); } if let Some(slot) = self.energy { - molecule.energy = Some(read_f64_scalar(self.builtin_payload(slot))?); + molecule.energy = Some(match slot.2 { + TYPE_FLOAT => FloatScalarData::F64(read_f64_scalar(self.builtin_payload(slot))?), + TYPE_FLOAT32 => FloatScalarData::F32(read_f32_scalar(self.builtin_payload(slot))?), + other => { + return Err(PyValueError::new_err(format!( + "Unsupported energy type tag {}", + other + ))); + } + }); } if let Some(slot) = self.forces { - molecule.forces = Some(decode_vec3_f32(self.builtin_payload(slot))?); + molecule.forces = Some(match slot.2 { + TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(self.builtin_payload(slot))?), + TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(self.builtin_payload(slot))?), + other => { + return Err(PyValueError::new_err(format!( + "Unsupported forces type tag {}", + other + ))); + } + }); } if let Some(slot) = self.name { let payload = self.builtin_payload(slot); @@ -622,10 +781,28 @@ impl SoaMoleculeView { molecule.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); } if let Some(slot) = self.stress { - molecule.stress = Some(decode_mat3x3_f64(self.builtin_payload(slot))?); + molecule.stress = Some(match slot.2 { + TYPE_MAT3X3_F32 => Mat3Data::F32(decode_mat3x3_f32(self.builtin_payload(slot))?), + TYPE_MAT3X3_F64 => Mat3Data::F64(decode_mat3x3_f64(self.builtin_payload(slot))?), + other => { + return Err(PyValueError::new_err(format!( + "Unsupported stress type tag {}", + other + ))); + } + }); } if let Some(slot) = self.velocities { - molecule.velocities = Some(decode_vec3_f32(self.builtin_payload(slot))?); + molecule.velocities = Some(match slot.2 { + TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(self.builtin_payload(slot))?), + TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(self.builtin_payload(slot))?), + other => { + return Err(PyValueError::new_err(format!( + "Unsupported velocities type tag {}", + other + ))); + } + }); } // Custom properties (lazy — key parsed here only) @@ -661,6 +838,16 @@ fn read_f64_scalar(payload: &[u8]) -> PyResult { )?)) } +fn read_f32_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 4 { + return Err(PyValueError::new_err("Invalid f32 payload length")); + } + Ok(f32::from_le_bytes(py_slice_to_array( + payload, + "f32 payload", + )?)) +} + fn read_i64_scalar(payload: &[u8]) -> PyResult { if payload.len() != 8 { return Err(PyValueError::new_err("Invalid i64 payload length")); @@ -786,6 +973,29 @@ fn decode_mat3x3_f64(payload: &[u8]) -> PyResult<[[f64; 3]; 3]> { ]) } +fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(PyValueError::new_err("Invalid mat3x3 payload length")); + } + Ok([ + [ + f32::from_le_bytes(py_slice_to_array(&payload[0..4], "mat3x3 [0][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[4..8], "mat3x3 [0][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[8..12], "mat3x3 [0][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[12..16], "mat3x3 [1][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[16..20], "mat3x3 [1][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[20..24], "mat3x3 [1][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[24..28], "mat3x3 [2][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[28..32], "mat3x3 [2][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[32..36], "mat3x3 [2][2]")?), + ], + ]) +} + fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { Ok(match type_tag { TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 78a83e9..6fce82a 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -59,13 +59,11 @@ enum MoleculeBacking { mod helpers; pub(crate) use self::helpers::{ - SoaBuiltinPayloads, SoaCustomSection, build_soa_record_with_custom, cast_or_decode_f32, - cast_or_decode_f64, cast_or_decode_i32, cast_or_decode_i64, pyarray1_from_cow, - pyarray2_from_cow, -}; -use self::helpers::{ - into_py_any, property_section_to_pyobject, property_value_to_pyobject, pyarray2_from_flat, + SoaRecord, SoaSection, build_soa_record, cast_or_decode_f32, cast_or_decode_f64, + cast_or_decode_i32, cast_or_decode_i64, parse_float_array_field, parse_mat3_field, + parse_vec3_field, pyarray1_from_cow, pyarray2_from_cow, }; +use self::helpers::{into_py_any, property_section_to_pyobject, property_value_to_pyobject}; #[pymethods] impl PyMolecule { @@ -87,13 +85,13 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] fn new( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, @@ -113,17 +111,7 @@ impl PyMolecule { ) } - /// Create a molecule from numpy arrays (fast path). - /// - /// Parameters: - /// - positions: float32 array of shape (n_atoms, 3) - /// - atomic_numbers: uint8 array of shape (n_atoms,) - /// - builtins: optional keyword arguments such as energy, forces, charges, - /// velocities, cell, stress, pbc, and name - /// - /// Builds an SOA view directly from the numpy buffers — no intermediate - /// Atom structs are created. If you later mutate the molecule (e.g. set - /// energy), it will be materialized on demand. + /// Create a molecule from numpy arrays. #[staticmethod] #[pyo3(signature = ( positions, @@ -141,13 +129,13 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] fn from_arrays( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, @@ -214,9 +202,9 @@ impl PyMolecule { copy_arrays: bool, ) -> PyResult> { let numbers = self.atomic_numbers_py(py)?.into_any().unbind(); - let positions = self.positions_py(py)?.into_any().unbind(); + let positions = self.positions_py(py)?; let cell = match self.cell_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let pbc = match self.pbc()? { @@ -224,7 +212,7 @@ impl PyMolecule { None => py.None(), }; let velocities = match self.velocities_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let energy = match self.energy()? { @@ -232,15 +220,15 @@ impl PyMolecule { None => py.None(), }; let forces = match self.forces_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let stress = match self.stress_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let charges = match self.charges_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; @@ -299,59 +287,16 @@ impl PyMolecule { /// forces property (getter) #[getter] - fn forces<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn forces<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .forces - .as_ref() - .map(|forces| { - let n_atoms = forces.len(); - let flat: Vec = forces.iter().flat_map(|f| [f[0], f[1], f[2]]).collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.forces else { - return Ok(None); - }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid forces section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f32(payload)?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) + slf.borrow().forces_py(py) } /// forces property (setter) #[setter] - fn set_forces(&mut self, forces: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = forces.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape[1] != 3 { - return Err(PyValueError::new_err("Forces must have shape (n_atoms, 3)")); - } - - let forces_vec: Vec<[f32; 3]> = arr - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - - if forces_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Forces length ({}) doesn't match atom count ({})", - forces_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.forces = Some(forces_vec); + fn set_forces(&mut self, py: Python<'_>, forces: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.forces = Some(parse_vec3_field(forces.bind(py), "forces", n_atoms)?); Ok(()) } @@ -359,7 +304,7 @@ impl PyMolecule { #[getter] fn energy(&self) -> PyResult> { if let Some(inner) = self.as_owned() { - Ok(inner.energy) + Ok(inner.energy.as_ref().map(FloatScalarData::as_f64)) } else if let Some(view) = self.as_view() { view.energy() } else { @@ -370,242 +315,76 @@ impl PyMolecule { /// Set energy #[setter] fn set_energy(&mut self, energy: Option) -> PyResult<()> { - self.ensure_owned()?.energy = energy; + self.ensure_owned()?.energy = energy.map(FloatScalarData::F64); Ok(()) } /// charges property (getter) #[getter] - fn charges<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn charges<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return Ok(inner - .charges - .as_ref() - .map(|charges| PyArray1::from_slice(py, charges))); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.charges else { - return Ok(None); - }; - if slot.2 != TYPE_F64_ARRAY || slot.1 != view.n_atoms * 8 { - return Err(PyValueError::new_err("Invalid charges section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray1_from_cow(py, data))) + slf.borrow().charges_py(py) } /// charges property (setter) #[setter] - fn set_charges(&mut self, charges: &Bound<'_, PyArray1>) -> PyResult<()> { - let charges_vec: Vec = charges.readonly().as_array().to_vec(); - - if charges_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Charges length ({}) doesn't match atom count ({})", - charges_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.charges = Some(charges_vec); + fn set_charges(&mut self, py: Python<'_>, charges: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.charges = Some(parse_float_array_field( + charges.bind(py), + "charges", + n_atoms, + )?); Ok(()) } /// velocities property (getter) #[getter] - fn velocities<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn velocities<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .velocities - .as_ref() - .map(|vels| { - let n_atoms = vels.len(); - let flat: Vec = vels.iter().flat_map(|v| [v[0], v[1], v[2]]).collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.velocities else { - return Ok(None); - }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid velocities section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f32(payload)?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) + slf.borrow().velocities_py(py) } /// velocities property (setter) #[setter] - fn set_velocities(&mut self, velocities: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = velocities.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Velocities must have shape (n_atoms, 3)", - )); - } - - let vel_vec: Vec<[f32; 3]> = arr - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - - if vel_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Velocities length ({}) doesn't match atom count ({})", - vel_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.velocities = Some(vel_vec); + fn set_velocities(&mut self, py: Python<'_>, velocities: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.velocities = Some(parse_vec3_field( + velocities.bind(py), + "velocities", + n_atoms, + )?); Ok(()) } /// cell property (getter) #[getter] - fn cell<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn cell<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .cell - .as_ref() - .map(|cell| { - let flat: Vec = cell - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.cell else { - return Ok(None); - }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid cell section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) + slf.borrow().cell_py(py) } /// cell property (setter) #[setter] - fn set_cell(&mut self, cell: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = cell.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Cell must have shape (3, 3)")); - } - - let cell_array: [[f64; 3]; 3] = [ - [arr[[0, 0]], arr[[0, 1]], arr[[0, 2]]], - [arr[[1, 0]], arr[[1, 1]], arr[[1, 2]]], - [arr[[2, 0]], arr[[2, 1]], arr[[2, 2]]], - ]; - - self.ensure_owned()?.cell = Some(cell_array); + fn set_cell(&mut self, py: Python<'_>, cell: Py) -> PyResult<()> { + self.ensure_owned()?.cell = Some(parse_mat3_field(cell.bind(py), "cell")?); Ok(()) } /// stress property (getter) #[getter] - fn stress<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn stress<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .stress - .as_ref() - .map(|stress| { - let flat: Vec = stress - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.stress else { - return Ok(None); - }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid stress section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) + slf.borrow().stress_py(py) } /// stress property (setter) #[setter] fn set_stress(&mut self, py: Python<'_>, stress: Py) -> PyResult<()> { - let stress = stress.bind(py); - if let Ok(arr) = stress.cast::>() { - let readonly = arr.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Stress must have shape (3, 3)")); - } - - let inner = self.ensure_owned()?; - inner.stress = Some([ - [arr[[0, 0]], arr[[0, 1]], arr[[0, 2]]], - [arr[[1, 0]], arr[[1, 1]], arr[[1, 2]]], - [arr[[2, 0]], arr[[2, 1]], arr[[2, 2]]], - ]); - inner.properties.remove("stress"); - return Ok(()); - } - - if let Ok(arr) = stress.cast::>() { - let readonly = arr.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Stress must have shape (3, 3)")); - } - - let inner = self.ensure_owned()?; - inner.stress = Some([ - [arr[[0, 0]] as f64, arr[[0, 1]] as f64, arr[[0, 2]] as f64], - [arr[[1, 0]] as f64, arr[[1, 1]] as f64, arr[[1, 2]] as f64], - [arr[[2, 0]] as f64, arr[[2, 1]] as f64, arr[[2, 2]] as f64], - ]); - inner.properties.remove("stress"); - return Ok(()); - } - - Err(PyValueError::new_err( - "Stress must be a float32 or float64 ndarray with shape (3, 3)", - )) + let inner = self.ensure_owned()?; + inner.stress = Some(parse_mat3_field(stress.bind(py), "stress")?); + inner.properties.remove("stress"); + Ok(()) } /// pbc property (getter) @@ -629,19 +408,9 @@ impl PyMolecule { /// positions property (read-only) #[getter] - fn positions<'py>(slf: Bound<'py, Self>) -> PyResult>> { + fn positions<'py>(slf: Bound<'py, Self>) -> PyResult> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - let flat = inner.positions_flat(); - let n_atoms = inner.len(); - return pyarray2_from_flat(py, flat, n_atoms, 3); - } - let view = molecule.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; - let pos_f32 = cast_or_decode_f32(view.positions_bytes())?; - pyarray2_from_cow(py, pos_f32, view.n_atoms, 3) + slf.borrow().positions_py(py) } /// atomic_numbers property (read-only) diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index 4512f7f..497c2aa 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -105,157 +105,75 @@ pub(crate) fn pyarray2_from_cow<'py, T: Element + Clone>( } } -struct WrittenSoaSection { - slot: BuiltinSlot, - section: LazySection, -} - -fn write_soa_section_raw( - buf: &mut Vec, - kind: u8, - key: &str, - type_tag: u8, - payload: &[u8], -) -> Result { - let key_len: u8 = key - .len() - .try_into() - .map_err(|_| format!("Section key '{}' is too long", key))?; - let payload_len: u32 = payload - .len() - .try_into() - .map_err(|_| format!("Section '{}' payload is too large", key))?; - buf.push(kind); - buf.push(key_len); - let key_start = buf.len(); - buf.extend_from_slice(key.as_bytes()); - buf.push(type_tag); - buf.extend_from_slice(&payload_len.to_le_bytes()); - let payload_start = buf.len(); - buf.extend_from_slice(payload); - Ok(WrittenSoaSection { - slot: (payload_start, payload.len(), type_tag), - section: LazySection { - kind, - key_start, - key_len, - type_tag, - payload_start, - payload_len: payload.len(), - }, - }) -} - -pub(crate) struct SoaBuiltinPayloads<'a> { - pub(crate) energy: Option, - pub(crate) forces: Option<&'a [u8]>, - pub(crate) charges: Option<&'a [u8]>, - pub(crate) velocities: Option<&'a [u8]>, - pub(crate) cell: Option<&'a [u8]>, - pub(crate) stress: Option<&'a [u8]>, - pub(crate) pbc: Option<[bool; 3]>, - pub(crate) name: Option<&'a str>, -} - -pub(crate) struct SoaCustomSection<'a> { +pub(crate) struct SoaSection<'a> { pub(crate) kind: u8, pub(crate) key: &'a str, pub(crate) type_tag: u8, pub(crate) payload: &'a [u8], } -pub(crate) struct BuiltSoaRecord { - bytes: Vec, - n_atoms: usize, - positions_start: usize, - atomic_numbers_start: usize, - forces: Option, - energy: Option, - cell: Option, - stress: Option, - charges: Option, - velocities: Option, - pbc: Option, - name: Option, - custom_sections: Vec, -} - -impl BuiltSoaRecord { - pub(crate) fn into_parts(self) -> (Vec, u32) { - (self.bytes, self.n_atoms as u32) - } - - pub(crate) fn into_view(self) -> SoaMoleculeView { - SoaMoleculeView { - bytes: SoaBytes::Owned(self.bytes), - n_atoms: self.n_atoms, - positions_start: self.positions_start, - atomic_numbers_start: self.atomic_numbers_start, - forces: self.forces, - energy: self.energy, - cell: self.cell, - stress: self.stress, - charges: self.charges, - velocities: self.velocities, - pbc: self.pbc, - name: self.name, - custom_sections: self.custom_sections, - } - } +pub(crate) struct SoaRecord<'a> { + pub(crate) record_format: u32, + pub(crate) positions_type: u8, + pub(crate) positions: &'a [u8], + pub(crate) atomic_numbers: &'a [u8], + pub(crate) sections: &'a [SoaSection<'a>], } -pub(crate) fn build_soa_record( - positions: &[f32], - atomic_numbers: &[u8], - builtins: SoaBuiltinPayloads<'_>, -) -> Result { - build_soa_record_with_custom(positions, atomic_numbers, builtins, &[]) +fn write_soa_section_raw(buf: &mut Vec, section: &SoaSection<'_>) -> Result<(), String> { + let key_len: u8 = section + .key + .len() + .try_into() + .map_err(|_| format!("Section key '{}' is too long", section.key))?; + let payload_len: u32 = section + .payload + .len() + .try_into() + .map_err(|_| format!("Section '{}' payload is too large", section.key))?; + buf.push(section.kind); + buf.push(key_len); + buf.extend_from_slice(section.key.as_bytes()); + buf.push(section.type_tag); + buf.extend_from_slice(&payload_len.to_le_bytes()); + buf.extend_from_slice(section.payload); + Ok(()) } -pub(crate) fn build_soa_record_with_custom( - positions: &[f32], - atomic_numbers: &[u8], - builtins: SoaBuiltinPayloads<'_>, - custom_sections: &[SoaCustomSection<'_>], -) -> Result { - if !positions.len().is_multiple_of(3) { - return Err("positions length must be divisible by 3".to_string()); - } - let n_atoms = positions.len() / 3; - if atomic_numbers.len() != n_atoms { +pub(crate) fn build_soa_record(record: SoaRecord<'_>) -> Result, String> { + if !matches!( + record.record_format, + RECORD_FORMAT_SOA_V2 | RECORD_FORMAT_SOA_V3 + ) { return Err(format!( - "Atomic numbers length ({}) doesn't match atom count ({})", - atomic_numbers.len(), - n_atoms + "Unsupported record format {}", + record.record_format )); } - - let validate_bytes = |payload: &[u8], expected: usize, label: &str| -> Result<(), String> { - if payload.len() != expected { - return Err(format!( - "{} payload length ({}) doesn't match expected byte length ({})", - label, - payload.len(), - expected - )); + let positions_elem_bytes = match record.positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + other => { + return Err(format!("Unsupported positions type tag {}", other)); } - Ok(()) }; - - if let Some(payload) = builtins.forces { - validate_bytes(payload, n_atoms * 12, "forces")?; - } - if let Some(payload) = builtins.charges { - validate_bytes(payload, n_atoms * 8, "charges")?; + if record.record_format == RECORD_FORMAT_SOA_V2 && record.positions_type != TYPE_VEC3_F32 { + return Err("record format 2 only supports float32 positions".to_string()); } - if let Some(payload) = builtins.velocities { - validate_bytes(payload, n_atoms * 12, "velocities")?; - } - if let Some(payload) = builtins.cell { - validate_bytes(payload, 72, "cell")?; + if !record.positions.len().is_multiple_of(positions_elem_bytes) { + return Err(format!( + "positions payload length ({}) is not a multiple of {}", + record.positions.len(), + positions_elem_bytes + )); } - if let Some(payload) = builtins.stress { - validate_bytes(payload, 72, "stress")?; + let n_atoms = record.positions.len() / positions_elem_bytes; + if record.atomic_numbers.len() != n_atoms { + return Err(format!( + "Atomic numbers length ({}) doesn't match atom count ({})", + record.atomic_numbers.len(), + n_atoms + )); } let mut n_sections = 0u16; @@ -267,31 +185,7 @@ pub(crate) fn build_soa_record_with_custom( section_overhead += 1 + 1 + key_len + 1 + 4; }; - if let Some(payload) = builtins.charges { - account_section(payload.len(), "charges".len()); - } - if let Some(payload) = builtins.cell { - account_section(payload.len(), "cell".len()); - } - if builtins.energy.is_some() { - account_section(std::mem::size_of::(), "energy".len()); - } - if let Some(payload) = builtins.forces { - account_section(payload.len(), "forces".len()); - } - if let Some(value) = builtins.name { - account_section(value.len(), "name".len()); - } - if builtins.pbc.is_some() { - account_section(3, "pbc".len()); - } - if let Some(payload) = builtins.stress { - account_section(payload.len(), "stress".len()); - } - if let Some(payload) = builtins.velocities { - account_section(payload.len(), "velocities".len()); - } - for section in custom_sections { + for section in record.sections { let parsed = SectionRef { kind: section.kind, key: section.key, @@ -312,7 +206,9 @@ pub(crate) fn build_soa_record_with_custom( elem_bytes } TYPE_FLOAT | TYPE_INT => 8, + TYPE_FLOAT32 => 4, TYPE_BOOL3 => 3, + TYPE_MAT3X3_F32 => 36, TYPE_MAT3X3_F64 => 72, _ => parsed.payload.len(), }; @@ -328,183 +224,146 @@ pub(crate) fn build_soa_record_with_custom( account_section(parsed.payload.len(), parsed.key.len()); } - let positions_start = 4usize; - let atomic_numbers_start = positions_start + positions.len() * 4; + let positions_type_bytes = usize::from(record.record_format == RECORD_FORMAT_SOA_V3); let mut buf = Vec::with_capacity( - 4 + positions.len() * 4 + atomic_numbers.len() + 2 + section_overhead + payload_bytes, + 4 + positions_type_bytes + + record.positions.len() + + record.atomic_numbers.len() + + 2 + + section_overhead + + payload_bytes, ); buf.extend_from_slice(&(n_atoms as u32).to_le_bytes()); - buf.extend_from_slice(bytemuck::cast_slice::(positions)); - buf.extend_from_slice(atomic_numbers); + if record.record_format == RECORD_FORMAT_SOA_V3 { + buf.push(record.positions_type); + } + buf.extend_from_slice(record.positions); + buf.extend_from_slice(record.atomic_numbers); buf.extend_from_slice(&n_sections.to_le_bytes()); - let mut charges = None; - let mut cell = None; - let mut energy_slot = None; - let mut forces = None; - let mut name = None; - let mut pbc = None; - let mut stress = None; - let mut velocities = None; - let mut custom_slots = Vec::with_capacity(custom_sections.len()); - - if let Some(payload) = builtins.charges { - charges = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, payload)?.slot, - ); - } - if let Some(payload) = builtins.cell { - cell = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, payload)?.slot, - ); - } - if let Some(value) = builtins.energy { - let payload = value.to_le_bytes(); - energy_slot = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "energy", TYPE_FLOAT, &payload)?.slot, - ); - } - if let Some(payload) = builtins.forces { - forces = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, payload)?.slot, - ); - } - if let Some(value) = builtins.name { - name = Some( - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "name", - TYPE_STRING, - value.as_bytes(), - )? - .slot, - ); - } - if let Some([a, b, c]) = builtins.pbc { - let payload = [a as u8, b as u8, c as u8]; - pbc = - Some(write_soa_section_raw(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload)?.slot); - } - if let Some(payload) = builtins.stress { - stress = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, payload)?.slot, - ); - } - if let Some(payload) = builtins.velocities { - velocities = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "velocities", TYPE_VEC3_F32, payload)? - .slot, - ); - } - for section in custom_sections { - custom_slots.push( - write_soa_section_raw( - &mut buf, - section.kind, - section.key, - section.type_tag, - section.payload, - )? - .section, - ); - } - - Ok(BuiltSoaRecord { - bytes: buf, - n_atoms, - positions_start, - atomic_numbers_start, - forces, - energy: energy_slot, - cell, - stress, - charges, - velocities, - pbc, - name, - custom_sections: custom_slots, - }) + for section in record.sections { + write_soa_section_raw(&mut buf, section)?; + } + + Ok(buf) } -fn vec3_f32_payload<'py>( - readonly: &'py numpy::PyReadonlyArray2<'py, f32>, - label: &str, - expected_rows: usize, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - let shape = arr.shape(); - if shape != [expected_rows, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape ({}, 3)", - label, expected_rows - ))); +fn molecule_from_positions( + positions: Vec3Data, + atomic_numbers: Vec, +) -> Result { + match positions { + Vec3Data::F32(values) => Molecule::new(values, atomic_numbers), + Vec3Data::F64(values) => Molecule::new_f64(values, atomic_numbers), } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) } -fn vec1_f64_payload<'py>( - readonly: &'py numpy::PyReadonlyArray1<'py, f64>, +pub(crate) fn parse_vec3_field( + value: &Bound<'_, PyAny>, label: &str, - expected_len: usize, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - if arr.len() != expected_len { - return Err(PyValueError::new_err(format!( - "{} length ({}) doesn't match atom count ({})", - label, - arr.len(), - expected_len - ))); + expected_rows: usize, +) -> PyResult { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [expected_rows, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, expected_rows + ))); + } + return Ok(Vec3Data::F32( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [expected_rows, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, expected_rows + ))); + } + return Ok(Vec3Data::F64( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) + Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({}, 3)", + label, expected_rows + ))) } -fn mat3x3_f64_payload<'py>( - readonly: &'py numpy::PyReadonlyArray2<'py, f64>, +pub(crate) fn parse_float_array_field( + value: &Bound<'_, PyAny>, label: &str, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - if arr.shape() != [3, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape (3, 3)", - label - ))); + expected_len: usize, +) -> PyResult { + if let Ok(arr) = value.cast::>() { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + return Ok(FloatArrayData::F32(values)); } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) + if let Ok(arr) = value.cast::>() { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + return Ok(FloatArrayData::F64(values)); + } + Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({},)", + label, expected_len + ))) } -fn mat3x3_f64_payload_from_any(py: Python<'_>, value: Py, label: &str) -> PyResult> { - let value = value.bind(py); - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - return Ok(mat3x3_f64_payload(&readonly, label)?.to_vec()); - } +pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResult { if let Ok(arr) = value.cast::>() { let readonly = arr.readonly(); - let arr = readonly.as_array(); - if arr.shape() != [3, 3] { + let view = readonly.as_array(); + if view.shape() != [3, 3] { return Err(PyValueError::new_err(format!( "{} must have shape (3, 3)", label ))); } - let mut payload = Vec::with_capacity(72); - for row in arr.outer_iter() { - for value in row { - payload.extend_from_slice(&(*value as f64).to_le_bytes()); - } + return Ok(Mat3Data::F32([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape (3, 3)", + label + ))); } - return Ok(payload); + return Ok(Mat3Data::F64([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])); } Err(PyValueError::new_err(format!( "{} must be a float32 or float64 ndarray with shape (3, 3)", @@ -512,6 +371,17 @@ fn mat3x3_f64_payload_from_any(py: Python<'_>, value: Py, label: &str) -> ))) } +pub(crate) fn molecule_from_numpy_arrays( + positions: &Bound<'_, PyAny>, + atomic_numbers: &Bound<'_, PyArray1>, +) -> PyResult { + let z = atomic_numbers.readonly(); + let z_arr = z.as_array(); + let atomic_numbers_vec = z_arr.to_vec(); + let positions = parse_vec3_field(positions, "positions", atomic_numbers_vec.len())?; + molecule_from_positions(positions, atomic_numbers_vec).map_err(PyValueError::new_err) +} + pub(super) fn into_py_any<'py, T>(py: Python<'py>, value: T) -> PyResult> where T: IntoPyObject<'py>, @@ -624,31 +494,58 @@ fn is_reserved_ase_array_key(key: &str) -> bool { matches!(key, "numbers" | "positions") } -fn owned_vec3_array<'py>( - py: Python<'py>, - values: &[[f32; 3]], -) -> PyResult>> { - let n_atoms = values.len(); - let flat: Vec = values - .iter() - .flat_map(|value| [value[0], value[1], value[2]]) - .collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) +fn owned_vec3_array<'py>(py: Python<'py>, values: &Vec3Data) -> PyResult> { + Ok(match values { + Vec3Data::F32(values) => { + let n_atoms = values.len(); + let flat: Vec = values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(); + pyarray2_from_flat(py, flat, n_atoms, 3)? + .into_any() + .unbind() + } + Vec3Data::F64(values) => { + let n_atoms = values.len(); + let flat: Vec = values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(); + pyarray2_from_flat(py, flat, n_atoms, 3)? + .into_any() + .unbind() + } + }) } -fn owned_mat3x3_array<'py>( - py: Python<'py>, - values: &[[f64; 3]; 3], -) -> PyResult>> { - let flat: Vec = values - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) +fn owned_float_array<'py>(py: Python<'py>, values: &FloatArrayData) -> Py { + match values { + FloatArrayData::F32(values) => PyArray1::from_slice(py, values).into_any().unbind(), + FloatArrayData::F64(values) => PyArray1::from_slice(py, values).into_any().unbind(), + } +} + +fn owned_mat3x3_array<'py>(py: Python<'py>, values: &Mat3Data) -> PyResult> { + Ok(match values { + Mat3Data::F32(values) => { + let flat: Vec = values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(); + pyarray2_from_flat(py, flat, 3, 3)?.into_any().unbind() + } + Mat3Data::F64(values) => { + let flat: Vec = values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(); + pyarray2_from_flat(py, flat, 3, 3)?.into_any().unbind() + } + }) } impl PyMolecule { - #[allow(dead_code)] pub(crate) fn from_owned(inner: Molecule) -> Self { Self { backing: MoleculeBacking::Owned(inner), @@ -661,14 +558,14 @@ impl PyMolecule { } } - pub(super) fn as_owned(&self) -> Option<&Molecule> { + pub(crate) fn as_owned(&self) -> Option<&Molecule> { match &self.backing { MoleculeBacking::Owned(inner) => Some(inner), MoleculeBacking::View(_) => None, } } - pub(super) fn as_view(&self) -> Option<&SoaMoleculeView> { + pub(crate) fn as_view(&self) -> Option<&SoaMoleculeView> { match &self.backing { MoleculeBacking::Owned(_) => None, MoleculeBacking::View(view) => Some(view), @@ -701,22 +598,31 @@ impl PyMolecule { } } - pub(crate) fn soa_bytes(&self) -> Option<(&[u8], u32)> { - match &self.backing { - MoleculeBacking::View(view) => Some((view.bytes.as_slice(), view.n_atoms as u32)), - MoleculeBacking::Owned(_) => None, - } - } - - pub(super) fn positions_py<'py>(&self, py: Python<'py>) -> PyResult>> { + pub(super) fn positions_py<'py>(&self, py: Python<'py>) -> PyResult> { if let Some(inner) = self.as_owned() { - return pyarray2_from_flat(py, inner.positions_flat(), inner.len(), 3); + return owned_vec3_array(py, &inner.positions); } let view = self.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") })?; - let positions = cast_or_decode_f32(view.positions_bytes())?; - pyarray2_from_cow(py, positions, view.n_atoms, 3) + match view.positions_type { + TYPE_VEC3_F32 => { + let positions = cast_or_decode_f32(view.positions_bytes())?; + Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? + .into_any() + .unbind()) + } + TYPE_VEC3_F64 => { + let positions = cast_or_decode_f64(view.positions_bytes())?; + Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? + .into_any() + .unbind()) + } + other => Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))), + } } pub(super) fn atomic_numbers_py<'py>( @@ -732,10 +638,7 @@ impl PyMolecule { Ok(PyArray1::from_slice(py, view.atomic_numbers_bytes())) } - pub(super) fn forces_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn forces_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .forces @@ -749,22 +652,39 @@ impl PyMolecule { let Some(slot) = view.forces else { return Ok(None); }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid forces section")); + match slot.2 { + TYPE_VEC3_F32 => { + if slot.1 != view.n_atoms * 12 { + return Err(PyValueError::new_err("Invalid forces section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + TYPE_VEC3_F64 => { + if slot.1 != view.n_atoms * 24 { + return Err(PyValueError::new_err("Invalid forces section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + _ => Err(PyValueError::new_err("Invalid forces section")), } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) } - pub(super) fn charges_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn charges_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return Ok(inner .charges .as_ref() - .map(|charges| PyArray1::from_slice(py, charges))); + .map(|charges| owned_float_array(py, charges))); } let view = self.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") @@ -772,17 +692,26 @@ impl PyMolecule { let Some(slot) = view.charges else { return Ok(None); }; - if slot.2 != TYPE_F64_ARRAY || slot.1 != view.n_atoms * 8 { - return Err(PyValueError::new_err("Invalid charges section")); + match slot.2 { + TYPE_F32_ARRAY => { + if slot.1 != view.n_atoms * 4 { + return Err(PyValueError::new_err("Invalid charges section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) + } + TYPE_F64_ARRAY => { + if slot.1 != view.n_atoms * 8 { + return Err(PyValueError::new_err("Invalid charges section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid charges section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray1_from_cow(py, data))) } - pub(super) fn velocities_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn velocities_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .velocities @@ -796,17 +725,34 @@ impl PyMolecule { let Some(slot) = view.velocities else { return Ok(None); }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid velocities section")); + match slot.2 { + TYPE_VEC3_F32 => { + if slot.1 != view.n_atoms * 12 { + return Err(PyValueError::new_err("Invalid velocities section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + TYPE_VEC3_F64 => { + if slot.1 != view.n_atoms * 24 { + return Err(PyValueError::new_err("Invalid velocities section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + _ => Err(PyValueError::new_err("Invalid velocities section")), } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) } - pub(super) fn cell_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn cell_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .cell @@ -820,17 +766,26 @@ impl PyMolecule { let Some(slot) = view.cell else { return Ok(None); }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid cell section")); + match slot.2 { + TYPE_MAT3X3_F32 => { + if slot.1 != 36 { + return Err(PyValueError::new_err("Invalid cell section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + TYPE_MAT3X3_F64 => { + if slot.1 != 72 { + return Err(PyValueError::new_err("Invalid cell section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid cell section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) } - pub(super) fn stress_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn stress_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .stress @@ -844,11 +799,23 @@ impl PyMolecule { let Some(slot) = view.stress else { return Ok(None); }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid stress section")); + match slot.2 { + TYPE_MAT3X3_F32 => { + if slot.1 != 36 { + return Err(PyValueError::new_err("Invalid stress section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + TYPE_MAT3X3_F64 => { + if slot.1 != 72 { + return Err(PyValueError::new_err("Invalid stress section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid stress section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) } pub(super) fn append_owned_ase_properties<'py>( @@ -920,84 +887,59 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] pub(super) fn from_arrays_impl( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, ) -> PyResult { - let pos = positions.readonly(); - let pos_arr = pos.as_array(); - let shape = pos_arr.shape(); - if shape.len() != 2 || shape[1] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (n_atoms, 3)", - )); - } - - let z = atomic_numbers.readonly(); - let z_arr = z.as_array(); - let n_atoms = shape[0]; - if z_arr.len() != n_atoms { - return Err(PyValueError::new_err(format!( - "Atomic numbers length ({}) doesn't match atom count ({})", - z_arr.len(), - n_atoms - ))); + let mut molecule = molecule_from_numpy_arrays(positions, atomic_numbers)?; + + if let Some(name) = name { + molecule.name = Some(name); + } + if let Some(energy) = energy { + molecule.energy = Some(FloatScalarData::F64(energy)); + } + if let Some((a, b, c)) = pbc { + molecule.pbc = Some([a, b, c]); + } + + let n_atoms = molecule.len(); + + if let Some(forces) = forces { + molecule.forces = Some(parse_vec3_field(forces.bind(py), "forces", n_atoms)?); + } + + if let Some(charges) = charges { + molecule.charges = Some(parse_float_array_field( + charges.bind(py), + "charges", + n_atoms, + )?); + } + + if let Some(velocities) = velocities { + molecule.velocities = Some(parse_vec3_field( + velocities.bind(py), + "velocities", + n_atoms, + )?); + } + + if let Some(cell) = cell { + molecule.cell = Some(parse_mat3_field(cell.bind(py), "cell")?); + } + + if let Some(stress) = stress { + molecule.stress = Some(parse_mat3_field(stress.bind(py), "stress")?); } - let pos_bytes = pos_arr - .as_slice() - .ok_or_else(|| PyValueError::new_err("positions must be C-contiguous"))?; - let z_bytes = z_arr - .as_slice() - .ok_or_else(|| PyValueError::new_err("atomic_numbers must be C-contiguous"))?; - - let forces_readonly = forces.map(|value| value.readonly()); - let charges_readonly = charges.map(|value| value.readonly()); - let velocities_readonly = velocities.map(|value| value.readonly()); - let cell_readonly = cell.map(|value| value.readonly()); - - let forces_payload = forces_readonly - .as_ref() - .map(|readonly| vec3_f32_payload(readonly, "forces", n_atoms)) - .transpose()?; - let charges_payload = charges_readonly - .as_ref() - .map(|readonly| vec1_f64_payload(readonly, "charges", n_atoms)) - .transpose()?; - let velocities_payload = velocities_readonly - .as_ref() - .map(|readonly| vec3_f32_payload(readonly, "velocities", n_atoms)) - .transpose()?; - let cell_payload = cell_readonly - .as_ref() - .map(|readonly| mat3x3_f64_payload(readonly, "cell")) - .transpose()?; - let stress_payload = stress - .map(|value| mat3x3_f64_payload_from_any(py, value, "stress")) - .transpose()?; - let record = build_soa_record( - pos_bytes, - z_bytes, - SoaBuiltinPayloads { - energy, - forces: forces_payload, - charges: charges_payload, - velocities: velocities_payload, - cell: cell_payload, - stress: stress_payload.as_deref(), - pbc: pbc.map(|(a, b, c)| [a, b, c]), - name: name.as_deref(), - }, - ) - .map_err(PyValueError::new_err)?; - - Ok(Self::from_view(record.into_view())) + Ok(Self::from_owned(molecule)) } } diff --git a/atompack-py/tests/test_atom_molecule.py b/atompack-py/tests/test_atom_molecule.py index 7dc3e49..5674ca0 100644 --- a/atompack-py/tests/test_atom_molecule.py +++ b/atompack-py/tests/test_atom_molecule.py @@ -127,8 +127,8 @@ def test_molecule_ml_properties_roundtrip() -> None: stress32 = np.eye(3, dtype=np.float32) * 5.0 mol.stress = stress32 - assert mol.stress.dtype == np.float64 - np.testing.assert_allclose(mol.stress, stress32.astype(np.float64)) + assert mol.stress.dtype == np.float32 + np.testing.assert_allclose(mol.stress, stress32) def test_molecule_ml_property_validation() -> None: @@ -136,7 +136,7 @@ def test_molecule_ml_property_validation() -> None: with pytest.raises(ValueError, match=r"Forces must have shape"): mol.forces = np.zeros((2, 2), dtype=np.float32) - with pytest.raises(ValueError, match=r"Forces length"): + with pytest.raises(ValueError, match=r"Forces must have shape"): mol.forces = np.zeros((1, 3), dtype=np.float32) with pytest.raises(ValueError, match=r"Charges length"): @@ -144,7 +144,7 @@ def test_molecule_ml_property_validation() -> None: with pytest.raises(ValueError, match=r"Velocities must have shape"): mol.velocities = np.zeros((2, 2), dtype=np.float32) - with pytest.raises(ValueError, match=r"Velocities length"): + with pytest.raises(ValueError, match=r"Velocities must have shape"): mol.velocities = np.zeros((1, 3), dtype=np.float32) with pytest.raises(ValueError, match=r"Cell must have shape"): @@ -276,13 +276,11 @@ def test_molecule_getitem_validation() -> None: _ = mol[1.5] -def test_from_arrays_rejects_wrong_dtype_positions() -> None: - # positions must be float32; passing float64 should be rejected cleanly, - # not silently truncated or panic across the FFI boundary. - positions = np.zeros((2, 3), dtype=np.float64) # wrong dtype +def test_from_arrays_preserves_float64_positions() -> None: + positions = np.zeros((2, 3), dtype=np.float64) atomic_numbers = np.array([6, 8], dtype=np.uint8) - with pytest.raises((TypeError, ValueError)): - atompack.Molecule.from_arrays(positions, atomic_numbers) + molecule = atompack.Molecule.from_arrays(positions, atomic_numbers) + assert molecule.positions.dtype == np.float64 def test_from_arrays_rejects_wrong_dtype_atomic_numbers() -> None: diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 6973049..fbb6619 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -579,6 +579,69 @@ def test_get_molecules_flat_empty(tmp_path: Path) -> None: assert batch["atomic_numbers"].shape == (0,) +def test_database_roundtrip_preserves_float64_geometry(tmp_path: Path) -> None: + path = tmp_path / "float64_geometry.atp" + db = atompack.Database(str(path), overwrite=True) + + positions = np.array([[0.0, 0.1, 0.2], [1.0, 1.1, 1.2]], dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + forces = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) + charges = np.array([-0.1, 0.1], dtype=np.float32) + cell = np.eye(3, dtype=np.float32) + stress = np.eye(3, dtype=np.float64) * 2.0 + + mol = atompack.Molecule.from_arrays( + positions, + atomic_numbers, + forces=forces, + charges=charges, + cell=cell, + stress=stress, + ) + db.add_molecule(mol) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + got = reopened.get_molecule(0) + flat = reopened.get_molecules_flat([0]) + + assert got.positions.dtype == np.float64 + assert got.forces.dtype == np.float64 + assert got.charges.dtype == np.float32 + assert got.cell.dtype == np.float32 + assert got.stress.dtype == np.float64 + assert flat["positions"].dtype == np.float64 + assert flat["forces"].dtype == np.float64 + assert flat["charges"].dtype == np.float32 + assert flat["cell"].dtype == np.float32 + assert flat["stress"].dtype == np.float64 + + +def test_get_molecules_flat_late_optional_builtin_zero_fills(tmp_path: Path) -> None: + path = tmp_path / "late_optional_builtin.atp" + db = atompack.Database(str(path), overwrite=True) + + positions = np.zeros((2, 3), dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + db.add_molecule(atompack.Molecule.from_arrays(positions, atomic_numbers)) + db.add_molecule( + atompack.Molecule.from_arrays( + positions + 1.0, + atomic_numbers, + forces=np.ones((2, 3), dtype=np.float64), + ) + ) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + batch = reopened.get_molecules_flat([0, 1]) + + assert batch["positions"].dtype == np.float64 + assert batch["forces"].dtype == np.float64 + np.testing.assert_allclose(batch["forces"][:2], np.zeros((2, 3), dtype=np.float64)) + np.testing.assert_allclose(batch["forces"][2:], np.ones((2, 3), dtype=np.float64)) + + def test_database_open_mmap_populate(tmp_path: Path) -> None: # Smoke test for the documented populate=True path. On Linux this # prefaults mapped pages via memmap2's PopulateRead advise; on macOS the diff --git a/atompack/examples/basic_usage.rs b/atompack/examples/basic_usage.rs index a5fecbe..9269657 100644 --- a/atompack/examples/basic_usage.rs +++ b/atompack/examples/basic_usage.rs @@ -3,7 +3,7 @@ //! //! Run with: cargo run -p atompack --example basic_usage -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{Atom, AtomDatabase, FloatScalarData, Molecule, compression::CompressionType}; fn main() -> atompack::Result<()> { let db_path = std::env::temp_dir().join(format!("atompack-basic-{}.atp", std::process::id())); @@ -14,7 +14,7 @@ fn main() -> atompack::Result<()> { Atom::new(-0.24, 0.93, 0.0, 1), ]); water.name = Some("water".to_string()); - water.energy = Some(-76.4); + water.energy = Some(FloatScalarData::F64(-76.4)); let methane = Molecule::from_atoms(vec![ Atom::new(0.0, 0.0, 0.0, 6), @@ -33,7 +33,7 @@ fn main() -> atompack::Result<()> { let roundtrip = db.get_molecule(0)?; assert_eq!(roundtrip.name.as_deref(), Some("water")); - assert_eq!(roundtrip.energy, Some(-76.4)); + assert_eq!(roundtrip.energy, Some(FloatScalarData::F64(-76.4))); assert_eq!(roundtrip.len(), 3); println!("wrote {}", db_path.display()); diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index b890ba3..d6daef4 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -69,27 +69,152 @@ impl PropertyValue { } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Vec3Data { + F32(Vec<[f32; 3]>), + F64(Vec<[f64; 3]>), +} + +impl Vec3Data { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn atom_position(&self, index: usize) -> Option<[f32; 3]> { + match self { + Self::F32(values) => values.get(index).copied(), + Self::F64(values) => values + .get(index) + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]), + } + } + + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0] as f64, value[1] as f64, value[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatScalarData { + F32(f32), + F64(f64), +} + +impl FloatScalarData { + pub fn as_f64(&self) -> f64 { + match self { + Self::F32(value) => *value as f64, + Self::F64(value) => *value, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatArrayData { + F32(Vec), + F64(Vec), +} + +impl FloatArrayData { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn to_f64_vec(&self) -> Vec { + match self { + Self::F32(values) => values.iter().map(|value| *value as f64).collect(), + Self::F64(values) => values.clone(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Mat3Data { + F32([[f32; 3]; 3]), + F64([[f64; 3]; 3]), +} + +impl Mat3Data { + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0] as f32, row[1] as f32, row[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0] as f64, row[1] as f64, row[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Molecule { pub name: Option, - pub positions: Vec<[f32; 3]>, + pub positions: Vec3Data, pub atomic_numbers: Vec, - pub forces: Option>, + pub forces: Option, - pub energy: Option, + pub energy: Option, - pub charges: Option>, + pub charges: Option, - pub velocities: Option>, + pub velocities: Option, - pub cell: Option<[[f64; 3]; 3]>, + pub cell: Option, pub pbc: Option<[bool; 3]>, - pub stress: Option<[[f64; 3]; 3]>, + pub stress: Option, /// Per-atom properties (features or targets) pub atom_properties: HashMap, @@ -100,7 +225,11 @@ pub struct Molecule { impl Molecule { pub fn new(positions: Vec<[f32; 3]>, atomic_numbers: Vec) -> Result { - Self::with_name_internal(None, positions, atomic_numbers) + Self::with_name_internal(None, Vec3Data::F32(positions), atomic_numbers) + } + + pub fn new_f64(positions: Vec<[f64; 3]>, atomic_numbers: Vec) -> Result { + Self::with_name_internal(None, Vec3Data::F64(positions), atomic_numbers) } pub fn with_name( @@ -108,12 +237,20 @@ impl Molecule { positions: Vec<[f32; 3]>, atomic_numbers: Vec, ) -> Result { - Self::with_name_internal(Some(name), positions, atomic_numbers) + Self::with_name_internal(Some(name), Vec3Data::F32(positions), atomic_numbers) + } + + pub fn with_name_f64( + name: String, + positions: Vec<[f64; 3]>, + atomic_numbers: Vec, + ) -> Result { + Self::with_name_internal(Some(name), Vec3Data::F64(positions), atomic_numbers) } fn with_name_internal( name: Option, - positions: Vec<[f32; 3]>, + positions: Vec3Data, atomic_numbers: Vec, ) -> Result { if positions.len() != atomic_numbers.len() { @@ -144,7 +281,7 @@ impl Molecule { let atomic_numbers = atoms.iter().map(|atom| atom.atomic_number).collect(); Self { name: None, - positions, + positions: Vec3Data::F32(positions), atomic_numbers, forces: None, energy: None, @@ -167,38 +304,76 @@ impl Molecule { } pub fn atom(&self, index: usize) -> Option { + let position = self.positions.atom_position(index)?; Some(Atom::new( - self.positions.get(index)?[0], - self.positions.get(index)?[1], - self.positions.get(index)?[2], + position[0], + position[1], + position[2], *self.atomic_numbers.get(index)?, )) } pub fn to_atoms(&self) -> Vec { - self.positions + self.atomic_numbers .iter() - .zip(&self.atomic_numbers) - .map(|(position, &atomic_number)| { - Atom::new(position[0], position[1], position[2], atomic_number) + .enumerate() + .filter_map(|(index, &atomic_number)| { + self.positions + .atom_position(index) + .map(|position| Atom::new(position[0], position[1], position[2], atomic_number)) }) .collect() } pub fn forces_mut(&mut self) -> &mut Vec<[f32; 3]> { let n_atoms = self.len(); - self.forces.get_or_insert_with(|| vec![[0.0; 3]; n_atoms]) + let forces = self + .forces + .get_or_insert_with(|| Vec3Data::F32(vec![[0.0; 3]; n_atoms])); + if let Vec3Data::F64(values) = forces { + let converted: Vec<[f32; 3]> = values + .iter() + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(); + *forces = Vec3Data::F32(converted); + } + match forces { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => unreachable!(), + } } pub fn velocities_mut(&mut self) -> &mut Vec<[f32; 3]> { let n_atoms = self.len(); - self.velocities - .get_or_insert_with(|| vec![[0.0; 3]; n_atoms]) + let velocities = self + .velocities + .get_or_insert_with(|| Vec3Data::F32(vec![[0.0; 3]; n_atoms])); + if let Vec3Data::F64(values) = velocities { + let converted: Vec<[f32; 3]> = values + .iter() + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(); + *velocities = Vec3Data::F32(converted); + } + match velocities { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => unreachable!(), + } } pub fn charges_mut(&mut self) -> &mut Vec { let n_atoms = self.len(); - self.charges.get_or_insert_with(|| vec![0.0; n_atoms]) + let charges = self + .charges + .get_or_insert_with(|| FloatArrayData::F64(vec![0.0; n_atoms])); + if let FloatArrayData::F32(values) = charges { + let converted: Vec = values.iter().map(|value| *value as f64).collect(); + *charges = FloatArrayData::F64(converted); + } + match charges { + FloatArrayData::F64(values) => values, + FloatArrayData::F32(_) => unreachable!(), + } } /// Add a per-atom floating-point property @@ -317,10 +492,7 @@ impl Molecule { /// /// Returns: [x1, y1, z1, x2, y2, z2, ...] pub fn positions_flat(&self) -> Vec { - self.positions - .iter() - .flat_map(|position| [position[0], position[1], position[2]]) - .collect() + self.positions.flatten_f32_lossy() } /// Get atomic numbers as an array @@ -332,25 +504,19 @@ impl Molecule { /// /// Returns: Some([fx1, fy1, fz1, fx2, fy2, fz2, ...]) or None pub fn forces_flat(&self) -> Option> { - self.forces - .as_ref() - .map(|f| f.iter().flat_map(|v| [v[0], v[1], v[2]]).collect()) + self.forces.as_ref().map(Vec3Data::flatten_f32_lossy) } /// Get velocities as a flat array (if present) pub fn velocities_flat(&self) -> Option> { - self.velocities - .as_ref() - .map(|v| v.iter().flat_map(|vec| [vec[0], vec[1], vec[2]]).collect()) + self.velocities.as_ref().map(Vec3Data::flatten_f32_lossy) } /// Get cell as a flat array (if present) /// /// Returns: Some([a_x, a_y, a_z, b_x, b_y, b_z, c_x, c_y, c_z]) or None pub fn cell_flat(&self) -> Option> { - self.cell - .as_ref() - .map(|c| c.iter().flat_map(|row| [row[0], row[1], row[2]]).collect()) + self.cell.as_ref().map(Mat3Data::flatten_f64) } } @@ -393,7 +559,7 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], vec![6, 8]).unwrap(); // Add forces - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); // Test flat forces let forces_flat = mol.forces_flat().unwrap(); @@ -422,13 +588,13 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); // Set energy - mol.energy = Some(-100.5); + mol.energy = Some(FloatScalarData::F64(-100.5)); // Set custom molecular property mol.set_property("homo_lumo_gap".to_string(), PropertyValue::Float(5.2)); // Retrieve - assert_eq!(mol.energy.unwrap(), -100.5); + assert_eq!(mol.energy.as_ref(), Some(&FloatScalarData::F64(-100.5))); assert_eq!(mol.get_scalar("homo_lumo_gap").unwrap(), 5.2); } @@ -448,7 +614,11 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); // Set unit cell (cubic 10x10x10) - mol.cell = Some([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]); + mol.cell = Some(Mat3Data::F64([ + [10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0], + ])); mol.pbc = Some([true, true, true]); // Test cell_flat diff --git a/atompack/src/bin/atompack-bench.rs b/atompack/src/bin/atompack-bench.rs index db5dca4..8c21bda 100644 --- a/atompack/src/bin/atompack-bench.rs +++ b/atompack/src/bin/atompack-bench.rs @@ -4,7 +4,9 @@ //! Run with: //! cargo run -p atompack --release --bin atompack-bench -- --help -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{ + Atom, AtomDatabase, FloatScalarData, Molecule, Vec3Data, compression::CompressionType, +}; use std::env; use std::time::{Duration, Instant}; @@ -101,15 +103,15 @@ fn create_synthetic_molecule(atoms_per_molecule: usize, id: u64, with_props: boo if with_props { let n = atoms_per_molecule; - mol.energy = Some(-1000.0 - (id as f64) * 1e-3); - mol.forces = Some( + mol.energy = Some(FloatScalarData::F64(-1000.0 - (id as f64) * 1e-3)); + mol.forces = Some(Vec3Data::F32( (0..n) .map(|i| { let t = (id as f32) * 0.002 + (i as f32) * 0.02; [(t * 0.7).sin(), (t * 0.9).cos(), (t * 1.1).sin()] }) .collect(), - ); + )); } mol diff --git a/atompack/src/lib.rs b/atompack/src/lib.rs index 8704db5..04a4dfe 100644 --- a/atompack/src/lib.rs +++ b/atompack/src/lib.rs @@ -12,7 +12,9 @@ pub mod atom; pub mod compression; pub mod storage; -pub use atom::{Atom, Molecule}; +pub use atom::{ + Atom, FloatArrayData, FloatScalarData, Mat3Data, Molecule, PropertyValue, Vec3Data, +}; pub use compression::decompress as decompress_bytes; pub use storage::{AtomDatabase, SharedMmapBytes}; diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index fe9e1be..1e9437b 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -39,7 +39,10 @@ mod soa; use self::header::{Header, encode_header_slot, read_best_header}; use self::index::{IndexStorage, MoleculeIndex, decode_index, encode_index}; -use self::soa::{arr, deserialize_molecule_soa, serialize_molecule_soa}; +use self::soa::{ + SchemaLock, arr, deserialize_molecule_soa, merge_schema_lock, record_schema, + serialize_molecule_soa, +}; // --------------------------------------------------------------------------- // Constants @@ -48,7 +51,9 @@ use self::soa::{arr, deserialize_molecule_soa, serialize_molecule_soa}; const MAGIC: &[u8; 4] = b"ATPK"; /// Bump only for incompatible layout changes (not crate version). const FILE_FORMAT_VERSION: u32 = 2; -const RECORD_FORMAT_SOA: u32 = 2; +const RECORD_FORMAT_SOA_V2: u32 = 2; +const RECORD_FORMAT_SOA_V3: u32 = 3; +const RECORD_FORMAT_SOA: u32 = RECORD_FORMAT_SOA_V3; // Section kind tags (inside each SOA record) const KIND_BUILTIN: u8 = 0; @@ -67,6 +72,8 @@ const TYPE_VEC3_F64: u8 = 7; // Vec<[f64; 3]> const TYPE_I32_ARRAY: u8 = 8; const TYPE_BOOL3: u8 = 9; // [bool; 3] const TYPE_MAT3X3_F64: u8 = 10; // [[f64; 3]; 3] +const TYPE_FLOAT32: u8 = 11; // f32 scalar +const TYPE_MAT3X3_F32: u8 = 12; // [[f32; 3]; 3] // Two redundant page-aligned header slots for crash safety. const HEADER_SLOT_SIZE: usize = 4096; @@ -107,6 +114,7 @@ pub struct AtomDatabase { committed_end: u64, truncate_tail_on_next_write: bool, index: IndexStorage, + schema_lock: Option, file: Option, data_mmap: Option>, } @@ -151,6 +159,7 @@ impl AtomDatabase { committed_end: HEADER_REGION_SIZE, truncate_tail_on_next_write: false, index: IndexStorage::InMemory(Vec::new()), + schema_lock: None, file: None, data_mmap: None, }) @@ -199,9 +208,11 @@ impl AtomDatabase { fn open_v1(path: PathBuf, mut file: File, use_mmap: bool, populate: bool) -> Result { // Read the best valid header slot (crash-safe). let header = read_best_header(&mut file)?; - if header.record_format != RECORD_FORMAT_SOA { + if header.record_format != RECORD_FORMAT_SOA_V2 + && header.record_format != RECORD_FORMAT_SOA_V3 + { return Err(Error::InvalidData(format!( - "Unsupported legacy record format {}. Atompack no longer supports bincode databases.", + "Unsupported record format {}.", header.record_format ))); } @@ -258,6 +269,7 @@ impl AtomDatabase { committed_end, truncate_tail_on_next_write, index, + schema_lock: None, file: Some(file), data_mmap, }) @@ -277,10 +289,73 @@ impl AtomDatabase { } self.truncate_tail_on_next_write = false; + self.schema_lock = None; self.file = None; Ok(()) } + fn rebuild_schema_lock(&self) -> Result { + let mut lock = SchemaLock::default(); + let compression = self.compression; + + if let Some(ref mmap) = self.data_mmap { + for index in 0..self.index.len() { + let entry = self + .index + .get(index) + .ok_or_else(|| Error::InvalidData(format!("Index {} out of bounds", index)))?; + let start = entry.offset as usize; + let end = start + entry.compressed_size as usize; + let bytes = decompress( + &mmap[start..end], + compression, + Some(entry.uncompressed_size as usize), + )?; + let record = record_schema(&bytes, self.record_format)?; + merge_schema_lock(&mut lock, &record)?; + } + return Ok(lock); + } + + let mut file = File::open(&self.path)?; + for index in 0..self.index.len() { + let entry = self + .index + .get(index) + .ok_or_else(|| Error::InvalidData(format!("Index {} out of bounds", index)))?; + file.seek(SeekFrom::Start(entry.offset))?; + let mut compressed = vec![0u8; entry.compressed_size as usize]; + file.read_exact(&mut compressed)?; + let bytes = decompress( + &compressed, + compression, + Some(entry.uncompressed_size as usize), + )?; + let record = record_schema(&bytes, self.record_format)?; + merge_schema_lock(&mut lock, &record)?; + } + Ok(lock) + } + + fn ensure_schema_compatible<'a, I>(&mut self, records: I) -> Result<()> + where + I: IntoIterator, + { + let mut lock = match &self.schema_lock { + Some(lock) => lock.clone(), + None if self.index.is_empty() => SchemaLock::default(), + None => self.rebuild_schema_lock()?, + }; + + for bytes in records { + let record = record_schema(bytes, self.record_format)?; + merge_schema_lock(&mut lock, &record)?; + } + + self.schema_lock = Some(lock); + Ok(()) + } + // -- Writing ------------------------------------------------------------- /// Add a single molecule. @@ -298,7 +373,7 @@ impl AtomDatabase { let serialized: Vec<(Vec, u32)> = molecules .par_iter() .map(|mol| { - let bytes = serialize_molecule_soa(mol)?; + let bytes = serialize_molecule_soa(mol, self.record_format)?; let num_atoms = mol.len() as u32; Ok((bytes, num_atoms)) }) @@ -337,6 +412,7 @@ impl AtomDatabase { } self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_compatible(records.iter().map(|(bytes, _)| *bytes))?; let compression = self.compression; @@ -384,6 +460,7 @@ impl AtomDatabase { } self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_compatible(records.iter().map(|(bytes, _)| bytes.as_slice()))?; let compression = self.compression; @@ -446,14 +523,14 @@ impl AtomDatabase { self.compression, Some(mol_index.uncompressed_size as usize), )?; - deserialize_molecule_soa(&decompressed) + deserialize_molecule_soa(&decompressed, self.record_format) } /// Read multiple molecules in parallel. pub fn get_molecules(&self, indices: &[usize]) -> Result> { let raw = self.get_raw_bytes(indices)?; raw.into_par_iter() - .map(|bytes| deserialize_molecule_soa(&bytes)) + .map(|bytes| deserialize_molecule_soa(&bytes, self.record_format)) .collect() } @@ -593,6 +670,10 @@ impl AtomDatabase { self.compression } + pub fn record_format(&self) -> u32 { + self.record_format + } + /// Compressed bytes for a molecule from the mmap (None if no mmap). pub fn get_compressed_slice(&self, index: usize) -> Option<&[u8]> { let mol_index = self.index.get(index)?; @@ -636,7 +717,7 @@ impl AtomDatabase { #[cfg(test)] mod tests { use super::*; - use crate::Atom; + use crate::{Atom, FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; use tempfile::NamedTempFile; fn molecule_from_atoms(atoms: Vec) -> Molecule { @@ -689,7 +770,7 @@ mod tests { let mut db = AtomDatabase::open(&path).unwrap(); assert_eq!(db.len(), 1); let mol = db.get_molecule(0).unwrap(); - assert_eq!(mol.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(mol.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); } } @@ -703,8 +784,8 @@ mod tests { Atom::new(0.0, 0.0, 0.0, 6), Atom::new(1.0, 0.0, 0.0, 8), ]); - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); - mol.energy = Some(-123.456); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); + mol.energy = Some(FloatScalarData::F64(-123.456)); // Write to database { @@ -721,13 +802,14 @@ mod tests { // Check forces are preserved assert!(retrieved.forces.is_some()); let forces = retrieved.forces.unwrap(); - assert_eq!(forces.len(), 2); - assert_eq!(forces[0], [0.1, 0.2, 0.3]); - assert_eq!(forces[1], [0.4, 0.5, 0.6]); + assert_eq!( + forces, + Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + ); // Check energy is preserved assert!(retrieved.energy.is_some()); - assert_eq!(retrieved.energy.unwrap(), -123.456); + assert_eq!(retrieved.energy.unwrap(), FloatScalarData::F64(-123.456)); } } @@ -803,7 +885,7 @@ mod tests { // Opening + writing should truncate the uncommitted tail before appending new data. let mut db = AtomDatabase::open(&path).unwrap(); let mol2 = molecule_from_atoms(vec![Atom::new(1.0, 2.0, 3.0, 8)]); - let mol2_bytes = serialize_molecule_soa(&mol2).unwrap(); + let mol2_bytes = serialize_molecule_soa(&mol2, db.record_format()).unwrap(); let mol2_compressed = compress(&mol2_bytes, compression).unwrap(); db.add_molecule(&mol2).unwrap(); @@ -825,13 +907,21 @@ mod tests { Atom::new(4.0, 5.0, 6.0, 8), ]); mol.name = Some("water".to_string()); - mol.energy = Some(-42.5); - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); - mol.charges = Some(vec![-0.5, 0.5]); - mol.velocities = Some(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - mol.cell = Some([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]); + mol.energy = Some(FloatScalarData::F64(-42.5)); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); + mol.charges = Some(FloatArrayData::F64(vec![-0.5, 0.5])); + mol.velocities = Some(Vec3Data::F32(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])); + mol.cell = Some(Mat3Data::F64([ + [10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0], + ])); mol.pbc = Some([true, true, false]); - mol.stress = Some([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); + mol.stress = Some(Mat3Data::F64([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ])); // atom_properties mol.atom_properties.insert( @@ -866,28 +956,31 @@ mod tests { // Verify all fields assert_eq!(r.name.as_deref(), Some("water")); assert_eq!(r.len(), 2); - assert_eq!(r.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(r.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); assert_eq!(r.atomic_numbers[0], 6); - assert_eq!(r.positions[1], [4.0, 5.0, 6.0]); + assert_eq!(r.atom(1).unwrap().position(), [4.0, 5.0, 6.0]); assert_eq!(r.atomic_numbers[1], 8); - assert_eq!(r.energy, Some(-42.5)); + assert_eq!(r.energy, Some(FloatScalarData::F64(-42.5))); assert_eq!( r.forces.as_ref().unwrap(), - &[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + &Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + ); + assert_eq!( + r.charges.as_ref().unwrap(), + &FloatArrayData::F64(vec![-0.5, 0.5]) ); - assert_eq!(r.charges.as_ref().unwrap(), &[-0.5, 0.5]); assert_eq!( r.velocities.as_ref().unwrap(), - &[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + &Vec3Data::F32(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) ); assert_eq!( r.cell.unwrap(), - [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + Mat3Data::F64([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) ); assert_eq!(r.pbc, Some([true, true, false])); assert_eq!( r.stress.unwrap(), - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + Mat3Data::F64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) ); // atom_properties @@ -930,7 +1023,7 @@ mod tests { let mut db = AtomDatabase::open(&path).unwrap(); let r = db.get_molecule(0).unwrap(); assert_eq!(r.len(), 1); - assert_eq!(r.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(r.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); assert_eq!(r.atomic_numbers[0], 6); assert!(r.name.is_none()); assert!(r.energy.is_none()); @@ -1056,4 +1149,58 @@ mod tests { other => panic!("expected String, got {:?}", other), } } + + #[test] + fn test_schema_lock_rejects_position_dtype_mismatch() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mol_f64 = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + let mol_f32 = Molecule::new(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + + db.add_molecule(&mol_f64).unwrap(); + let err = db.add_molecule(&mol_f32).unwrap_err(); + assert!(format!("{}", err).contains("Position dtype mismatch")); + } + + #[test] + fn test_schema_lock_rejects_custom_shape_mismatch() { + use crate::atom::PropertyValue; + + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol1 = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol1.set_property( + "spectrum".to_string(), + PropertyValue::FloatArray(vec![1.0, 2.0]), + ); + + let mut mol2 = Molecule::new(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + mol2.set_property("spectrum".to_string(), PropertyValue::FloatArray(vec![3.0])); + + db.add_molecule(&mol1).unwrap(); + let err = db.add_molecule(&mol2).unwrap_err(); + assert!(format!("{}", err).contains("Schema mismatch for section 'spectrum'")); + } + + #[test] + fn test_schema_lock_allows_late_optional_builtin() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mol1 = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + let mut mol2 = Molecule::new_f64(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + mol2.forces = Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]])); + + db.add_molecule(&mol1).unwrap(); + db.add_molecule(&mol2).unwrap(); + db.flush().unwrap(); + + let retrieved = db.get_molecule(1).unwrap(); + assert_eq!(retrieved.forces, Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]]))); + } } diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index 6741929..9ac5d37 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -1,4 +1,20 @@ use super::*; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; +use std::collections::BTreeMap; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct SchemaLock { + pub(super) positions_type: Option, + pub(super) sections: BTreeMap<(u8, String), SchemaEntry>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct SchemaEntry { + pub(super) type_tag: u8, + pub(super) per_atom: bool, + pub(super) elem_bytes: usize, + pub(super) slot_bytes: usize, +} /// Write a single tagged section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: &[u8]) { @@ -112,6 +128,36 @@ pub(super) fn decode_vec3_f32(payload: &[u8]) -> Result> { .collect() } +pub(super) fn decode_vec3_f64(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(24) { + return Err(Error::InvalidData( + "vec3 payload length not divisible by 24".into(), + )); + } + payload + .chunks_exact(24) + .map(|c| { + Ok([ + f64::from_le_bytes(arr(&c[0..8])?), + f64::from_le_bytes(arr(&c[8..16])?), + f64::from_le_bytes(arr(&c[16..24])?), + ]) + }) + .collect() +} + +pub(super) fn decode_f32_array(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "f32 array payload length not divisible by 4".into(), + )); + } + payload + .chunks_exact(4) + .map(|c| Ok(f32::from_le_bytes(arr(c)?))) + .collect() +} + pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { if !payload.len().is_multiple_of(8) { return Err(Error::InvalidData( @@ -124,6 +170,23 @@ pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { .collect() } +pub(super) fn decode_mat3x3_f32(payload: &[u8]) -> Result<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(Error::InvalidData(format!( + "mat3x3 payload length {} (expected 36)", + payload.len() + ))); + } + let mut mat = [[0.0f32; 3]; 3]; + for (r, row) in mat.iter_mut().enumerate() { + for (c, cell) in row.iter_mut().enumerate() { + let o = (r * 3 + c) * 4; + *cell = f32::from_le_bytes(arr(&payload[o..o + 4])?); + } + } + Ok(mat) +} + pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { if payload.len() != 72 { return Err(Error::InvalidData(format!( @@ -141,6 +204,82 @@ pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { Ok(mat) } +fn write_vec3_section(buf: &mut Vec, key: &str, values: &Vec3Data) { + match values { + Vec3Data::F32(values) => { + let mut payload = Vec::with_capacity(values.len() * 12); + for value in values { + extend_f32(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F32, &payload); + } + Vec3Data::F64(values) => { + let mut payload = Vec::with_capacity(values.len() * 24); + for value in values { + extend_f64(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F64, &payload); + } + } +} + +fn write_float_array_section(buf: &mut Vec, key: &str, values: &FloatArrayData) { + match values { + FloatArrayData::F32(values) => { + let mut payload = Vec::with_capacity(values.len() * 4); + extend_f32(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F32_ARRAY, &payload); + } + FloatArrayData::F64(values) => { + let mut payload = Vec::with_capacity(values.len() * 8); + extend_f64(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F64_ARRAY, &payload); + } + } +} + +fn write_mat3_section(buf: &mut Vec, key: &str, values: &Mat3Data) { + match values { + Mat3Data::F32(values) => { + let mut payload = Vec::with_capacity(36); + for row in values { + extend_f32(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F32, &payload); + } + Mat3Data::F64(values) => { + let mut payload = Vec::with_capacity(72); + for row in values { + extend_f64(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F64, &payload); + } + } +} + +fn write_energy_section(buf: &mut Vec, value: &FloatScalarData) { + match value { + FloatScalarData::F32(value) => { + write_section( + buf, + KIND_BUILTIN, + "energy", + TYPE_FLOAT32, + &value.to_le_bytes(), + ); + } + FloatScalarData::F64(value) => { + write_section( + buf, + KIND_BUILTIN, + "energy", + TYPE_FLOAT, + &value.to_le_bytes(), + ); + } + } +} + fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { Ok(match type_tag { TYPE_FLOAT => { @@ -224,12 +363,20 @@ fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result }) } -pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { +fn serialize_molecule_soa_v2(molecule: &Molecule) -> Result> { let n = molecule.len(); let mut buf = Vec::new(); buf.extend_from_slice(&(n as u32).to_le_bytes()); - for position in &molecule.positions { + let positions = match &molecule.positions { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float64 positions".into(), + )); + } + }; + for position in positions { buf.extend_from_slice(&position[0].to_le_bytes()); buf.extend_from_slice(&position[1].to_le_bytes()); buf.extend_from_slice(&position[2].to_le_bytes()); @@ -266,18 +413,42 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { buf.extend_from_slice(&n_sections.to_le_bytes()); if let Some(ref charges) = molecule.charges { + let charges = match charges { + FloatArrayData::F64(values) => values, + FloatArrayData::F32(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float32 charges".into(), + )); + } + }; let mut payload = Vec::with_capacity(charges.len() * 8); extend_f64(&mut payload, charges); write_section(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, &payload); } if let Some(ref cell) = molecule.cell { + let cell = match cell { + Mat3Data::F64(values) => values, + Mat3Data::F32(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float32 cell".into(), + )); + } + }; let mut payload = Vec::with_capacity(72); for row in cell { extend_f64(&mut payload, row); } write_section(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, &payload); } - if let Some(energy) = molecule.energy { + if let Some(ref energy) = molecule.energy { + let energy = match energy { + FloatScalarData::F64(value) => value, + FloatScalarData::F32(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float32 energy".into(), + )); + } + }; write_section( &mut buf, KIND_BUILTIN, @@ -287,6 +458,14 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { ); } if let Some(ref forces) = molecule.forces { + let forces = match forces { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float64 forces".into(), + )); + } + }; let mut payload = Vec::with_capacity(forces.len() * 12); for f in forces { extend_f32(&mut payload, f); @@ -301,6 +480,14 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); } if let Some(ref stress) = molecule.stress { + let stress = match stress { + Mat3Data::F64(values) => values, + Mat3Data::F32(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float32 stress".into(), + )); + } + }; let mut payload = Vec::with_capacity(72); for row in stress { extend_f64(&mut payload, row); @@ -308,6 +495,14 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { write_section(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, &payload); } if let Some(ref velocities) = molecule.velocities { + let velocities = match velocities { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => { + return Err(Error::InvalidData( + "record format 2 does not support float64 velocities".into(), + )); + } + }; let mut payload = Vec::with_capacity(velocities.len() * 12); for v in velocities { extend_f32(&mut payload, v); @@ -352,7 +547,129 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { Ok(buf) } -pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { +fn serialize_molecule_soa_v3(molecule: &Molecule) -> Result> { + let n = molecule.len(); + let mut buf = Vec::new(); + + buf.extend_from_slice(&(n as u32).to_le_bytes()); + match &molecule.positions { + Vec3Data::F32(values) => { + buf.push(TYPE_VEC3_F32); + for value in values { + buf.extend_from_slice(&value[0].to_le_bytes()); + buf.extend_from_slice(&value[1].to_le_bytes()); + buf.extend_from_slice(&value[2].to_le_bytes()); + } + } + Vec3Data::F64(values) => { + buf.push(TYPE_VEC3_F64); + for value in values { + buf.extend_from_slice(&value[0].to_le_bytes()); + buf.extend_from_slice(&value[1].to_le_bytes()); + buf.extend_from_slice(&value[2].to_le_bytes()); + } + } + } + buf.extend_from_slice(&molecule.atomic_numbers); + + let mut n_sections: u16 = 0; + if molecule.charges.is_some() { + n_sections += 1; + } + if molecule.cell.is_some() { + n_sections += 1; + } + if molecule.energy.is_some() { + n_sections += 1; + } + if molecule.forces.is_some() { + n_sections += 1; + } + if molecule.name.is_some() { + n_sections += 1; + } + if molecule.pbc.is_some() { + n_sections += 1; + } + if molecule.stress.is_some() { + n_sections += 1; + } + if molecule.velocities.is_some() { + n_sections += 1; + } + n_sections += molecule.atom_properties.len() as u16; + n_sections += molecule.properties.len() as u16; + buf.extend_from_slice(&n_sections.to_le_bytes()); + + if let Some(ref charges) = molecule.charges { + write_float_array_section(&mut buf, "charges", charges); + } + if let Some(ref cell) = molecule.cell { + write_mat3_section(&mut buf, "cell", cell); + } + if let Some(ref energy) = molecule.energy { + write_energy_section(&mut buf, energy); + } + if let Some(ref forces) = molecule.forces { + write_vec3_section(&mut buf, "forces", forces); + } + if let Some(ref name) = molecule.name { + write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); + } + if let Some(ref pbc) = molecule.pbc { + let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8]; + write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); + } + if let Some(ref stress) = molecule.stress { + write_mat3_section(&mut buf, "stress", stress); + } + if let Some(ref velocities) = molecule.velocities { + write_vec3_section(&mut buf, "velocities", velocities); + } + + let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect(); + atom_keys.sort(); + for key in atom_keys { + let value = &molecule.atom_properties[key]; + let payload = property_value_to_bytes(value); + write_section( + &mut buf, + KIND_ATOM_PROP, + key, + property_value_type_tag(value), + &payload, + ); + } + + let mut prop_keys: Vec<&String> = molecule.properties.keys().collect(); + prop_keys.sort(); + for key in prop_keys { + let value = &molecule.properties[key]; + let payload = property_value_to_bytes(value); + write_section( + &mut buf, + KIND_MOL_PROP, + key, + property_value_type_tag(value), + &payload, + ); + } + + Ok(buf) +} + +pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> Result> { + match record_format { + RECORD_FORMAT_SOA_V2 => serialize_molecule_soa_v2(molecule), + RECORD_FORMAT_SOA_V3 => serialize_molecule_soa_v3(molecule), + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +fn deserialize_molecule_soa_v2(bytes: &[u8]) -> Result { if bytes.len() < 6 { return Err(Error::InvalidData("SOA record too small".into())); } @@ -426,19 +743,227 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { if payload.len() < 8 { return Err(Error::InvalidData("energy payload truncated".into())); } - mol.energy = Some(f64::from_le_bytes(arr(&payload[..8])?)); + mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes( + arr(&payload[..8])?, + ))); } - "forces" => mol.forces = Some(decode_vec3_f32(payload)?), - "charges" => mol.charges = Some(decode_f64_array(payload)?), - "velocities" => mol.velocities = Some(decode_vec3_f32(payload)?), - "cell" => mol.cell = Some(decode_mat3x3_f64(payload)?), + "forces" => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), + "charges" => mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)), + "velocities" => mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), + "cell" => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + "pbc" => { + if payload.len() < 3 { + return Err(Error::InvalidData("pbc payload truncated".into())); + } + mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); + } + "stress" => mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + "name" => { + mol.name = Some( + std::str::from_utf8(payload) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? + .to_string(), + ) + } + _ => {} + }, + KIND_ATOM_PROP => { + mol.atom_properties + .insert(key.to_string(), decode_property_value(type_tag, payload)?); + } + KIND_MOL_PROP => { + mol.properties + .insert(key.to_string(), decode_property_value(type_tag, payload)?); + } + _ => {} + } + } + + Ok(mol) +} + +fn deserialize_molecule_soa_v3(bytes: &[u8]) -> Result { + if bytes.len() < 7 { + return Err(Error::InvalidData("SOA v3 record too small".into())); + } + + let mut pos = 0; + let n = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + + let positions_type = bytes[pos]; + pos += 1; + let positions = match positions_type { + TYPE_VEC3_F32 => { + let positions_end = pos + n * 12; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA v3 record truncated at positions".into(), + )); + } + let values = decode_vec3_f32(&bytes[pos..positions_end])?; + pos = positions_end; + Vec3Data::F32(values) + } + TYPE_VEC3_F64 => { + let positions_end = pos + n * 24; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA v3 record truncated at positions".into(), + )); + } + let values = decode_vec3_f64(&bytes[pos..positions_end])?; + pos = positions_end; + Vec3Data::F64(values) + } + _ => { + return Err(Error::InvalidData(format!( + "Unsupported positions type tag {} in record format 3", + positions_type + ))); + } + }; + + let z_end = pos + n; + if z_end > bytes.len() { + return Err(Error::InvalidData( + "SOA v3 record truncated at atomic_numbers".into(), + )); + } + let atomic_numbers = bytes[pos..z_end].to_vec(); + pos = z_end; + + if pos + 2 > bytes.len() { + return Err(Error::InvalidData( + "SOA v3 record truncated at n_sections".into(), + )); + } + let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + + let mut mol = match positions { + Vec3Data::F32(values) => { + Molecule::new(values, atomic_numbers).map_err(Error::InvalidData)? + } + Vec3Data::F64(values) => { + Molecule::new_f64(values, atomic_numbers).map_err(Error::InvalidData)? + } + }; + + for _ in 0..n_sections { + if pos + 7 > bytes.len() { + return Err(Error::InvalidData("SOA section header truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("SOA section key truncated".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))?; + pos += key_len; + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + if pos + payload_len > bytes.len() { + return Err(Error::InvalidData("SOA section payload truncated".into())); + } + let payload = &bytes[pos..pos + payload_len]; + pos += payload_len; + + match kind { + KIND_BUILTIN => match key { + "energy" => match type_tag { + TYPE_FLOAT => { + if payload.len() != 8 { + return Err(Error::InvalidData("energy f64 payload truncated".into())); + } + mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))); + } + TYPE_FLOAT32 => { + if payload.len() != 4 { + return Err(Error::InvalidData("energy f32 payload truncated".into())); + } + mol.energy = Some(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))); + } + _ => { + return Err(Error::InvalidData(format!( + "Unsupported energy type tag {}", + type_tag + ))); + } + }, + "forces" => match type_tag { + TYPE_VEC3_F32 => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => mol.forces = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported forces type tag {}", + type_tag + ))); + } + }, + "charges" => match type_tag { + TYPE_F32_ARRAY => { + mol.charges = Some(FloatArrayData::F32(decode_f32_array(payload)?)) + } + TYPE_F64_ARRAY => { + mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)) + } + _ => { + return Err(Error::InvalidData(format!( + "Unsupported charges type tag {}", + type_tag + ))); + } + }, + "velocities" => match type_tag { + TYPE_VEC3_F32 => { + mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)) + } + TYPE_VEC3_F64 => { + mol.velocities = Some(Vec3Data::F64(decode_vec3_f64(payload)?)) + } + _ => { + return Err(Error::InvalidData(format!( + "Unsupported velocities type tag {}", + type_tag + ))); + } + }, + "cell" => match type_tag { + TYPE_MAT3X3_F32 => mol.cell = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported cell type tag {}", + type_tag + ))); + } + }, "pbc" => { if payload.len() < 3 { return Err(Error::InvalidData("pbc payload truncated".into())); } mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); } - "stress" => mol.stress = Some(decode_mat3x3_f64(payload)?), + "stress" => match type_tag { + TYPE_MAT3X3_F32 => { + mol.stress = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)) + } + TYPE_MAT3X3_F64 => { + mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)) + } + _ => { + return Err(Error::InvalidData(format!( + "Unsupported stress type tag {}", + type_tag + ))); + } + }, "name" => { mol.name = Some( std::str::from_utf8(payload) @@ -462,3 +987,300 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { Ok(mol) } + +pub(super) fn deserialize_molecule_soa(bytes: &[u8], record_format: u32) -> Result { + match record_format { + RECORD_FORMAT_SOA_V2 => deserialize_molecule_soa_v2(bytes), + RECORD_FORMAT_SOA_V3 => deserialize_molecule_soa_v3(bytes), + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +fn schema_type_tag_elem_bytes(tag: u8) -> Result { + match tag { + TYPE_FLOAT => Ok(8), + TYPE_INT => Ok(8), + TYPE_STRING => Ok(0), + TYPE_F64_ARRAY => Ok(8), + TYPE_VEC3_F32 => Ok(12), + TYPE_I64_ARRAY => Ok(8), + TYPE_F32_ARRAY => Ok(4), + TYPE_VEC3_F64 => Ok(24), + TYPE_I32_ARRAY => Ok(4), + TYPE_BOOL3 => Ok(3), + TYPE_MAT3X3_F64 => Ok(72), + TYPE_FLOAT32 => Ok(4), + TYPE_MAT3X3_F32 => Ok(36), + _ => Err(Error::InvalidData(format!( + "Unsupported section type tag {}", + tag + ))), + } +} + +fn schema_is_per_atom(kind: u8, key: &str) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn schema_entry( + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + n_atoms: usize, +) -> Result { + let per_atom = schema_is_per_atom(kind, key); + let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; + let slot_bytes = if type_tag == TYPE_STRING { + 0 + } else if per_atom { + match type_tag { + TYPE_F64_ARRAY | TYPE_I64_ARRAY | TYPE_F32_ARRAY | TYPE_I32_ARRAY => elem_bytes, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => elem_bytes, + _ => payload_len + .checked_div(n_atoms.max(1)) + .unwrap_or(elem_bytes), + } + } else { + payload_len + }; + + if per_atom { + let expected = elem_bytes + .checked_mul(n_atoms) + .ok_or_else(|| Error::InvalidData(format!("Schema overflow for section '{}'", key)))?; + if payload_len != expected { + return Err(Error::InvalidData(format!( + "Section '{}' payload length {} does not match expected {}", + key, payload_len, expected + ))); + } + } + + Ok(SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn parse_record_schema_v2(bytes: &[u8]) -> Result { + if bytes.len() < 6 { + return Err(Error::InvalidData("SOA record too small".into())); + } + + let mut pos = 0usize; + let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + + let positions_end = pos + .checked_add( + n_atoms + .checked_mul(12) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, + ) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at positions".into(), + )); + } + pos = positions_end; + + let z_end = pos + .checked_add(n_atoms) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; + if z_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at atomic_numbers".into(), + )); + } + pos = z_end; + + if pos + 2 > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at n_sections".into(), + )); + } + let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + + let mut schema = SchemaLock { + positions_type: Some(TYPE_VEC3_F32), + sections: BTreeMap::new(), + }; + + for _ in 0..n_sections { + if pos + 7 > bytes.len() { + return Err(Error::InvalidData("SOA section header truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("SOA section key truncated".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? + .to_string(); + pos += key_len; + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let payload_end = pos + .checked_add(payload_len) + .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; + if payload_end > bytes.len() { + return Err(Error::InvalidData("SOA section payload truncated".into())); + } + pos = payload_end; + let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key), entry); + } + + Ok(schema) +} + +fn parse_record_schema_v3(bytes: &[u8]) -> Result { + if bytes.len() < 7 { + return Err(Error::InvalidData("SOA record too small".into())); + } + + let mut pos = 0usize; + let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + + let positions_type = bytes[pos]; + pos += 1; + let positions_elem_bytes = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(Error::InvalidData(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + + let positions_end = pos + .checked_add( + n_atoms + .checked_mul(positions_elem_bytes) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, + ) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at positions".into(), + )); + } + pos = positions_end; + + let z_end = pos + .checked_add(n_atoms) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; + if z_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at atomic_numbers".into(), + )); + } + pos = z_end; + + if pos + 2 > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at n_sections".into(), + )); + } + let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + + let mut schema = SchemaLock { + positions_type: Some(positions_type), + sections: BTreeMap::new(), + }; + + for _ in 0..n_sections { + if pos + 7 > bytes.len() { + return Err(Error::InvalidData("SOA section header truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("SOA section key truncated".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? + .to_string(); + pos += key_len; + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let payload_end = pos + .checked_add(payload_len) + .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; + if payload_end > bytes.len() { + return Err(Error::InvalidData("SOA section payload truncated".into())); + } + pos = payload_end; + let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key), entry); + } + + Ok(schema) +} + +pub(super) fn record_schema(bytes: &[u8], record_format: u32) -> Result { + match record_format { + RECORD_FORMAT_SOA_V2 => parse_record_schema_v2(bytes), + RECORD_FORMAT_SOA_V3 => parse_record_schema_v3(bytes), + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { + match (lock.positions_type, record.positions_type) { + (None, Some(tag)) => lock.positions_type = Some(tag), + (Some(expected), Some(actual)) if expected != actual => { + return Err(Error::InvalidData(format!( + "Position dtype mismatch: expected type tag {}, got {}", + expected, actual + ))); + } + _ => {} + } + + for ((kind, key), entry) in &record.sections { + match lock.sections.get(&(*kind, key.clone())) { + Some(expected) if expected != entry => { + return Err(Error::InvalidData(format!( + "Schema mismatch for section '{}': expected {:?}, got {:?}", + key, expected, entry + ))); + } + Some(_) => {} + None => { + lock.sections.insert((*kind, key.clone()), entry.clone()); + } + } + } + + Ok(()) +} diff --git a/atompack/tests/throughput_smoke.rs b/atompack/tests/throughput_smoke.rs index 1c58da4..28f58b6 100644 --- a/atompack/tests/throughput_smoke.rs +++ b/atompack/tests/throughput_smoke.rs @@ -1,5 +1,7 @@ // Copyright 2026 Entalpic -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{ + Atom, AtomDatabase, FloatScalarData, Molecule, Vec3Data, compression::CompressionType, +}; use std::hint::black_box; use std::time::{Duration, Instant}; @@ -179,15 +181,15 @@ fn synthetic_molecule(id: usize) -> Molecule { }) .collect(); let mut molecule = Molecule::from_atoms(atoms); - molecule.energy = Some(-1_000.0 - id as f64 * 1e-3); - molecule.forces = Some( + molecule.energy = Some(FloatScalarData::F64(-1_000.0 - id as f64 * 1e-3)); + molecule.forces = Some(Vec3Data::F32( (0..ATOMS_PER_MOLECULE) .map(|i| { let t = id as f32 * 0.002 + i as f32 * 0.02; [(t * 0.7).sin(), (t * 0.9).cos(), (t * 1.1).sin()] }) .collect(), - ); + )); molecule } @@ -231,7 +233,7 @@ fn run_smoke() -> atompack::Result { for _ in 0..RANDOM_READS { let molecule = db.get_molecule(rng.index(db.len()))?; rand_atoms += molecule.len(); - black_box(molecule.forces.as_ref().map(Vec::len)); + black_box(molecule.forces.as_ref().map(Vec3Data::len)); } let rand_elapsed = start_rand.elapsed(); assert_eq!(rand_atoms, RANDOM_READS * ATOMS_PER_MOLECULE); From d99a9286adafcb56a84b52dd27c06be27365771c Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 22:02:40 +0200 Subject: [PATCH 2/9] fix: restore CI validation expectations --- atompack-py/src/molecule.rs | 10 ++++---- atompack-py/src/molecule_helpers.rs | 36 ++++++++++++++++++++++++++++- atompack/src/atom.rs | 4 ++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 6fce82a..3e082f5 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -296,7 +296,7 @@ impl PyMolecule { #[setter] fn set_forces(&mut self, py: Python<'_>, forces: Py) -> PyResult<()> { let n_atoms = self.len(); - self.ensure_owned()?.forces = Some(parse_vec3_field(forces.bind(py), "forces", n_atoms)?); + self.ensure_owned()?.forces = Some(parse_vec3_field(forces.bind(py), "Forces", n_atoms)?); Ok(()) } @@ -332,7 +332,7 @@ impl PyMolecule { let n_atoms = self.len(); self.ensure_owned()?.charges = Some(parse_float_array_field( charges.bind(py), - "charges", + "Charges", n_atoms, )?); Ok(()) @@ -351,7 +351,7 @@ impl PyMolecule { let n_atoms = self.len(); self.ensure_owned()?.velocities = Some(parse_vec3_field( velocities.bind(py), - "velocities", + "Velocities", n_atoms, )?); Ok(()) @@ -367,7 +367,7 @@ impl PyMolecule { /// cell property (setter) #[setter] fn set_cell(&mut self, py: Python<'_>, cell: Py) -> PyResult<()> { - self.ensure_owned()?.cell = Some(parse_mat3_field(cell.bind(py), "cell")?); + self.ensure_owned()?.cell = Some(parse_mat3_field(cell.bind(py), "Cell")?); Ok(()) } @@ -382,7 +382,7 @@ impl PyMolecule { #[setter] fn set_stress(&mut self, py: Python<'_>, stress: Py) -> PyResult<()> { let inner = self.ensure_owned()?; - inner.stress = Some(parse_mat3_field(stress.bind(py), "stress")?); + inner.stress = Some(parse_mat3_field(stress.bind(py), "Stress")?); inner.properties.remove("stress"); Ok(()) } diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index 497c2aa..67553db 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -299,6 +299,40 @@ pub(crate) fn parse_vec3_field( ))) } +fn parse_positions_field(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 2 || view.shape()[1] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (n_atoms, 3)", + )); + } + return Ok(Vec3Data::F32( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 2 || view.shape()[1] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (n_atoms, 3)", + )); + } + return Ok(Vec3Data::F64( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); + } + Err(PyValueError::new_err( + "positions must be a float32 or float64 ndarray with shape (n_atoms, 3)", + )) +} + pub(crate) fn parse_float_array_field( value: &Bound<'_, PyAny>, label: &str, @@ -378,7 +412,7 @@ pub(crate) fn molecule_from_numpy_arrays( let z = atomic_numbers.readonly(); let z_arr = z.as_array(); let atomic_numbers_vec = z_arr.to_vec(); - let positions = parse_vec3_field(positions, "positions", atomic_numbers_vec.len())?; + let positions = parse_positions_field(positions)?; molecule_from_positions(positions, atomic_numbers_vec).map_err(PyValueError::new_err) } diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index d6daef4..c7be8e7 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -152,6 +152,10 @@ impl FloatArrayData { } } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn to_f64_vec(&self) -> Vec { match self { Self::F32(values) => values.iter().map(|value| *value as f64).collect(), From 3b1fdeec120000f7c604faf0282ed3e8d3071a70 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 22:26:19 +0200 Subject: [PATCH 3/9] refactor: store v3 positions dtype in schema metadata --- atompack-py/src/database.rs | 57 +- atompack-py/src/database_batch.rs | 2 +- atompack-py/src/database_flat.rs | 41 +- atompack-py/src/lib.rs | 50 +- atompack-py/src/molecule_helpers.rs | 7 +- atompack/src/storage/header.rs | 43 +- atompack/src/storage/mod.rs | 78 +- atompack/src/storage/soa.rs | 1087 ++++++++++++--------------- 8 files changed, 634 insertions(+), 731 deletions(-) diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index 3cb30e2..bea4369 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -14,6 +14,7 @@ impl PyAtomDatabase { fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult { let compression = self.inner.compression(); let record_format = self.inner.record_format(); + let positions_type = self.inner.positions_type(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); if use_mmap { @@ -23,7 +24,11 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", index)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes, record_format) + SoaMoleculeView::from_shared_bytes_inner( + bytes, + record_format, + positions_type, + ) } else { let compressed = self.inner.get_compressed_slice(index).ok_or_else(|| { @@ -44,7 +49,11 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed, record_format) + SoaMoleculeView::from_bytes_inner( + decompressed, + record_format, + positions_type, + ) } }) .map_err(|e| PyValueError::new_err(format!("{}", e))); @@ -56,7 +65,7 @@ impl PyAtomDatabase { let raw = raw_bytes.pop().ok_or_else(|| { PyValueError::new_err(format!("Missing raw bytes for molecule {}", index)) })?; - SoaMoleculeView::from_bytes(raw, record_format) + SoaMoleculeView::from_bytes(raw, record_format, positions_type) } } @@ -118,16 +127,6 @@ impl PyAtomDatabase { /// Add a molecule to the database fn add_molecule(&mut self, molecule: &PyMolecule) -> PyResult<()> { - if let Some(view) = molecule.as_view() - && view.record_format == self.inner.record_format() - { - let soa_bytes = view.bytes.as_slice(); - let n_atoms = view.n_atoms as u32; - return self - .inner - .add_raw_soa_records(&[(soa_bytes, n_atoms)]) - .map_err(|e| PyValueError::new_err(format!("{}", e))); - } let owned = molecule.clone_as_owned()?; self.inner .add_molecule(&owned) @@ -136,25 +135,10 @@ impl PyAtomDatabase { /// Add multiple molecules (processed in parallel) fn add_molecules(&mut self, molecules: Vec>) -> PyResult<()> { - // Split into view-backed (fast path) and owned molecules - let mut raw_records: Vec<(&[u8], u32)> = Vec::new(); let mut owned_molecules: Vec = Vec::new(); - let record_format = self.inner.record_format(); for m in &molecules { - if let Some(view) = m.as_view() - && view.record_format == record_format - { - raw_records.push((view.bytes.as_slice(), view.n_atoms as u32)); - } else { - owned_molecules.push(m.clone_as_owned()?); - } - } - - if !raw_records.is_empty() { - self.inner - .add_raw_soa_records(&raw_records) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + owned_molecules.push(m.clone_as_owned()?); } if !owned_molecules.is_empty() { let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect(); @@ -247,6 +231,7 @@ impl PyAtomDatabase { let compression = self.inner.compression(); let record_format = self.inner.record_format(); + let positions_type = self.inner.positions_type(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); let views: Vec = if use_mmap { @@ -259,7 +244,11 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(idx).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", idx)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes, record_format) + SoaMoleculeView::from_shared_bytes_inner( + bytes, + record_format, + positions_type, + ) } else { let compressed = self.inner.get_compressed_slice(idx).ok_or_else(|| { @@ -280,7 +269,11 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed, record_format) + SoaMoleculeView::from_bytes_inner( + decompressed, + record_format, + positions_type, + ) } }) .collect() @@ -292,7 +285,7 @@ impl PyAtomDatabase { .map_err(|e| PyValueError::new_err(format!("{}", e)))?; raw_bytes .into_iter() - .map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format)) + .map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format, positions_type)) .collect::>>() .map_err(|e| invalid_data(format!("{}", e))) } diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 9433a35..ff11c96 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -817,7 +817,7 @@ pub(super) fn add_arrays_batch_impl( sections: §ions, }) .map_err(PyValueError::new_err)?; - records.push((record, n_atoms as u32)); + records.push((record, n_atoms as u32, positions_type)); } py.detach(move || inner.add_owned_soa_records(records)) diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index 0ea2722..63b0a56 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -44,29 +44,20 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .ok_or_else(|| invalid_data("missing final atom offset"))?; let record_format = inner.record_format(); + let positions_type = inner + .positions_type() + .ok_or_else(|| invalid_data("Missing position dtype for batch"))?; let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; - let mut positions_type: Option = None; let mut schema: Vec = Vec::new(); for bytes in &raw_bytes { - let md = parse_mol_fast_soa(bytes, record_format)?; - match positions_type { - None => positions_type = Some(md.positions_type), - Some(expected) if expected != md.positions_type => { - return Err(invalid_data(format!( - "Position dtype mismatch across selected molecules: expected type tag {}, got {}", - expected, md.positions_type - ))); - } - _ => {} - } + let md = parse_mol_fast_soa(bytes, record_format, Some(positions_type))?; for section in &md.sections { let incoming = section_schema_from_ref(section, md.n_atoms)?; - if let Some(existing) = schema - .iter() - .find(|candidate| candidate.kind == incoming.kind && candidate.key == incoming.key) - { + if let Some(existing) = schema.iter().find(|candidate| { + candidate.kind == incoming.kind && candidate.key == incoming.key + }) { if existing.type_tag != incoming.type_tag || existing.per_atom != incoming.per_atom || existing.elem_bytes != incoming.elem_bytes @@ -82,9 +73,6 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( } } } - let positions_type = - positions_type.ok_or_else(|| invalid_data("Missing position dtype for batch"))?; - let positions_stride = match positions_type { TYPE_VEC3_F32 => 12usize, TYPE_VEC3_F64 => 24usize, @@ -144,15 +132,9 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .collect(); let process_mol = |i: usize, mol_bytes: &[u8]| -> atompack::Result<()> { - let md = parse_mol_fast_soa(mol_bytes, record_format)?; + let md = parse_mol_fast_soa(mol_bytes, record_format, Some(positions_type))?; let atom_off = offsets[i]; let n = md.n_atoms; - if md.positions_type != positions_type { - return Err(invalid_data(format!( - "Position dtype mismatch for molecule {}: expected type tag {}, got {}", - i, positions_type, md.positions_type - ))); - } unsafe { std::ptr::copy_nonoverlapping( @@ -168,9 +150,10 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( } for (section_idx, schema_entry) in schema.iter().enumerate() { - let sec = md.sections.iter().find(|sec| { - sec.kind == schema_entry.kind && sec.key == schema_entry.key - }); + let sec = md + .sections + .iter() + .find(|sec| sec.kind == schema_entry.kind && sec.key == schema_entry.key); let Some(sec) = sec else { continue; }; diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index f62eaa3..df83c8b 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -129,7 +129,6 @@ struct SectionRef<'a> { /// Per-molecule extracted data (references into decompressed bytes). struct MolData<'a> { n_atoms: usize, - positions_type: u8, positions_bytes: &'a [u8], atomic_numbers_bytes: &'a [u8], // n_atoms sections: Vec>, @@ -148,15 +147,20 @@ struct SectionSchema { /// Parse SOA format bytes into MolData without allocation. /// /// Layout: -/// [n_atoms:u32][positions:n*12][atomic_numbers:n] +/// [n_atoms:u32][positions:n*(12|24)][atomic_numbers:n] /// [n_sections:u16] /// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] -fn parse_mol_fast_soa(bytes: &[u8], record_format: u32) -> atompack::Result> { +fn parse_mol_fast_soa( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> atompack::Result> { let mut pos = 0usize; let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; let positions_type = match record_format { RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, - RECORD_FORMAT_SOA_V3 => read_u8_at(bytes, &mut pos, "SOA positions type")?, + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, _ => { return Err(invalid_data(format!( "Unsupported record format {}", @@ -201,7 +205,6 @@ fn parse_mol_fast_soa(bytes: &[u8], record_format: u32) -> atompack::Result atompack::Result { + fn from_storage_inner( + bytes: SoaBytes, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { if bytes.len() < 6 { return Err(invalid_data("SOA record too small")); } @@ -415,14 +421,8 @@ impl SoaMoleculeView { let mut pos = 4usize; let positions_type = match record_format { RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, - RECORD_FORMAT_SOA_V3 => { - if pos + 1 > bytes.len() { - return Err(invalid_data("SOA record truncated at positions type")); - } - let tag = bytes[pos]; - pos += 1; - tag - } + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, _ => { return Err(invalid_data(format!( "Unsupported record format {}", @@ -541,7 +541,6 @@ impl SoaMoleculeView { Ok(Self { bytes, - record_format, n_atoms, positions_type, positions_start, @@ -559,20 +558,29 @@ impl SoaMoleculeView { }) } - fn from_bytes_inner(bytes: Vec, record_format: u32) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Owned(bytes), record_format) + fn from_bytes_inner( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Owned(bytes), record_format, positions_type_hint) } fn from_shared_bytes_inner( bytes: SharedMmapBytes, record_format: u32, + positions_type_hint: Option, ) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Shared(bytes), record_format) + Self::from_storage_inner(SoaBytes::Shared(bytes), record_format, positions_type_hint) } /// Thin wrapper for call sites that need PyResult. - fn from_bytes(bytes: Vec, record_format: u32) -> PyResult { - Self::from_bytes_inner(bytes, record_format) + fn from_bytes( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> PyResult { + Self::from_bytes_inner(bytes, record_format, positions_type_hint) .map_err(|e| PyValueError::new_err(format!("{}", e))) } diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index 67553db..cb2f768 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -224,19 +224,14 @@ pub(crate) fn build_soa_record(record: SoaRecord<'_>) -> Result, String> account_section(parsed.payload.len(), parsed.key.len()); } - let positions_type_bytes = usize::from(record.record_format == RECORD_FORMAT_SOA_V3); let mut buf = Vec::with_capacity( - 4 + positions_type_bytes - + record.positions.len() + 4 + record.positions.len() + record.atomic_numbers.len() + 2 + section_overhead + payload_bytes, ); buf.extend_from_slice(&(n_atoms as u32).to_le_bytes()); - if record.record_format == RECORD_FORMAT_SOA_V3 { - buf.push(record.positions_type); - } buf.extend_from_slice(record.positions); buf.extend_from_slice(record.atomic_numbers); buf.extend_from_slice(&n_sections.to_le_bytes()); diff --git a/atompack/src/storage/header.rs b/atompack/src/storage/header.rs index 63dc4d9..2652156 100644 --- a/atompack/src/storage/header.rs +++ b/atompack/src/storage/header.rs @@ -7,6 +7,8 @@ pub(super) struct Header { pub(super) num_molecules: u64, pub(super) compression: CompressionType, pub(super) record_format: u32, + pub(super) schema_offset: u64, + pub(super) schema_len: u64, pub(super) index_offset: u64, pub(super) index_len: u64, } @@ -30,18 +32,20 @@ pub(super) fn encode_header_slot(header: Header) -> [u8; HEADER_SLOT_SIZE] { slot[4..8].copy_from_slice(&FILE_FORMAT_VERSION.to_le_bytes()); slot[8..16].copy_from_slice(&header.generation.to_le_bytes()); slot[16..24].copy_from_slice(&header.data_start.to_le_bytes()); - slot[24..32].copy_from_slice(&header.index_offset.to_le_bytes()); - slot[32..40].copy_from_slice(&header.index_len.to_le_bytes()); - slot[40..48].copy_from_slice(&header.num_molecules.to_le_bytes()); + slot[24..32].copy_from_slice(&header.schema_offset.to_le_bytes()); + slot[32..40].copy_from_slice(&header.schema_len.to_le_bytes()); + slot[40..48].copy_from_slice(&header.index_offset.to_le_bytes()); + slot[48..56].copy_from_slice(&header.index_len.to_le_bytes()); + slot[56..64].copy_from_slice(&header.num_molecules.to_le_bytes()); let (compression_type, compression_level) = match header.compression { CompressionType::None => (0u8, 0i32), CompressionType::Lz4 => (1u8, 0i32), CompressionType::Zstd(level) => (2u8, level), }; - slot[48] = compression_type; - slot[52..56].copy_from_slice(&compression_level.to_le_bytes()); - slot[56..60].copy_from_slice(&header.record_format.to_le_bytes()); + slot[64] = compression_type; + slot[68..72].copy_from_slice(&compression_level.to_le_bytes()); + slot[72..76].copy_from_slice(&header.record_format.to_le_bytes()); let checksum = adler32(&slot[..HEADER_SLOT_SIZE - 4]); slot[HEADER_SLOT_SIZE - 4..HEADER_SLOT_SIZE].copy_from_slice(&checksum.to_le_bytes()); @@ -70,12 +74,14 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option CompressionType::None, @@ -88,6 +94,17 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option file_size { + return None; + } + } + if index_offset == 0 || index_len == 0 { if num_molecules != 0 || index_offset != 0 || index_len != 0 { return None; @@ -113,6 +130,8 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option 0 && header.schema_len > 0 { + file.seek(SeekFrom::Start(header.schema_offset))?; + let mut schema_bytes = vec![0u8; header.schema_len as usize]; + file.read_exact(&mut schema_bytes)?; + Some(decode_schema_lock(&schema_bytes)?) + } else { + None + }; + Ok(Self { path, compression: header.compression, @@ -269,7 +280,7 @@ impl AtomDatabase { committed_end, truncate_tail_on_next_write, index, - schema_lock: None, + schema_lock, file: Some(file), data_mmap, }) @@ -289,7 +300,6 @@ impl AtomDatabase { } self.truncate_tail_on_next_write = false; - self.schema_lock = None; self.file = None; Ok(()) } @@ -297,6 +307,7 @@ impl AtomDatabase { fn rebuild_schema_lock(&self) -> Result { let mut lock = SchemaLock::default(); let compression = self.compression; + let positions_type_hint = self.positions_type(); if let Some(ref mmap) = self.data_mmap { for index in 0..self.index.len() { @@ -311,7 +322,7 @@ impl AtomDatabase { compression, Some(entry.uncompressed_size as usize), )?; - let record = record_schema(&bytes, self.record_format)?; + let record = record_schema(&bytes, self.record_format, positions_type_hint)?; merge_schema_lock(&mut lock, &record)?; } return Ok(lock); @@ -331,7 +342,7 @@ impl AtomDatabase { compression, Some(entry.uncompressed_size as usize), )?; - let record = record_schema(&bytes, self.record_format)?; + let record = record_schema(&bytes, self.record_format, positions_type_hint)?; merge_schema_lock(&mut lock, &record)?; } Ok(lock) @@ -339,7 +350,7 @@ impl AtomDatabase { fn ensure_schema_compatible<'a, I>(&mut self, records: I) -> Result<()> where - I: IntoIterator, + I: IntoIterator)>, { let mut lock = match &self.schema_lock { Some(lock) => lock.clone(), @@ -347,8 +358,9 @@ impl AtomDatabase { None => self.rebuild_schema_lock()?, }; - for bytes in records { - let record = record_schema(bytes, self.record_format)?; + for (bytes, positions_type_hint) in records { + let hint = positions_type_hint.or(lock.positions_type); + let record = record_schema(bytes, self.record_format, hint)?; merge_schema_lock(&mut lock, &record)?; } @@ -370,12 +382,16 @@ impl AtomDatabase { return Ok(()); } - let serialized: Vec<(Vec, u32)> = molecules + let serialized: Vec<(Vec, u32, u8)> = molecules .par_iter() .map(|mol| { let bytes = serialize_molecule_soa(mol, self.record_format)?; let num_atoms = mol.len() as u32; - Ok((bytes, num_atoms)) + Ok(( + bytes, + num_atoms, + schema_from_molecule(mol)?.positions_type.unwrap(), + )) }) .collect::>>()?; @@ -396,7 +412,7 @@ impl AtomDatabase { } #[doc(hidden)] - pub fn add_owned_soa_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { + pub fn add_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { if records.is_empty() { return Ok(()); } @@ -412,7 +428,7 @@ impl AtomDatabase { } self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_compatible(records.iter().map(|(bytes, _)| *bytes))?; + self.ensure_schema_compatible(records.iter().map(|(bytes, _)| (*bytes, None)))?; let compression = self.compression; @@ -451,7 +467,7 @@ impl AtomDatabase { Ok(()) } - fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { + fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { return Err(Error::InvalidData( "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." @@ -460,13 +476,17 @@ impl AtomDatabase { } self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_compatible(records.iter().map(|(bytes, _)| bytes.as_slice()))?; + self.ensure_schema_compatible( + records + .iter() + .map(|(bytes, _, positions_type)| (bytes.as_slice(), Some(*positions_type))), + )?; let compression = self.compression; let compressed_records: Vec<(Vec, u32, u32)> = records .into_par_iter() - .map(|(bytes, num_atoms)| { + .map(|(bytes, num_atoms, _positions_type)| { let uncompressed_size = bytes.len() as u32; let compressed = compress(&bytes, compression)?; Ok((compressed, uncompressed_size, num_atoms)) @@ -523,14 +543,16 @@ impl AtomDatabase { self.compression, Some(mol_index.uncompressed_size as usize), )?; - deserialize_molecule_soa(&decompressed, self.record_format) + deserialize_molecule_soa(&decompressed, self.record_format, self.positions_type()) } /// Read multiple molecules in parallel. pub fn get_molecules(&self, indices: &[usize]) -> Result> { let raw = self.get_raw_bytes(indices)?; raw.into_par_iter() - .map(|bytes| deserialize_molecule_soa(&bytes, self.record_format)) + .map(|bytes| { + deserialize_molecule_soa(&bytes, self.record_format, self.positions_type()) + }) .collect() } @@ -618,8 +640,20 @@ impl AtomDatabase { }; let index_bytes = encode_index(index_vec); + let schema_bytes = self + .schema_lock + .as_ref() + .map(encode_schema_lock) + .transpose()?; let mut file = OpenOptions::new().write(true).open(&self.path)?; + let schema_offset = file.seek(SeekFrom::End(0))?; + let schema_len = if let Some(schema_bytes) = &schema_bytes { + file.write_all(schema_bytes)?; + schema_bytes.len() as u64 + } else { + 0 + }; let index_offset = file.seek(SeekFrom::End(0))?; file.write_all(&index_bytes)?; file.flush()?; @@ -633,6 +667,8 @@ impl AtomDatabase { num_molecules: self.index.len() as u64, compression: self.compression, record_format: self.record_format, + schema_offset: if schema_len > 0 { schema_offset } else { 0 }, + schema_len, index_offset, index_len: index_bytes.len() as u64, }; @@ -674,6 +710,12 @@ impl AtomDatabase { self.record_format } + pub fn positions_type(&self) -> Option { + self.schema_lock + .as_ref() + .and_then(|lock| lock.positions_type) + } + /// Compressed bytes for a molecule from the mmap (None if no mmap). pub fn get_compressed_slice(&self, index: usize) -> Option<&[u8]> { let mol_index = self.index.get(index)?; diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index 9ac5d37..1ddfec0 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -16,6 +16,114 @@ pub(super) struct SchemaEntry { pub(super) slot_bytes: usize, } +const SCHEMA_BLOB_VERSION: u32 = 1; + +fn positions_type_from_molecule(molecule: &Molecule) -> u8 { + match molecule.positions { + Vec3Data::F32(_) => TYPE_VEC3_F32, + Vec3Data::F64(_) => TYPE_VEC3_F64, + } +} + +pub(super) fn encode_schema_lock(lock: &SchemaLock) -> Result> { + let mut buf = Vec::new(); + buf.extend_from_slice(&SCHEMA_BLOB_VERSION.to_le_bytes()); + buf.push(lock.positions_type.unwrap_or(255)); + buf.extend_from_slice(&(lock.sections.len() as u32).to_le_bytes()); + for ((kind, key), entry) in &lock.sections { + let key_len: u16 = key + .len() + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema key '{}' is too long", key)))?; + let elem_bytes: u32 = entry + .elem_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema elem_bytes overflow for '{}'", key)))?; + let slot_bytes: u32 = entry + .slot_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema slot_bytes overflow for '{}'", key)))?; + buf.push(*kind); + buf.push(entry.type_tag); + buf.push(u8::from(entry.per_atom)); + buf.extend_from_slice(&key_len.to_le_bytes()); + buf.extend_from_slice(&elem_bytes.to_le_bytes()); + buf.extend_from_slice(&slot_bytes.to_le_bytes()); + buf.extend_from_slice(key.as_bytes()); + } + Ok(buf) +} + +pub(super) fn decode_schema_lock(bytes: &[u8]) -> Result { + if bytes.len() < 9 { + return Err(Error::InvalidData("Schema blob too small".into())); + } + let version = u32::from_le_bytes(arr(&bytes[0..4])?); + if version != SCHEMA_BLOB_VERSION { + return Err(Error::InvalidData(format!( + "Unsupported schema blob version {}", + version + ))); + } + let positions_type = match bytes[4] { + 255 => None, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => Some(bytes[4]), + other => { + return Err(Error::InvalidData(format!( + "Unsupported schema positions type tag {}", + other + ))); + } + }; + let count = u32::from_le_bytes(arr(&bytes[5..9])?) as usize; + let mut pos = 9usize; + let mut sections = BTreeMap::new(); + for _ in 0..count { + if pos + 13 > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let type_tag = bytes[pos]; + pos += 1; + let per_atom = match bytes[pos] { + 0 => false, + 1 => true, + _ => return Err(Error::InvalidData("Invalid schema per_atom flag".into())), + }; + pos += 1; + let key_len = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + let elem_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let slot_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated at key".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in schema key".into()))? + .to_string(); + pos += key_len; + sections.insert( + (kind, key), + SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }, + ); + } + if pos != bytes.len() { + return Err(Error::InvalidData("Schema blob trailing bytes".into())); + } + Ok(SchemaLock { + positions_type, + sections, + }) +} + /// Write a single tagged section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: &[u8]) { buf.push(kind); @@ -280,299 +388,95 @@ fn write_energy_section(buf: &mut Vec, value: &FloatScalarData) { } } -fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { - Ok(match type_tag { - TYPE_FLOAT => { - if payload.len() < 8 { - return Err(Error::InvalidData("f64 property truncated".into())); - } - PropertyValue::Float(f64::from_le_bytes(arr(&payload[..8])?)) - } - TYPE_INT => { - if payload.len() < 8 { - return Err(Error::InvalidData("i64 property truncated".into())); - } - PropertyValue::Int(i64::from_le_bytes(arr(&payload[..8])?)) - } - TYPE_STRING => PropertyValue::String( - std::str::from_utf8(payload) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in property".into()))? - .to_string(), - ), - TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), - TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), - TYPE_I64_ARRAY => { - if !payload.len().is_multiple_of(8) { - return Err(Error::InvalidData( - "i64 array payload length not divisible by 8".into(), - )); - } - PropertyValue::IntArray( - payload - .chunks_exact(8) - .map(|c| Ok(i64::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - TYPE_F32_ARRAY => { - if !payload.len().is_multiple_of(4) { - return Err(Error::InvalidData( - "f32 array payload length not divisible by 4".into(), - )); - } - PropertyValue::Float32Array( - payload - .chunks_exact(4) - .map(|c| Ok(f32::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - TYPE_VEC3_F64 => { - if !payload.len().is_multiple_of(24) { - return Err(Error::InvalidData( - "vec3 payload length not divisible by 24".into(), - )); - } - PropertyValue::Vec3ArrayF64( - payload - .chunks_exact(24) - .map(|c| { - Ok([ - f64::from_le_bytes(arr(&c[0..8])?), - f64::from_le_bytes(arr(&c[8..16])?), - f64::from_le_bytes(arr(&c[16..24])?), - ]) - }) - .collect::>()?, - ) - } - TYPE_I32_ARRAY => { - if !payload.len().is_multiple_of(4) { - return Err(Error::InvalidData( - "i32 array payload length not divisible by 4".into(), - )); - } - PropertyValue::Int32Array( - payload - .chunks_exact(4) - .map(|c| Ok(i32::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - _ => return Err(Error::InvalidData(format!("Unknown type tag {}", type_tag))), - }) -} - -fn serialize_molecule_soa_v2(molecule: &Molecule) -> Result> { - let n = molecule.len(); - let mut buf = Vec::new(); - - buf.extend_from_slice(&(n as u32).to_le_bytes()); - let positions = match &molecule.positions { - Vec3Data::F32(values) => values, - Vec3Data::F64(_) => { - return Err(Error::InvalidData( - "record format 2 does not support float64 positions".into(), - )); - } - }; - for position in positions { - buf.extend_from_slice(&position[0].to_le_bytes()); - buf.extend_from_slice(&position[1].to_le_bytes()); - buf.extend_from_slice(&position[2].to_le_bytes()); +fn resolve_positions_type(record_format: u32, positions_type_hint: Option) -> Result { + match record_format { + RECORD_FORMAT_SOA_V2 => Ok(TYPE_VEC3_F32), + RECORD_FORMAT_SOA_V3 => positions_type_hint.ok_or_else(|| { + Error::InvalidData("Missing positions dtype for record format 3".into()) + }), + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), } - buf.extend_from_slice(&molecule.atomic_numbers); +} - let mut n_sections: u16 = 0; - if molecule.charges.is_some() { - n_sections += 1; - } - if molecule.cell.is_some() { - n_sections += 1; - } - if molecule.energy.is_some() { - n_sections += 1; - } - if molecule.forces.is_some() { - n_sections += 1; - } - if molecule.name.is_some() { - n_sections += 1; - } - if molecule.pbc.is_some() { - n_sections += 1; - } - if molecule.stress.is_some() { - n_sections += 1; - } - if molecule.velocities.is_some() { - n_sections += 1; +fn positions_stride(positions_type: u8) -> Result { + match positions_type { + TYPE_VEC3_F32 => Ok(12), + TYPE_VEC3_F64 => Ok(24), + _ => Err(Error::InvalidData(format!( + "Unsupported positions type tag {}", + positions_type + ))), } - n_sections += molecule.atom_properties.len() as u16; - n_sections += molecule.properties.len() as u16; - buf.extend_from_slice(&n_sections.to_le_bytes()); +} - if let Some(ref charges) = molecule.charges { - let charges = match charges { - FloatArrayData::F64(values) => values, - FloatArrayData::F32(_) => { +fn validate_record_format_compat(molecule: &Molecule, record_format: u32) -> Result<()> { + match record_format { + RECORD_FORMAT_SOA_V3 => Ok(()), + RECORD_FORMAT_SOA_V2 => { + if matches!(molecule.positions, Vec3Data::F64(_)) { + return Err(Error::InvalidData( + "record format 2 does not support float64 positions".into(), + )); + } + if matches!(molecule.charges, Some(FloatArrayData::F32(_))) { return Err(Error::InvalidData( "record format 2 does not support float32 charges".into(), )); } - }; - let mut payload = Vec::with_capacity(charges.len() * 8); - extend_f64(&mut payload, charges); - write_section(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, &payload); - } - if let Some(ref cell) = molecule.cell { - let cell = match cell { - Mat3Data::F64(values) => values, - Mat3Data::F32(_) => { + if matches!(molecule.cell, Some(Mat3Data::F32(_))) { return Err(Error::InvalidData( "record format 2 does not support float32 cell".into(), )); } - }; - let mut payload = Vec::with_capacity(72); - for row in cell { - extend_f64(&mut payload, row); - } - write_section(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, &payload); - } - if let Some(ref energy) = molecule.energy { - let energy = match energy { - FloatScalarData::F64(value) => value, - FloatScalarData::F32(_) => { + if matches!(molecule.energy, Some(FloatScalarData::F32(_))) { return Err(Error::InvalidData( "record format 2 does not support float32 energy".into(), )); } - }; - write_section( - &mut buf, - KIND_BUILTIN, - "energy", - TYPE_FLOAT, - &energy.to_le_bytes(), - ); - } - if let Some(ref forces) = molecule.forces { - let forces = match forces { - Vec3Data::F32(values) => values, - Vec3Data::F64(_) => { + if matches!(molecule.forces, Some(Vec3Data::F64(_))) { return Err(Error::InvalidData( "record format 2 does not support float64 forces".into(), )); } - }; - let mut payload = Vec::with_capacity(forces.len() * 12); - for f in forces { - extend_f32(&mut payload, f); - } - write_section(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, &payload); - } - if let Some(ref name) = molecule.name { - write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); - } - if let Some(ref pbc) = molecule.pbc { - let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8]; - write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); - } - if let Some(ref stress) = molecule.stress { - let stress = match stress { - Mat3Data::F64(values) => values, - Mat3Data::F32(_) => { + if matches!(molecule.stress, Some(Mat3Data::F32(_))) { return Err(Error::InvalidData( "record format 2 does not support float32 stress".into(), )); } - }; - let mut payload = Vec::with_capacity(72); - for row in stress { - extend_f64(&mut payload, row); - } - write_section(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, &payload); - } - if let Some(ref velocities) = molecule.velocities { - let velocities = match velocities { - Vec3Data::F32(values) => values, - Vec3Data::F64(_) => { + if matches!(molecule.velocities, Some(Vec3Data::F64(_))) { return Err(Error::InvalidData( "record format 2 does not support float64 velocities".into(), )); } - }; - let mut payload = Vec::with_capacity(velocities.len() * 12); - for v in velocities { - extend_f32(&mut payload, v); + Ok(()) } - write_section( - &mut buf, - KIND_BUILTIN, - "velocities", - TYPE_VEC3_F32, - &payload, - ); - } - - let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect(); - atom_keys.sort(); - for key in atom_keys { - let value = &molecule.atom_properties[key]; - let payload = property_value_to_bytes(value); - write_section( - &mut buf, - KIND_ATOM_PROP, - key, - property_value_type_tag(value), - &payload, - ); - } - - let mut prop_keys: Vec<&String> = molecule.properties.keys().collect(); - prop_keys.sort(); - for key in prop_keys { - let value = &molecule.properties[key]; - let payload = property_value_to_bytes(value); - write_section( - &mut buf, - KIND_MOL_PROP, - key, - property_value_type_tag(value), - &payload, - ); + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), } - - Ok(buf) } -fn serialize_molecule_soa_v3(molecule: &Molecule) -> Result> { - let n = molecule.len(); - let mut buf = Vec::new(); - - buf.extend_from_slice(&(n as u32).to_le_bytes()); - match &molecule.positions { +fn write_positions(buf: &mut Vec, positions: &Vec3Data) { + match positions { Vec3Data::F32(values) => { - buf.push(TYPE_VEC3_F32); for value in values { - buf.extend_from_slice(&value[0].to_le_bytes()); - buf.extend_from_slice(&value[1].to_le_bytes()); - buf.extend_from_slice(&value[2].to_le_bytes()); + extend_f32(buf, value); } } Vec3Data::F64(values) => { - buf.push(TYPE_VEC3_F64); for value in values { - buf.extend_from_slice(&value[0].to_le_bytes()); - buf.extend_from_slice(&value[1].to_le_bytes()); - buf.extend_from_slice(&value[2].to_le_bytes()); + extend_f64(buf, value); } } } - buf.extend_from_slice(&molecule.atomic_numbers); +} - let mut n_sections: u16 = 0; +fn count_sections(molecule: &Molecule) -> u16 { + let mut n_sections = 0; if molecule.charges.is_some() { n_sections += 1; } @@ -599,32 +503,34 @@ fn serialize_molecule_soa_v3(molecule: &Molecule) -> Result> { } n_sections += molecule.atom_properties.len() as u16; n_sections += molecule.properties.len() as u16; - buf.extend_from_slice(&n_sections.to_le_bytes()); + n_sections +} +fn write_sections(buf: &mut Vec, molecule: &Molecule) { if let Some(ref charges) = molecule.charges { - write_float_array_section(&mut buf, "charges", charges); + write_float_array_section(buf, "charges", charges); } if let Some(ref cell) = molecule.cell { - write_mat3_section(&mut buf, "cell", cell); + write_mat3_section(buf, "cell", cell); } if let Some(ref energy) = molecule.energy { - write_energy_section(&mut buf, energy); + write_energy_section(buf, energy); } if let Some(ref forces) = molecule.forces { - write_vec3_section(&mut buf, "forces", forces); + write_vec3_section(buf, "forces", forces); } if let Some(ref name) = molecule.name { - write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); + write_section(buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); } if let Some(ref pbc) = molecule.pbc { let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8]; - write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); + write_section(buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); } if let Some(ref stress) = molecule.stress { - write_mat3_section(&mut buf, "stress", stress); + write_mat3_section(buf, "stress", stress); } if let Some(ref velocities) = molecule.velocities { - write_vec3_section(&mut buf, "velocities", velocities); + write_vec3_section(buf, "velocities", velocities); } let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect(); @@ -633,7 +539,7 @@ fn serialize_molecule_soa_v3(molecule: &Molecule) -> Result> { let value = &molecule.atom_properties[key]; let payload = property_value_to_bytes(value); write_section( - &mut buf, + buf, KIND_ATOM_PROP, key, property_value_type_tag(value), @@ -647,187 +553,255 @@ fn serialize_molecule_soa_v3(molecule: &Molecule) -> Result> { let value = &molecule.properties[key]; let payload = property_value_to_bytes(value); write_section( - &mut buf, + buf, KIND_MOL_PROP, key, property_value_type_tag(value), &payload, ); } - - Ok(buf) -} - -pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> Result> { - match record_format { - RECORD_FORMAT_SOA_V2 => serialize_molecule_soa_v2(molecule), - RECORD_FORMAT_SOA_V3 => serialize_molecule_soa_v3(molecule), - _ => Err(Error::InvalidData(format!( - "Unsupported record format {}", - record_format - ))), - } } -fn deserialize_molecule_soa_v2(bytes: &[u8]) -> Result { - if bytes.len() < 6 { - return Err(Error::InvalidData("SOA record too small".into())); - } - - let mut pos = 0; - let n = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - - let positions_end = pos + n * 12; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at positions".into(), - )); - } - let mut positions = Vec::with_capacity(n); - for i in 0..n { - let o = pos + i * 12; - let x = f32::from_le_bytes(arr(&bytes[o..o + 4])?); - let y = f32::from_le_bytes(arr(&bytes[o + 4..o + 8])?); - let z = f32::from_le_bytes(arr(&bytes[o + 8..o + 12])?); - positions.push([x, y, z]); - } - pos = positions_end; +fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { + Ok(match type_tag { + TYPE_FLOAT => { + if payload.len() < 8 { + return Err(Error::InvalidData("f64 property truncated".into())); + } + PropertyValue::Float(f64::from_le_bytes(arr(&payload[..8])?)) + } + TYPE_INT => { + if payload.len() < 8 { + return Err(Error::InvalidData("i64 property truncated".into())); + } + PropertyValue::Int(i64::from_le_bytes(arr(&payload[..8])?)) + } + TYPE_STRING => PropertyValue::String( + std::str::from_utf8(payload) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in property".into()))? + .to_string(), + ), + TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), + TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), + TYPE_I64_ARRAY => { + if !payload.len().is_multiple_of(8) { + return Err(Error::InvalidData( + "i64 array payload length not divisible by 8".into(), + )); + } + PropertyValue::IntArray( + payload + .chunks_exact(8) + .map(|c| Ok(i64::from_le_bytes(arr(c)?))) + .collect::>()?, + ) + } + TYPE_F32_ARRAY => { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "f32 array payload length not divisible by 4".into(), + )); + } + PropertyValue::Float32Array( + payload + .chunks_exact(4) + .map(|c| Ok(f32::from_le_bytes(arr(c)?))) + .collect::>()?, + ) + } + TYPE_VEC3_F64 => { + if !payload.len().is_multiple_of(24) { + return Err(Error::InvalidData( + "vec3 payload length not divisible by 24".into(), + )); + } + PropertyValue::Vec3ArrayF64( + payload + .chunks_exact(24) + .map(|c| { + Ok([ + f64::from_le_bytes(arr(&c[0..8])?), + f64::from_le_bytes(arr(&c[8..16])?), + f64::from_le_bytes(arr(&c[16..24])?), + ]) + }) + .collect::>()?, + ) + } + TYPE_I32_ARRAY => { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "i32 array payload length not divisible by 4".into(), + )); + } + PropertyValue::Int32Array( + payload + .chunks_exact(4) + .map(|c| Ok(i32::from_le_bytes(arr(c)?))) + .collect::>()?, + ) + } + _ => return Err(Error::InvalidData(format!("Unknown type tag {}", type_tag))), + }) +} - let z_end = pos + n; - if z_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at atomic_numbers".into(), - )); - } - let atomic_numbers = bytes[pos..z_end].to_vec(); - pos = z_end; +pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> Result> { + validate_record_format_compat(molecule, record_format)?; - if pos + 2 > bytes.len() { + let n = molecule.len(); + let mut buf = Vec::new(); + + buf.extend_from_slice(&(n as u32).to_le_bytes()); + write_positions(&mut buf, &molecule.positions); + buf.extend_from_slice(&molecule.atomic_numbers); + buf.extend_from_slice(&count_sections(molecule).to_le_bytes()); + write_sections(&mut buf, molecule); + + Ok(buf) +} + +fn decode_positions( + bytes: &[u8], + pos: &mut usize, + n_atoms: usize, + positions_type: u8, +) -> Result { + let positions_len = n_atoms + .checked_mul(positions_stride(positions_type)?) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + let positions_end = pos + .checked_add(positions_len) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { return Err(Error::InvalidData( - "SOA record truncated at n_sections".into(), + "SOA record truncated at positions".into(), )); } - let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; - pos += 2; - - let mut mol = Molecule::new(positions, atomic_numbers).map_err(Error::InvalidData)?; - - for _ in 0..n_sections { - if pos + 7 > bytes.len() { - return Err(Error::InvalidData("SOA section header truncated".into())); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(Error::InvalidData("SOA section key truncated".into())); - } - let key = std::str::from_utf8(&bytes[pos..pos + key_len]) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))?; - pos += key_len; - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - if pos + payload_len > bytes.len() { - return Err(Error::InvalidData("SOA section payload truncated".into())); + let positions = match positions_type { + TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(&bytes[*pos..positions_end])?), + TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(&bytes[*pos..positions_end])?), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported positions type tag {}", + positions_type + ))); } - let payload = &bytes[pos..pos + payload_len]; - pos += payload_len; + }; + *pos = positions_end; + Ok(positions) +} - match kind { - KIND_BUILTIN => match key { - "energy" => { - if payload.len() < 8 { - return Err(Error::InvalidData("energy payload truncated".into())); - } - mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes( - arr(&payload[..8])?, - ))); - } - "forces" => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), - "charges" => mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)), - "velocities" => mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), - "cell" => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - "pbc" => { - if payload.len() < 3 { - return Err(Error::InvalidData("pbc payload truncated".into())); - } - mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); +fn decode_builtin_section( + mol: &mut Molecule, + key: &str, + type_tag: u8, + payload: &[u8], +) -> Result<()> { + match key { + "energy" => match type_tag { + TYPE_FLOAT => { + if payload.len() != 8 { + return Err(Error::InvalidData("energy f64 payload truncated".into())); } - "stress" => mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - "name" => { - mol.name = Some( - std::str::from_utf8(payload) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? - .to_string(), - ) + mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))); + } + TYPE_FLOAT32 => { + if payload.len() != 4 { + return Err(Error::InvalidData("energy f32 payload truncated".into())); } - _ => {} - }, - KIND_ATOM_PROP => { - mol.atom_properties - .insert(key.to_string(), decode_property_value(type_tag, payload)?); + mol.energy = Some(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))); } - KIND_MOL_PROP => { - mol.properties - .insert(key.to_string(), decode_property_value(type_tag, payload)?); + _ => { + return Err(Error::InvalidData(format!( + "Unsupported energy type tag {}", + type_tag + ))); } - _ => {} + }, + "forces" => match type_tag { + TYPE_VEC3_F32 => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => mol.forces = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported forces type tag {}", + type_tag + ))); + } + }, + "charges" => match type_tag { + TYPE_F32_ARRAY => mol.charges = Some(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported charges type tag {}", + type_tag + ))); + } + }, + "velocities" => match type_tag { + TYPE_VEC3_F32 => mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => mol.velocities = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported velocities type tag {}", + type_tag + ))); + } + }, + "cell" => match type_tag { + TYPE_MAT3X3_F32 => mol.cell = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported cell type tag {}", + type_tag + ))); + } + }, + "pbc" => { + if payload.len() < 3 { + return Err(Error::InvalidData("pbc payload truncated".into())); + } + mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); + } + "stress" => match type_tag { + TYPE_MAT3X3_F32 => mol.stress = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => { + return Err(Error::InvalidData(format!( + "Unsupported stress type tag {}", + type_tag + ))); + } + }, + "name" => { + mol.name = Some( + std::str::from_utf8(payload) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? + .to_string(), + ); } + _ => {} } - - Ok(mol) + Ok(()) } -fn deserialize_molecule_soa_v3(bytes: &[u8]) -> Result { - if bytes.len() < 7 { - return Err(Error::InvalidData("SOA v3 record too small".into())); +fn deserialize_molecule_soa_with_positions(bytes: &[u8], positions_type: u8) -> Result { + if bytes.len() < 6 { + return Err(Error::InvalidData("SOA record too small".into())); } let mut pos = 0; let n = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; pos += 4; - let positions_type = bytes[pos]; - pos += 1; - let positions = match positions_type { - TYPE_VEC3_F32 => { - let positions_end = pos + n * 12; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA v3 record truncated at positions".into(), - )); - } - let values = decode_vec3_f32(&bytes[pos..positions_end])?; - pos = positions_end; - Vec3Data::F32(values) - } - TYPE_VEC3_F64 => { - let positions_end = pos + n * 24; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA v3 record truncated at positions".into(), - )); - } - let values = decode_vec3_f64(&bytes[pos..positions_end])?; - pos = positions_end; - Vec3Data::F64(values) - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported positions type tag {} in record format 3", - positions_type - ))); - } - }; - - let z_end = pos + n; + let positions = decode_positions(bytes, &mut pos, n, positions_type)?; + let z_end = pos + .checked_add(n) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; if z_end > bytes.len() { return Err(Error::InvalidData( - "SOA v3 record truncated at atomic_numbers".into(), + "SOA record truncated at atomic_numbers".into(), )); } let atomic_numbers = bytes[pos..z_end].to_vec(); @@ -835,7 +809,7 @@ fn deserialize_molecule_soa_v3(bytes: &[u8]) -> Result { if pos + 2 > bytes.len() { return Err(Error::InvalidData( - "SOA v3 record truncated at n_sections".into(), + "SOA record truncated at n_sections".into(), )); } let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; @@ -875,104 +849,7 @@ fn deserialize_molecule_soa_v3(bytes: &[u8]) -> Result { pos += payload_len; match kind { - KIND_BUILTIN => match key { - "energy" => match type_tag { - TYPE_FLOAT => { - if payload.len() != 8 { - return Err(Error::InvalidData("energy f64 payload truncated".into())); - } - mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))); - } - TYPE_FLOAT32 => { - if payload.len() != 4 { - return Err(Error::InvalidData("energy f32 payload truncated".into())); - } - mol.energy = Some(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))); - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported energy type tag {}", - type_tag - ))); - } - }, - "forces" => match type_tag { - TYPE_VEC3_F32 => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), - TYPE_VEC3_F64 => mol.forces = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), - _ => { - return Err(Error::InvalidData(format!( - "Unsupported forces type tag {}", - type_tag - ))); - } - }, - "charges" => match type_tag { - TYPE_F32_ARRAY => { - mol.charges = Some(FloatArrayData::F32(decode_f32_array(payload)?)) - } - TYPE_F64_ARRAY => { - mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)) - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported charges type tag {}", - type_tag - ))); - } - }, - "velocities" => match type_tag { - TYPE_VEC3_F32 => { - mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)) - } - TYPE_VEC3_F64 => { - mol.velocities = Some(Vec3Data::F64(decode_vec3_f64(payload)?)) - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported velocities type tag {}", - type_tag - ))); - } - }, - "cell" => match type_tag { - TYPE_MAT3X3_F32 => mol.cell = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), - TYPE_MAT3X3_F64 => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - _ => { - return Err(Error::InvalidData(format!( - "Unsupported cell type tag {}", - type_tag - ))); - } - }, - "pbc" => { - if payload.len() < 3 { - return Err(Error::InvalidData("pbc payload truncated".into())); - } - mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); - } - "stress" => match type_tag { - TYPE_MAT3X3_F32 => { - mol.stress = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)) - } - TYPE_MAT3X3_F64 => { - mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)) - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported stress type tag {}", - type_tag - ))); - } - }, - "name" => { - mol.name = Some( - std::str::from_utf8(payload) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? - .to_string(), - ) - } - _ => {} - }, + KIND_BUILTIN => decode_builtin_section(&mut mol, key, type_tag, payload)?, KIND_ATOM_PROP => { mol.atom_properties .insert(key.to_string(), decode_property_value(type_tag, payload)?); @@ -988,15 +865,13 @@ fn deserialize_molecule_soa_v3(bytes: &[u8]) -> Result { Ok(mol) } -pub(super) fn deserialize_molecule_soa(bytes: &[u8], record_format: u32) -> Result { - match record_format { - RECORD_FORMAT_SOA_V2 => deserialize_molecule_soa_v2(bytes), - RECORD_FORMAT_SOA_V3 => deserialize_molecule_soa_v3(bytes), - _ => Err(Error::InvalidData(format!( - "Unsupported record format {}", - record_format - ))), - } +pub(super) fn deserialize_molecule_soa( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + deserialize_molecule_soa_with_positions(bytes, positions_type) } fn schema_type_tag_elem_bytes(tag: u8) -> Result { @@ -1073,87 +948,90 @@ fn schema_entry( }) } -fn parse_record_schema_v2(bytes: &[u8]) -> Result { - if bytes.len() < 6 { - return Err(Error::InvalidData("SOA record too small".into())); - } +pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { + let n_atoms = molecule.len(); + let mut schema = SchemaLock { + positions_type: Some(positions_type_from_molecule(molecule)), + sections: BTreeMap::new(), + }; - let mut pos = 0usize; - let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; + let mut insert = |kind: u8, key: &str, type_tag: u8, payload_len: usize| -> Result<()> { + let entry = schema_entry(kind, key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key.to_string()), entry); + Ok(()) + }; - let positions_end = pos - .checked_add( - n_atoms - .checked_mul(12) - .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, - ) - .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at positions".into(), - )); + if let Some(charges) = &molecule.charges { + let (type_tag, payload_len) = match charges { + FloatArrayData::F32(values) => (TYPE_F32_ARRAY, values.len() * 4), + FloatArrayData::F64(values) => (TYPE_F64_ARRAY, values.len() * 8), + }; + insert(KIND_BUILTIN, "charges", type_tag, payload_len)?; } - pos = positions_end; - - let z_end = pos - .checked_add(n_atoms) - .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; - if z_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at atomic_numbers".into(), - )); + if let Some(cell) = &molecule.cell { + let (type_tag, payload_len) = match cell { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "cell", type_tag, payload_len)?; } - pos = z_end; - - if pos + 2 > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at n_sections".into(), - )); + if let Some(energy) = &molecule.energy { + let (type_tag, payload_len) = match energy { + FloatScalarData::F32(_) => (TYPE_FLOAT32, 4), + FloatScalarData::F64(_) => (TYPE_FLOAT, 8), + }; + insert(KIND_BUILTIN, "energy", type_tag, payload_len)?; + } + if let Some(forces) = &molecule.forces { + let (type_tag, payload_len) = match forces { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "forces", type_tag, payload_len)?; + } + if let Some(name) = &molecule.name { + insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; + } + if molecule.pbc.is_some() { + insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; + } + if let Some(stress) = &molecule.stress { + let (type_tag, payload_len) = match stress { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "stress", type_tag, payload_len)?; + } + if let Some(velocities) = &molecule.velocities { + let (type_tag, payload_len) = match velocities { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "velocities", type_tag, payload_len)?; } - let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; - pos += 2; - - let mut schema = SchemaLock { - positions_type: Some(TYPE_VEC3_F32), - sections: BTreeMap::new(), - }; - for _ in 0..n_sections { - if pos + 7 > bytes.len() { - return Err(Error::InvalidData("SOA section header truncated".into())); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(Error::InvalidData("SOA section key truncated".into())); - } - let key = std::str::from_utf8(&bytes[pos..pos + key_len]) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? - .to_string(); - pos += key_len; - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - let payload_end = pos - .checked_add(payload_len) - .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; - if payload_end > bytes.len() { - return Err(Error::InvalidData("SOA section payload truncated".into())); - } - pos = payload_end; - let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; - schema.sections.insert((kind, key), entry); + for (key, value) in &molecule.atom_properties { + insert( + KIND_ATOM_PROP, + key, + property_value_type_tag(value), + property_value_to_bytes(value).len(), + )?; + } + for (key, value) in &molecule.properties { + insert( + KIND_MOL_PROP, + key, + property_value_type_tag(value), + property_value_to_bytes(value).len(), + )?; } Ok(schema) } -fn parse_record_schema_v3(bytes: &[u8]) -> Result { - if bytes.len() < 7 { +fn parse_record_schema_with_positions(bytes: &[u8], positions_type: u8) -> Result { + if bytes.len() < 6 { return Err(Error::InvalidData("SOA record too small".into())); } @@ -1161,23 +1039,10 @@ fn parse_record_schema_v3(bytes: &[u8]) -> Result { let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; pos += 4; - let positions_type = bytes[pos]; - pos += 1; - let positions_elem_bytes = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(Error::InvalidData(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; - let positions_end = pos .checked_add( n_atoms - .checked_mul(positions_elem_bytes) + .checked_mul(positions_stride(positions_type)?) .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, ) .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; @@ -1244,15 +1109,13 @@ fn parse_record_schema_v3(bytes: &[u8]) -> Result { Ok(schema) } -pub(super) fn record_schema(bytes: &[u8], record_format: u32) -> Result { - match record_format { - RECORD_FORMAT_SOA_V2 => parse_record_schema_v2(bytes), - RECORD_FORMAT_SOA_V3 => parse_record_schema_v3(bytes), - _ => Err(Error::InvalidData(format!( - "Unsupported record format {}", - record_format - ))), - } +pub(super) fn record_schema( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + parse_record_schema_with_positions(bytes, positions_type) } pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { From c3aa0e9b80cfbcfb9c7755c7299f89cd3d03836c Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 23:31:42 +0200 Subject: [PATCH 4/9] fix: restore v2 compatibility and split python soa parsing --- atompack-py/src/lib.rs | 915 +--------------------------- atompack-py/src/soa.rs | 929 +++++++++++++++++++++++++++++ atompack-py/tests/test_database.py | 28 + atompack/src/storage/header.rs | 35 +- atompack/src/storage/mod.rs | 97 ++- atompack/src/storage/schema.rs | 444 ++++++++++++++ atompack/src/storage/soa.rs | 557 +++-------------- 7 files changed, 1605 insertions(+), 1400 deletions(-) create mode 100644 atompack-py/src/soa.rs create mode 100644 atompack/src/storage/schema.rs diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index df83c8b..e20e0d1 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -117,916 +117,13 @@ const TYPE_MAT3X3_F32: u8 = 12; const RECORD_FORMAT_SOA_V2: u32 = 2; const RECORD_FORMAT_SOA_V3: u32 = 3; -/// A single parsed section reference (zero-copy into decompressed bytes). -#[derive(Clone)] -struct SectionRef<'a> { - kind: u8, - key: &'a str, - type_tag: u8, - payload: &'a [u8], -} - -/// Per-molecule extracted data (references into decompressed bytes). -struct MolData<'a> { - n_atoms: usize, - positions_bytes: &'a [u8], - atomic_numbers_bytes: &'a [u8], // n_atoms - sections: Vec>, -} - -#[derive(Clone)] -struct SectionSchema { - kind: u8, - key: String, - type_tag: u8, - per_atom: bool, - elem_bytes: usize, - slot_bytes: usize, -} - -/// Parse SOA format bytes into MolData without allocation. -/// -/// Layout: -/// [n_atoms:u32][positions:n*(12|24)][atomic_numbers:n] -/// [n_sections:u16] -/// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] -fn parse_mol_fast_soa( - bytes: &[u8], - record_format: u32, - positions_type_hint: Option, -) -> atompack::Result> { - let mut pos = 0usize; - let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; - let positions_type = match record_format { - RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, - RECORD_FORMAT_SOA_V3 => positions_type_hint - .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, - _ => { - return Err(invalid_data(format!( - "Unsupported record format {}", - record_format - ))); - } - }; - let positions_stride = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(invalid_data(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; - let positions_len = n_atoms - .checked_mul(positions_stride) - .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; - let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; - let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; - let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; - - let mut sections = Vec::with_capacity(n_sections); - for _ in 0..n_sections { - let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; - let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; - let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; - let key = std::str::from_utf8(key_bytes) - .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; - let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; - let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; - let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; - sections.push(SectionRef { - kind, - key, - type_tag, - payload, - }); - } - - Ok(MolData { - n_atoms, - positions_bytes, - atomic_numbers_bytes, - sections, - }) -} - -fn section_schema_from_ref( - section: &SectionRef<'_>, - n_atoms: usize, -) -> atompack::Result { - let per_atom = is_per_atom(section.kind, section.key, section.type_tag); - let elem_bytes = match section.type_tag { - TYPE_STRING => 0, - tag if per_atom => { - let elem_bytes = type_tag_elem_bytes(tag); - if elem_bytes == 0 { - return Err(invalid_data(format!( - "Unsupported per-atom section type tag {} for key '{}'", - tag, section.key - ))); - } - elem_bytes - } - TYPE_FLOAT | TYPE_INT => 8, - TYPE_FLOAT32 => 4, - TYPE_BOOL3 => 3, - TYPE_MAT3X3_F32 => 36, - TYPE_MAT3X3_F64 => 72, - _ => section.payload.len(), - }; - let slot_bytes = if section.type_tag == TYPE_STRING { - 0 - } else if per_atom { - elem_bytes - } else { - section.payload.len() - }; - - validate_section_payload(section, per_atom, elem_bytes, slot_bytes, n_atoms)?; - - Ok(SectionSchema { - kind: section.kind, - key: section.key.to_string(), - type_tag: section.type_tag, - per_atom, - elem_bytes, - slot_bytes, - }) -} - -fn validate_section_payload( - section: &SectionRef<'_>, - per_atom: bool, - elem_bytes: usize, - slot_bytes: usize, - n_atoms: usize, -) -> atompack::Result<()> { - match section.type_tag { - TYPE_STRING => { - std::str::from_utf8(section.payload) - .map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?; - if per_atom { - return Err(invalid_data(format!( - "String section '{}' cannot be per-atom in flat extraction", - section.key - ))); - } - } - TYPE_FLOAT | TYPE_INT | TYPE_FLOAT32 | TYPE_BOOL3 | TYPE_MAT3X3_F32 | TYPE_MAT3X3_F64 => { - if section.payload.len() != slot_bytes { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - slot_bytes - ))); - } - } - _ if per_atom => { - let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { - invalid_data(format!("Section '{}' payload length overflow", section.key)) - })?; - if section.payload.len() != expected { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - expected - ))); - } - } - _ => { - if slot_bytes != 0 && section.payload.len() != slot_bytes { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - slot_bytes - ))); - } - if elem_bytes != 0 && !section.payload.len().is_multiple_of(elem_bytes) { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} for element size {}", - section.key, - section.payload.len(), - elem_bytes - ))); - } - } - } - Ok(()) -} - -/// Element size in bytes for a given type tag. Returns 0 for variable-length types. -fn type_tag_elem_bytes(tag: u8) -> usize { - match tag { - TYPE_FLOAT => 8, - TYPE_INT => 8, - TYPE_STRING => 0, // variable - TYPE_F64_ARRAY => 8, - TYPE_VEC3_F32 => 12, - TYPE_I64_ARRAY => 8, - TYPE_F32_ARRAY => 4, - TYPE_VEC3_F64 => 24, - TYPE_I32_ARRAY => 4, - TYPE_BOOL3 => 3, - TYPE_FLOAT32 => 4, - TYPE_MAT3X3_F32 => 36, - TYPE_MAT3X3_F64 => 72, - _ => 0, - } -} - -/// Whether a section with the given kind/key/type_tag is per-atom (vs per-molecule). -fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { - match kind { - KIND_ATOM_PROP => true, - KIND_MOL_PROP => false, - KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), - _ => false, - } -} - -/// Lightweight section descriptor — stores byte offsets into the parent `bytes` buffer. -/// Key is NOT parsed eagerly; use `key()` to read it lazily from `bytes`. -#[derive(Clone, Copy)] -struct LazySection { - kind: u8, - key_start: usize, - key_len: u8, - type_tag: u8, - payload_start: usize, - payload_len: usize, -} - -/// Byte-offset pair for a known builtin section (payload_start, payload_len, type_tag). -type BuiltinSlot = (usize, usize, u8); - -enum SoaBytes { - Owned(Vec), - Shared(SharedMmapBytes), -} - -impl SoaBytes { - #[inline] - fn as_slice(&self) -> &[u8] { - match self { - Self::Owned(bytes) => bytes, - Self::Shared(bytes) => bytes.as_slice(), - } - } -} - -impl std::ops::Deref for SoaBytes { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - self.as_slice() - } -} - -struct SoaMoleculeView { - bytes: SoaBytes, - n_atoms: usize, - positions_type: u8, - positions_start: usize, - positions_len: usize, - atomic_numbers_start: usize, - // Known builtins — zero-alloc, set during from_bytes - forces: Option, - energy: Option, - cell: Option, - stress: Option, - charges: Option, - velocities: Option, - pbc: Option, - name: Option, - // Custom properties — lazy, no String parsing until accessed - custom_sections: Vec, -} - -impl SoaMoleculeView { - /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. - fn from_storage_inner( - bytes: SoaBytes, - record_format: u32, - positions_type_hint: Option, - ) -> atompack::Result { - if bytes.len() < 6 { - return Err(invalid_data("SOA record too small")); - } - - let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; - let mut pos = 4usize; - let positions_type = match record_format { - RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, - RECORD_FORMAT_SOA_V3 => positions_type_hint - .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, - _ => { - return Err(invalid_data(format!( - "Unsupported record format {}", - record_format - ))); - } - }; - let positions_stride = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(invalid_data(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; - let positions_start = pos; - let positions_len = n_atoms - .checked_mul(positions_stride) - .ok_or_else(|| invalid_data("SOA positions overflow"))?; - pos = pos - .checked_add(positions_len) - .ok_or_else(|| invalid_data("SOA positions overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA record truncated at positions")); - } - - let atomic_numbers_start = pos; - pos = pos - .checked_add(n_atoms) - .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA record truncated at atomic_numbers")); - } - - let n_sections = - u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; - pos += 2; - - let mut forces = None; - let mut energy = None; - let mut cell = None; - let mut stress = None; - let mut charges = None; - let mut velocities = None; - let mut pbc = None; - let mut name = None; - let mut custom_sections = Vec::new(); - - for _ in 0..n_sections { - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(invalid_data("SOA section key truncated")); - } - let key_start = pos; - pos += key_len; - if pos + 5 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(slice_to_array( - &bytes[pos..pos + 4], - "SOA section payload length", - )?) as usize; - pos += 4; - let payload_start = pos; - pos = pos - .checked_add(payload_len) - .ok_or_else(|| invalid_data("SOA section payload overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA section payload truncated")); - } - - let key_bytes = &bytes[key_start..key_start + key_len]; - if kind == KIND_BUILTIN { - let slot = (payload_start, payload_len, type_tag); - match key_bytes { - b"forces" => forces = Some(slot), - b"energy" => energy = Some(slot), - b"cell" => cell = Some(slot), - b"stress" => stress = Some(slot), - b"charges" => charges = Some(slot), - b"velocities" => velocities = Some(slot), - b"pbc" => pbc = Some(slot), - b"name" => name = Some(slot), - _ => { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - } else { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - - Ok(Self { - bytes, - n_atoms, - positions_type, - positions_start, - positions_len, - atomic_numbers_start, - forces, - energy, - cell, - stress, - charges, - velocities, - pbc, - name, - custom_sections, - }) - } - - fn from_bytes_inner( - bytes: Vec, - record_format: u32, - positions_type_hint: Option, - ) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Owned(bytes), record_format, positions_type_hint) - } - - fn from_shared_bytes_inner( - bytes: SharedMmapBytes, - record_format: u32, - positions_type_hint: Option, - ) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Shared(bytes), record_format, positions_type_hint) - } - - /// Thin wrapper for call sites that need PyResult. - fn from_bytes( - bytes: Vec, - record_format: u32, - positions_type_hint: Option, - ) -> PyResult { - Self::from_bytes_inner(bytes, record_format, positions_type_hint) - .map_err(|e| PyValueError::new_err(format!("{}", e))) - } - - fn positions_bytes(&self) -> &[u8] { - &self.bytes[self.positions_start..self.positions_start + self.positions_len] - } - - fn atomic_numbers_bytes(&self) -> &[u8] { - &self.bytes[self.atomic_numbers_start..self.atomic_numbers_start + self.n_atoms] - } +mod soa; - #[inline] - fn builtin_payload(&self, slot: BuiltinSlot) -> &[u8] { - &self.bytes[slot.0..slot.0 + slot.1] - } - - fn lazy_section_key(&self, s: &LazySection) -> PyResult<&str> { - std::str::from_utf8(&self.bytes[s.key_start..s.key_start + s.key_len as usize]) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in section key")) - } - - fn lazy_section_payload(&self, s: &LazySection) -> &[u8] { - &self.bytes[s.payload_start..s.payload_start + s.payload_len] - } - - fn find_custom_section(&self, kind: u8, key: &str) -> PyResult> { - for section in &self.custom_sections { - if section.kind == kind && self.lazy_section_key(section)? == key { - return Ok(Some(section)); - } - } - Ok(None) - } - - fn property_keys(&self) -> PyResult> { - self.custom_sections - .iter() - .filter(|s| s.kind == KIND_MOL_PROP) - .map(|s| Ok(self.lazy_section_key(s)?.to_string())) - .collect() - } - - fn atom_at(&self, index: usize) -> PyResult> { - if index >= self.n_atoms { - return Ok(None); - } - let atomic_number = self.atomic_numbers_bytes()[index]; - Ok(Some(match self.positions_type { - TYPE_VEC3_F32 => { - let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; - Atom::new( - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), - atomic_number, - ) - } - TYPE_VEC3_F64 => { - let pos = &self.positions_bytes()[index * 24..(index + 1) * 24]; - Atom::new( - f64::from_le_bytes(py_slice_to_array(&pos[0..8], "atom x")?) as f32, - f64::from_le_bytes(py_slice_to_array(&pos[8..16], "atom y")?) as f32, - f64::from_le_bytes(py_slice_to_array(&pos[16..24], "atom z")?) as f32, - atomic_number, - ) - } - other => { - return Err(PyValueError::new_err(format!( - "Unsupported positions type tag {}", - other - ))); - } - })) - } - - fn energy(&self) -> PyResult> { - match self.energy { - Some(slot) => match slot.2 { - TYPE_FLOAT => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), - TYPE_FLOAT32 => Ok(Some(read_f32_scalar(self.builtin_payload(slot))? as f64)), - other => Err(PyValueError::new_err(format!( - "Unsupported energy type tag {}", - other - ))), - }, - None => Ok(None), - } - } - - fn pbc(&self) -> PyResult> { - match self.pbc { - Some(slot) => { - let payload = self.builtin_payload(slot); - if payload.len() != 3 { - return Err(PyValueError::new_err("Invalid pbc payload length")); - } - Ok(Some((payload[0] != 0, payload[1] != 0, payload[2] != 0))) - } - None => Ok(None), - } - } - - fn materialize(&self) -> PyResult { - let atomic_numbers = self.atomic_numbers_bytes().to_vec(); - let mut molecule = match self.positions_type { - TYPE_VEC3_F32 => { - let positions: Vec<[f32; 3]> = (0..self.n_atoms) - .map(|i| { - let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; - Ok([ - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), - ]) - }) - .collect::>()?; - Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)? - } - TYPE_VEC3_F64 => { - let positions: Vec<[f64; 3]> = (0..self.n_atoms) - .map(|i| { - let pos = &self.positions_bytes()[i * 24..(i + 1) * 24]; - Ok([ - f64::from_le_bytes(py_slice_to_array(&pos[0..8], "position x")?), - f64::from_le_bytes(py_slice_to_array(&pos[8..16], "position y")?), - f64::from_le_bytes(py_slice_to_array(&pos[16..24], "position z")?), - ]) - }) - .collect::>()?; - Molecule::new_f64(positions, atomic_numbers).map_err(PyValueError::new_err)? - } - other => { - return Err(PyValueError::new_err(format!( - "Unsupported positions type tag {}", - other - ))); - } - }; - - // Builtins - if let Some(slot) = self.charges { - molecule.charges = Some(match slot.2 { - TYPE_F32_ARRAY => { - FloatArrayData::F32(decode_f32_array(self.builtin_payload(slot))?) - } - TYPE_F64_ARRAY => { - FloatArrayData::F64(decode_f64_array(self.builtin_payload(slot))?) - } - other => { - return Err(PyValueError::new_err(format!( - "Unsupported charges type tag {}", - other - ))); - } - }); - } - if let Some(slot) = self.cell { - molecule.cell = Some(match slot.2 { - TYPE_MAT3X3_F32 => Mat3Data::F32(decode_mat3x3_f32(self.builtin_payload(slot))?), - TYPE_MAT3X3_F64 => Mat3Data::F64(decode_mat3x3_f64(self.builtin_payload(slot))?), - other => { - return Err(PyValueError::new_err(format!( - "Unsupported cell type tag {}", - other - ))); - } - }); - } - if let Some(slot) = self.energy { - molecule.energy = Some(match slot.2 { - TYPE_FLOAT => FloatScalarData::F64(read_f64_scalar(self.builtin_payload(slot))?), - TYPE_FLOAT32 => FloatScalarData::F32(read_f32_scalar(self.builtin_payload(slot))?), - other => { - return Err(PyValueError::new_err(format!( - "Unsupported energy type tag {}", - other - ))); - } - }); - } - if let Some(slot) = self.forces { - molecule.forces = Some(match slot.2 { - TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(self.builtin_payload(slot))?), - TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(self.builtin_payload(slot))?), - other => { - return Err(PyValueError::new_err(format!( - "Unsupported forces type tag {}", - other - ))); - } - }); - } - if let Some(slot) = self.name { - let payload = self.builtin_payload(slot); - molecule.name = Some( - std::str::from_utf8(payload) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in name"))? - .to_string(), - ); - } - if let Some(slot) = self.pbc { - let payload = self.builtin_payload(slot); - if payload.len() != 3 { - return Err(PyValueError::new_err("Invalid pbc payload length")); - } - molecule.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); - } - if let Some(slot) = self.stress { - molecule.stress = Some(match slot.2 { - TYPE_MAT3X3_F32 => Mat3Data::F32(decode_mat3x3_f32(self.builtin_payload(slot))?), - TYPE_MAT3X3_F64 => Mat3Data::F64(decode_mat3x3_f64(self.builtin_payload(slot))?), - other => { - return Err(PyValueError::new_err(format!( - "Unsupported stress type tag {}", - other - ))); - } - }); - } - if let Some(slot) = self.velocities { - molecule.velocities = Some(match slot.2 { - TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(self.builtin_payload(slot))?), - TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(self.builtin_payload(slot))?), - other => { - return Err(PyValueError::new_err(format!( - "Unsupported velocities type tag {}", - other - ))); - } - }); - } - - // Custom properties (lazy — key parsed here only) - for s in &self.custom_sections { - let key = self.lazy_section_key(s)?.to_string(); - let payload = self.lazy_section_payload(s); - match s.kind { - KIND_ATOM_PROP => { - molecule - .atom_properties - .insert(key, decode_property_value(s.type_tag, payload)?); - } - KIND_MOL_PROP => { - molecule - .properties - .insert(key, decode_property_value(s.type_tag, payload)?); - } - _ => {} - } - } - - Ok(molecule) - } -} - -fn read_f64_scalar(payload: &[u8]) -> PyResult { - if payload.len() != 8 { - return Err(PyValueError::new_err("Invalid f64 payload length")); - } - Ok(f64::from_le_bytes(py_slice_to_array( - payload, - "f64 payload", - )?)) -} - -fn read_f32_scalar(payload: &[u8]) -> PyResult { - if payload.len() != 4 { - return Err(PyValueError::new_err("Invalid f32 payload length")); - } - Ok(f32::from_le_bytes(py_slice_to_array( - payload, - "f32 payload", - )?)) -} - -fn read_i64_scalar(payload: &[u8]) -> PyResult { - if payload.len() != 8 { - return Err(PyValueError::new_err("Invalid i64 payload length")); - } - Ok(i64::from_le_bytes(py_slice_to_array( - payload, - "i64 payload", - )?)) -} - -fn decode_f64_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(8) { - return Err(PyValueError::new_err("Invalid f64 array payload length")); - } - payload - .chunks_exact(8) - .map(|chunk| { - Ok(f64::from_le_bytes(py_slice_to_array( - chunk, - "f64 array chunk", - )?)) - }) - .collect() -} - -fn decode_i64_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(8) { - return Err(PyValueError::new_err("Invalid i64 array payload length")); - } - payload - .chunks_exact(8) - .map(|chunk| { - Ok(i64::from_le_bytes(py_slice_to_array( - chunk, - "i64 array chunk", - )?)) - }) - .collect() -} - -fn decode_i32_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(4) { - return Err(PyValueError::new_err("Invalid i32 array payload length")); - } - payload - .chunks_exact(4) - .map(|chunk| { - Ok(i32::from_le_bytes(py_slice_to_array( - chunk, - "i32 array chunk", - )?)) - }) - .collect() -} - -fn decode_f32_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(4) { - return Err(PyValueError::new_err("Invalid f32 array payload length")); - } - payload - .chunks_exact(4) - .map(|chunk| { - Ok(f32::from_le_bytes(py_slice_to_array( - chunk, - "f32 array chunk", - )?)) - }) - .collect() -} - -fn decode_vec3_f32(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(12) { - return Err(PyValueError::new_err("Invalid vec3 payload length")); - } - payload - .chunks_exact(12) - .map(|chunk| { - Ok([ - f32::from_le_bytes(py_slice_to_array(&chunk[0..4], "vec3 x")?), - f32::from_le_bytes(py_slice_to_array(&chunk[4..8], "vec3 y")?), - f32::from_le_bytes(py_slice_to_array(&chunk[8..12], "vec3 z")?), - ]) - }) - .collect() -} - -fn decode_vec3_f64(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(24) { - return Err(PyValueError::new_err("Invalid vec3 payload length")); - } - payload - .chunks_exact(24) - .map(|chunk| { - Ok([ - f64::from_le_bytes(py_slice_to_array(&chunk[0..8], "vec3 x")?), - f64::from_le_bytes(py_slice_to_array(&chunk[8..16], "vec3 y")?), - f64::from_le_bytes(py_slice_to_array(&chunk[16..24], "vec3 z")?), - ]) - }) - .collect() -} - -fn decode_mat3x3_f64(payload: &[u8]) -> PyResult<[[f64; 3]; 3]> { - if payload.len() != 72 { - return Err(PyValueError::new_err("Invalid mat3x3 payload length")); - } - Ok([ - [ - f64::from_le_bytes(py_slice_to_array(&payload[0..8], "mat3x3 [0][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[8..16], "mat3x3 [0][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[16..24], "mat3x3 [0][2]")?), - ], - [ - f64::from_le_bytes(py_slice_to_array(&payload[24..32], "mat3x3 [1][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[32..40], "mat3x3 [1][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[40..48], "mat3x3 [1][2]")?), - ], - [ - f64::from_le_bytes(py_slice_to_array(&payload[48..56], "mat3x3 [2][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[56..64], "mat3x3 [2][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[64..72], "mat3x3 [2][2]")?), - ], - ]) -} - -fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { - if payload.len() != 36 { - return Err(PyValueError::new_err("Invalid mat3x3 payload length")); - } - Ok([ - [ - f32::from_le_bytes(py_slice_to_array(&payload[0..4], "mat3x3 [0][0]")?), - f32::from_le_bytes(py_slice_to_array(&payload[4..8], "mat3x3 [0][1]")?), - f32::from_le_bytes(py_slice_to_array(&payload[8..12], "mat3x3 [0][2]")?), - ], - [ - f32::from_le_bytes(py_slice_to_array(&payload[12..16], "mat3x3 [1][0]")?), - f32::from_le_bytes(py_slice_to_array(&payload[16..20], "mat3x3 [1][1]")?), - f32::from_le_bytes(py_slice_to_array(&payload[20..24], "mat3x3 [1][2]")?), - ], - [ - f32::from_le_bytes(py_slice_to_array(&payload[24..28], "mat3x3 [2][0]")?), - f32::from_le_bytes(py_slice_to_array(&payload[28..32], "mat3x3 [2][1]")?), - f32::from_le_bytes(py_slice_to_array(&payload[32..36], "mat3x3 [2][2]")?), - ], - ]) -} - -fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { - Ok(match type_tag { - TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), - TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?), - TYPE_STRING => PropertyValue::String( - std::str::from_utf8(payload) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in string property"))? - .to_string(), - ), - TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), - TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), - TYPE_I64_ARRAY => PropertyValue::IntArray(decode_i64_array(payload)?), - TYPE_F32_ARRAY => PropertyValue::Float32Array(decode_f32_array(payload)?), - TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), - TYPE_I32_ARRAY => PropertyValue::Int32Array(decode_i32_array(payload)?), - _ => { - return Err(PyValueError::new_err(format!( - "Unsupported property type tag {}", - type_tag - ))); - } - }) -} +pub(crate) use self::soa::{ + LazySection, SectionRef, SectionSchema, SoaMoleculeView, is_per_atom, parse_mol_fast_soa, + read_f64_scalar, read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes, + validate_section_payload, +}; mod database; mod molecule; diff --git a/atompack-py/src/soa.rs b/atompack-py/src/soa.rs new file mode 100644 index 0000000..eb2d2cb --- /dev/null +++ b/atompack-py/src/soa.rs @@ -0,0 +1,929 @@ +use super::*; + +/// A single parsed section reference (zero-copy into decompressed bytes). +#[derive(Clone)] +pub(crate) struct SectionRef<'a> { + pub(crate) kind: u8, + pub(crate) key: &'a str, + pub(crate) type_tag: u8, + pub(crate) payload: &'a [u8], +} + +/// Per-molecule extracted data (references into decompressed bytes). +pub(crate) struct MolData<'a> { + pub(crate) n_atoms: usize, + pub(crate) positions_bytes: &'a [u8], + pub(crate) atomic_numbers_bytes: &'a [u8], + pub(crate) sections: Vec>, +} + +#[derive(Clone)] +pub(crate) struct SectionSchema { + pub(crate) kind: u8, + pub(crate) key: String, + pub(crate) type_tag: u8, + pub(crate) per_atom: bool, + pub(crate) elem_bytes: usize, + pub(crate) slot_bytes: usize, +} + +/// Parse SOA format bytes into MolData without allocation. +/// +/// Layout: +/// [n_atoms:u32][positions:n*(12|24)][atomic_numbers:n] +/// [n_sections:u16] +/// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] +pub(crate) fn parse_mol_fast_soa( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> atompack::Result> { + let mut pos = 0usize; + let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let positions_len = n_atoms + .checked_mul(positions_stride) + .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; + let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; + let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; + let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; + + let mut sections = Vec::with_capacity(n_sections); + for _ in 0..n_sections { + let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; + let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; + let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; + let key = std::str::from_utf8(key_bytes) + .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; + let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; + let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; + let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; + sections.push(SectionRef { + kind, + key, + type_tag, + payload, + }); + } + + Ok(MolData { + n_atoms, + positions_bytes, + atomic_numbers_bytes, + sections, + }) +} + +pub(crate) fn section_schema_from_ref( + section: &SectionRef<'_>, + n_atoms: usize, +) -> atompack::Result { + let per_atom = is_per_atom(section.kind, section.key, section.type_tag); + let elem_bytes = match section.type_tag { + TYPE_STRING => 0, + tag if per_atom => { + let elem_bytes = type_tag_elem_bytes(tag); + if elem_bytes == 0 { + return Err(invalid_data(format!( + "Unsupported per-atom section type tag {} for key '{}'", + tag, section.key + ))); + } + elem_bytes + } + TYPE_FLOAT | TYPE_INT => 8, + TYPE_FLOAT32 => 4, + TYPE_BOOL3 => 3, + TYPE_MAT3X3_F32 => 36, + TYPE_MAT3X3_F64 => 72, + _ => section.payload.len(), + }; + let slot_bytes = if section.type_tag == TYPE_STRING { + 0 + } else if per_atom { + elem_bytes + } else { + section.payload.len() + }; + + validate_section_payload(section, per_atom, elem_bytes, slot_bytes, n_atoms)?; + + Ok(SectionSchema { + kind: section.kind, + key: section.key.to_string(), + type_tag: section.type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +pub(crate) fn validate_section_payload( + section: &SectionRef<'_>, + per_atom: bool, + elem_bytes: usize, + slot_bytes: usize, + n_atoms: usize, +) -> atompack::Result<()> { + match section.type_tag { + TYPE_STRING => { + std::str::from_utf8(section.payload) + .map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?; + if per_atom { + return Err(invalid_data(format!( + "String section '{}' cannot be per-atom in flat extraction", + section.key + ))); + } + } + TYPE_FLOAT | TYPE_INT | TYPE_FLOAT32 | TYPE_BOOL3 | TYPE_MAT3X3_F32 | TYPE_MAT3X3_F64 => { + if section.payload.len() != slot_bytes { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + slot_bytes + ))); + } + } + _ if per_atom => { + let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { + invalid_data(format!("Section '{}' payload length overflow", section.key)) + })?; + if section.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + expected + ))); + } + } + _ => { + if slot_bytes != 0 && section.payload.len() != slot_bytes { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + slot_bytes + ))); + } + if elem_bytes != 0 && !section.payload.len().is_multiple_of(elem_bytes) { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} for element size {}", + section.key, + section.payload.len(), + elem_bytes + ))); + } + } + } + Ok(()) +} + +/// Element size in bytes for a given type tag. Returns 0 for variable-length types. +pub(crate) fn type_tag_elem_bytes(tag: u8) -> usize { + match tag { + TYPE_FLOAT => 8, + TYPE_INT => 8, + TYPE_STRING => 0, + TYPE_F64_ARRAY => 8, + TYPE_VEC3_F32 => 12, + TYPE_I64_ARRAY => 8, + TYPE_F32_ARRAY => 4, + TYPE_VEC3_F64 => 24, + TYPE_I32_ARRAY => 4, + TYPE_BOOL3 => 3, + TYPE_FLOAT32 => 4, + TYPE_MAT3X3_F32 => 36, + TYPE_MAT3X3_F64 => 72, + _ => 0, + } +} + +/// Whether a section with the given kind/key/type_tag is per-atom (vs per-molecule). +pub(crate) fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn py_decode_float_array_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> PyResult { + match type_tag { + TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> PyResult { + match type_tag { + TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_float_scalar_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> PyResult { + match type_tag { + TYPE_FLOAT => Ok(FloatScalarData::F64(read_f64_scalar(payload)?)), + TYPE_FLOAT32 => Ok(FloatScalarData::F32(read_f32_scalar(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> PyResult { + match type_tag { + TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +/// Lightweight section descriptor — stores byte offsets into the parent `bytes` buffer. +/// Key is NOT parsed eagerly; use `key()` to read it lazily from `bytes`. +#[derive(Clone, Copy)] +pub(crate) struct LazySection { + pub(crate) kind: u8, + pub(crate) key_start: usize, + pub(crate) key_len: u8, + pub(crate) type_tag: u8, + pub(crate) payload_start: usize, + pub(crate) payload_len: usize, +} + +/// Byte-offset pair for a known builtin section (payload_start, payload_len, type_tag). +pub(crate) type BuiltinSlot = (usize, usize, u8); + +enum SoaBytes { + Owned(Vec), + Shared(SharedMmapBytes), +} + +impl SoaBytes { + #[inline] + fn as_slice(&self) -> &[u8] { + match self { + Self::Owned(bytes) => bytes, + Self::Shared(bytes) => bytes.as_slice(), + } + } +} + +impl std::ops::Deref for SoaBytes { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +pub(crate) struct SoaMoleculeView { + bytes: SoaBytes, + pub(crate) n_atoms: usize, + pub(crate) positions_type: u8, + positions_start: usize, + positions_len: usize, + atomic_numbers_start: usize, + pub(crate) forces: Option, + pub(crate) energy: Option, + pub(crate) cell: Option, + pub(crate) stress: Option, + pub(crate) charges: Option, + pub(crate) velocities: Option, + pbc: Option, + name: Option, + pub(crate) custom_sections: Vec, +} + +impl SoaMoleculeView { + /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. + fn from_storage_inner( + bytes: SoaBytes, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + if bytes.len() < 6 { + return Err(invalid_data("SOA record too small")); + } + + let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; + let mut pos = 4usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let positions_start = pos; + let positions_len = n_atoms + .checked_mul(positions_stride) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + pos = pos + .checked_add(positions_len) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA record truncated at positions")); + } + + let atomic_numbers_start = pos; + pos = pos + .checked_add(n_atoms) + .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA record truncated at atomic_numbers")); + } + + let n_sections = + u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; + pos += 2; + + let mut forces = None; + let mut energy = None; + let mut cell = None; + let mut stress = None; + let mut charges = None; + let mut velocities = None; + let mut pbc = None; + let mut name = None; + let mut custom_sections = Vec::new(); + + for _ in 0..n_sections { + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(invalid_data("SOA section key truncated")); + } + let key_start = pos; + pos += key_len; + if pos + 5 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(slice_to_array( + &bytes[pos..pos + 4], + "SOA section payload length", + )?) as usize; + pos += 4; + let payload_start = pos; + pos = pos + .checked_add(payload_len) + .ok_or_else(|| invalid_data("SOA section payload overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA section payload truncated")); + } + + let key_bytes = &bytes[key_start..key_start + key_len]; + if kind == KIND_BUILTIN { + let slot = (payload_start, payload_len, type_tag); + match key_bytes { + b"forces" => forces = Some(slot), + b"energy" => energy = Some(slot), + b"cell" => cell = Some(slot), + b"stress" => stress = Some(slot), + b"charges" => charges = Some(slot), + b"velocities" => velocities = Some(slot), + b"pbc" => pbc = Some(slot), + b"name" => name = Some(slot), + _ => { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + } else { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + + Ok(Self { + bytes, + n_atoms, + positions_type, + positions_start, + positions_len, + atomic_numbers_start, + forces, + energy, + cell, + stress, + charges, + velocities, + pbc, + name, + custom_sections, + }) + } + + pub(crate) fn from_bytes_inner( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Owned(bytes), record_format, positions_type_hint) + } + + pub(crate) fn from_shared_bytes_inner( + bytes: SharedMmapBytes, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Shared(bytes), record_format, positions_type_hint) + } + + pub(crate) fn from_bytes( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> PyResult { + Self::from_bytes_inner(bytes, record_format, positions_type_hint) + .map_err(|e| PyValueError::new_err(format!("{}", e))) + } + + pub(crate) fn positions_bytes(&self) -> &[u8] { + &self.bytes[self.positions_start..self.positions_start + self.positions_len] + } + + pub(crate) fn atomic_numbers_bytes(&self) -> &[u8] { + &self.bytes[self.atomic_numbers_start..self.atomic_numbers_start + self.n_atoms] + } + + #[inline] + pub(crate) fn builtin_payload(&self, slot: BuiltinSlot) -> &[u8] { + &self.bytes[slot.0..slot.0 + slot.1] + } + + pub(crate) fn lazy_section_key(&self, s: &LazySection) -> PyResult<&str> { + std::str::from_utf8(&self.bytes[s.key_start..s.key_start + s.key_len as usize]) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in section key")) + } + + pub(crate) fn lazy_section_payload(&self, s: &LazySection) -> &[u8] { + &self.bytes[s.payload_start..s.payload_start + s.payload_len] + } + + pub(crate) fn find_custom_section( + &self, + kind: u8, + key: &str, + ) -> PyResult> { + for section in &self.custom_sections { + if section.kind == kind && self.lazy_section_key(section)? == key { + return Ok(Some(section)); + } + } + Ok(None) + } + + pub(crate) fn property_keys(&self) -> PyResult> { + self.custom_sections + .iter() + .filter(|s| s.kind == KIND_MOL_PROP) + .map(|s| Ok(self.lazy_section_key(s)?.to_string())) + .collect() + } + + pub(crate) fn atom_at(&self, index: usize) -> PyResult> { + if index >= self.n_atoms { + return Ok(None); + } + let atomic_number = self.atomic_numbers_bytes()[index]; + Ok(Some(match self.positions_type { + TYPE_VEC3_F32 => { + let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; + Atom::new( + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), + atomic_number, + ) + } + TYPE_VEC3_F64 => { + let pos = &self.positions_bytes()[index * 24..(index + 1) * 24]; + Atom::new( + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "atom x")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "atom y")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "atom z")?) as f32, + atomic_number, + ) + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + })) + } + + pub(crate) fn energy(&self) -> PyResult> { + match self.energy { + Some(slot) => match slot.2 { + TYPE_FLOAT => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), + TYPE_FLOAT32 => Ok(Some(read_f32_scalar(self.builtin_payload(slot))? as f64)), + other => Err(PyValueError::new_err(format!( + "Unsupported energy type tag {}", + other + ))), + }, + None => Ok(None), + } + } + + pub(crate) fn pbc(&self) -> PyResult> { + match self.pbc { + Some(slot) => { + let payload = self.builtin_payload(slot); + if payload.len() != 3 { + return Err(PyValueError::new_err("Invalid pbc payload length")); + } + Ok(Some((payload[0] != 0, payload[1] != 0, payload[2] != 0))) + } + None => Ok(None), + } + } + + pub(crate) fn materialize(&self) -> PyResult { + let atomic_numbers = self.atomic_numbers_bytes().to_vec(); + let mut molecule = match self.positions_type { + TYPE_VEC3_F32 => { + let positions: Vec<[f32; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; + Ok([ + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + TYPE_VEC3_F64 => { + let positions: Vec<[f64; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 24..(i + 1) * 24]; + Ok([ + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "position x")?), + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "position y")?), + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new_f64(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + }; + + if let Some(slot) = self.charges { + molecule.charges = Some(py_decode_float_array_data( + self.builtin_payload(slot), + slot.2, + "charges", + )?); + } + if let Some(slot) = self.cell { + molecule.cell = Some(py_decode_mat3_data( + self.builtin_payload(slot), + slot.2, + "cell", + )?); + } + if let Some(slot) = self.energy { + molecule.energy = Some(py_decode_float_scalar_data( + self.builtin_payload(slot), + slot.2, + "energy", + )?); + } + if let Some(slot) = self.forces { + molecule.forces = Some(py_decode_vec3_data( + self.builtin_payload(slot), + slot.2, + "forces", + )?); + } + if let Some(slot) = self.name { + let payload = self.builtin_payload(slot); + molecule.name = Some( + std::str::from_utf8(payload) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in name"))? + .to_string(), + ); + } + if let Some(slot) = self.pbc { + let payload = self.builtin_payload(slot); + if payload.len() != 3 { + return Err(PyValueError::new_err("Invalid pbc payload length")); + } + molecule.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); + } + if let Some(slot) = self.stress { + molecule.stress = Some(py_decode_mat3_data( + self.builtin_payload(slot), + slot.2, + "stress", + )?); + } + if let Some(slot) = self.velocities { + molecule.velocities = Some(py_decode_vec3_data( + self.builtin_payload(slot), + slot.2, + "velocities", + )?); + } + + for s in &self.custom_sections { + let key = self.lazy_section_key(s)?.to_string(); + let payload = self.lazy_section_payload(s); + match s.kind { + KIND_ATOM_PROP => { + molecule + .atom_properties + .insert(key, decode_property_value(s.type_tag, payload)?); + } + KIND_MOL_PROP => { + molecule + .properties + .insert(key, decode_property_value(s.type_tag, payload)?); + } + _ => {} + } + } + + Ok(molecule) + } +} + +pub(crate) fn read_f64_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 8 { + return Err(PyValueError::new_err("Invalid f64 payload length")); + } + Ok(f64::from_le_bytes(py_slice_to_array( + payload, + "f64 payload", + )?)) +} + +fn read_f32_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 4 { + return Err(PyValueError::new_err("Invalid f32 payload length")); + } + Ok(f32::from_le_bytes(py_slice_to_array( + payload, + "f32 payload", + )?)) +} + +pub(crate) fn read_i64_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 8 { + return Err(PyValueError::new_err("Invalid i64 payload length")); + } + Ok(i64::from_le_bytes(py_slice_to_array( + payload, + "i64 payload", + )?)) +} + +fn decode_f64_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(8) { + return Err(PyValueError::new_err("Invalid f64 array payload length")); + } + payload + .chunks_exact(8) + .map(|chunk| { + Ok(f64::from_le_bytes(py_slice_to_array( + chunk, + "f64 array chunk", + )?)) + }) + .collect() +} + +fn decode_i64_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(8) { + return Err(PyValueError::new_err("Invalid i64 array payload length")); + } + payload + .chunks_exact(8) + .map(|chunk| { + Ok(i64::from_le_bytes(py_slice_to_array( + chunk, + "i64 array chunk", + )?)) + }) + .collect() +} + +fn decode_i32_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(4) { + return Err(PyValueError::new_err("Invalid i32 array payload length")); + } + payload + .chunks_exact(4) + .map(|chunk| { + Ok(i32::from_le_bytes(py_slice_to_array( + chunk, + "i32 array chunk", + )?)) + }) + .collect() +} + +fn decode_f32_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(4) { + return Err(PyValueError::new_err("Invalid f32 array payload length")); + } + payload + .chunks_exact(4) + .map(|chunk| { + Ok(f32::from_le_bytes(py_slice_to_array( + chunk, + "f32 array chunk", + )?)) + }) + .collect() +} + +fn decode_vec3_f32(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(12) { + return Err(PyValueError::new_err("Invalid vec3 payload length")); + } + payload + .chunks_exact(12) + .map(|chunk| { + Ok([ + f32::from_le_bytes(py_slice_to_array(&chunk[0..4], "vec3 x")?), + f32::from_le_bytes(py_slice_to_array(&chunk[4..8], "vec3 y")?), + f32::from_le_bytes(py_slice_to_array(&chunk[8..12], "vec3 z")?), + ]) + }) + .collect() +} + +fn decode_vec3_f64(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(24) { + return Err(PyValueError::new_err("Invalid vec3 payload length")); + } + payload + .chunks_exact(24) + .map(|chunk| { + Ok([ + f64::from_le_bytes(py_slice_to_array(&chunk[0..8], "vec3 x")?), + f64::from_le_bytes(py_slice_to_array(&chunk[8..16], "vec3 y")?), + f64::from_le_bytes(py_slice_to_array(&chunk[16..24], "vec3 z")?), + ]) + }) + .collect() +} + +fn decode_mat3x3_f64(payload: &[u8]) -> PyResult<[[f64; 3]; 3]> { + if payload.len() != 72 { + return Err(PyValueError::new_err("Invalid mat3x3 payload length")); + } + Ok([ + [ + f64::from_le_bytes(py_slice_to_array(&payload[0..8], "mat3x3 [0][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[8..16], "mat3x3 [0][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[16..24], "mat3x3 [0][2]")?), + ], + [ + f64::from_le_bytes(py_slice_to_array(&payload[24..32], "mat3x3 [1][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[32..40], "mat3x3 [1][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[40..48], "mat3x3 [1][2]")?), + ], + [ + f64::from_le_bytes(py_slice_to_array(&payload[48..56], "mat3x3 [2][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[56..64], "mat3x3 [2][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[64..72], "mat3x3 [2][2]")?), + ], + ]) +} + +fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(PyValueError::new_err("Invalid mat3x3 payload length")); + } + Ok([ + [ + f32::from_le_bytes(py_slice_to_array(&payload[0..4], "mat3x3 [0][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[4..8], "mat3x3 [0][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[8..12], "mat3x3 [0][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[12..16], "mat3x3 [1][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[16..20], "mat3x3 [1][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[20..24], "mat3x3 [1][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[24..28], "mat3x3 [2][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[28..32], "mat3x3 [2][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[32..36], "mat3x3 [2][2]")?), + ], + ]) +} + +fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { + Ok(match type_tag { + TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), + TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?), + TYPE_STRING => PropertyValue::String( + std::str::from_utf8(payload) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in string property"))? + .to_string(), + ), + TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), + TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), + TYPE_I64_ARRAY => PropertyValue::IntArray(decode_i64_array(payload)?), + TYPE_F32_ARRAY => PropertyValue::Float32Array(decode_f32_array(payload)?), + TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), + TYPE_I32_ARRAY => PropertyValue::Int32Array(decode_i32_array(payload)?), + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported property type tag {}", + type_tag + ))); + } + }) +} diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index fbb6619..62daf7e 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -3,6 +3,7 @@ from pathlib import Path import pickle +import zlib import atompack import numpy as np @@ -63,6 +64,33 @@ def test_database_rejects_invalid_compression(tmp_path: Path) -> None: atompack.Database(str(tmp_path / "bad.atp"), compression="definitely-not-a-codec") +def _rewrite_record_format_v2(path: Path) -> None: + header_slot_size = 4096 + for slot_offset in (0, header_slot_size): + with path.open("r+b") as handle: + handle.seek(slot_offset) + slot = bytearray(handle.read(header_slot_size)) + slot[56:60] = (2).to_bytes(4, "little") + checksum = zlib.adler32(slot[:-4]) & 0xFFFFFFFF + slot[-4:] = checksum.to_bytes(4, "little") + handle.seek(slot_offset) + handle.write(slot) + + +def test_database_add_arrays_batch_rejects_v2_incompatible_builtin_dtype(tmp_path: Path) -> None: + path = tmp_path / "batch_arrays_v2_compat.atp" + atompack.Database(str(path)) + _rewrite_record_format_v2(path) + + db = atompack.Database.open(str(path)) + positions = np.array([[[0.0, 0.0, 0.0]]], dtype=np.float32) + atomic_numbers = np.array([[6]], dtype=np.uint8) + cell = np.eye(3, dtype=np.float32)[None, ...] + + with pytest.raises(ValueError, match="record format 2 does not support float32 cell"): + db.add_arrays_batch(positions, atomic_numbers, cell=cell) + + def test_database_roundtrip_from_arrays_with_builtins(tmp_path: Path) -> None: path = tmp_path / "from_arrays_builtins.atp" positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32) diff --git a/atompack/src/storage/header.rs b/atompack/src/storage/header.rs index 2652156..6e9565f 100644 --- a/atompack/src/storage/header.rs +++ b/atompack/src/storage/header.rs @@ -32,20 +32,23 @@ pub(super) fn encode_header_slot(header: Header) -> [u8; HEADER_SLOT_SIZE] { slot[4..8].copy_from_slice(&FILE_FORMAT_VERSION.to_le_bytes()); slot[8..16].copy_from_slice(&header.generation.to_le_bytes()); slot[16..24].copy_from_slice(&header.data_start.to_le_bytes()); - slot[24..32].copy_from_slice(&header.schema_offset.to_le_bytes()); - slot[32..40].copy_from_slice(&header.schema_len.to_le_bytes()); - slot[40..48].copy_from_slice(&header.index_offset.to_le_bytes()); - slot[48..56].copy_from_slice(&header.index_len.to_le_bytes()); - slot[56..64].copy_from_slice(&header.num_molecules.to_le_bytes()); + slot[24..32].copy_from_slice(&header.index_offset.to_le_bytes()); + slot[32..40].copy_from_slice(&header.index_len.to_le_bytes()); + slot[40..48].copy_from_slice(&header.num_molecules.to_le_bytes()); let (compression_type, compression_level) = match header.compression { CompressionType::None => (0u8, 0i32), CompressionType::Lz4 => (1u8, 0i32), CompressionType::Zstd(level) => (2u8, level), }; - slot[64] = compression_type; - slot[68..72].copy_from_slice(&compression_level.to_le_bytes()); - slot[72..76].copy_from_slice(&header.record_format.to_le_bytes()); + slot[48] = compression_type; + slot[52..56].copy_from_slice(&compression_level.to_le_bytes()); + slot[56..60].copy_from_slice(&header.record_format.to_le_bytes()); + + // Keep the legacy v2 header field offsets intact; schema metadata lives in + // bytes that were previously unused. + slot[60..68].copy_from_slice(&header.schema_offset.to_le_bytes()); + slot[68..76].copy_from_slice(&header.schema_len.to_le_bytes()); let checksum = adler32(&slot[..HEADER_SLOT_SIZE - 4]); slot[HEADER_SLOT_SIZE - 4..HEADER_SLOT_SIZE].copy_from_slice(&checksum.to_le_bytes()); @@ -74,14 +77,14 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option CompressionType::None, diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index bda291f..48f8112 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -35,14 +35,16 @@ use std::sync::Arc; mod header; mod index; +mod schema; mod soa; use self::header::{Header, encode_header_slot, read_best_header}; use self::index::{IndexStorage, MoleculeIndex, decode_index, encode_index}; -use self::soa::{ - SchemaLock, arr, decode_schema_lock, deserialize_molecule_soa, encode_schema_lock, - merge_schema_lock, record_schema, schema_from_molecule, serialize_molecule_soa, +use self::schema::{ + SchemaLock, decode_schema_lock, encode_schema_lock, merge_schema_lock, record_schema, + schema_from_molecule, }; +use self::soa::{arr, deserialize_molecule_soa, serialize_molecule_soa}; // --------------------------------------------------------------------------- // Constants @@ -762,6 +764,41 @@ mod tests { use crate::{Atom, FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; use tempfile::NamedTempFile; + fn adler32_for_test(bytes: &[u8]) -> u32 { + const MOD_ADLER: u32 = 65_521; + let mut a: u32 = 1; + let mut b: u32 = 0; + for &byte in bytes { + a = (a + (byte as u32)) % MOD_ADLER; + b = (b + a) % MOD_ADLER; + } + (b << 16) | a + } + + fn encode_legacy_v2_header_slot(header: Header) -> [u8; HEADER_SLOT_SIZE] { + let mut slot = [0u8; HEADER_SLOT_SIZE]; + slot[0..4].copy_from_slice(MAGIC); + slot[4..8].copy_from_slice(&FILE_FORMAT_VERSION.to_le_bytes()); + slot[8..16].copy_from_slice(&header.generation.to_le_bytes()); + slot[16..24].copy_from_slice(&header.data_start.to_le_bytes()); + slot[24..32].copy_from_slice(&header.index_offset.to_le_bytes()); + slot[32..40].copy_from_slice(&header.index_len.to_le_bytes()); + slot[40..48].copy_from_slice(&header.num_molecules.to_le_bytes()); + + let (compression_type, compression_level) = match header.compression { + CompressionType::None => (0u8, 0i32), + CompressionType::Lz4 => (1u8, 0i32), + CompressionType::Zstd(level) => (2u8, level), + }; + slot[48] = compression_type; + slot[52..56].copy_from_slice(&compression_level.to_le_bytes()); + slot[56..60].copy_from_slice(&header.record_format.to_le_bytes()); + + let checksum = adler32_for_test(&slot[..HEADER_SLOT_SIZE - 4]); + slot[HEADER_SLOT_SIZE - 4..HEADER_SLOT_SIZE].copy_from_slice(&checksum.to_le_bytes()); + slot + } + fn molecule_from_atoms(atoms: Vec) -> Molecule { Molecule::from_atoms(atoms) } @@ -816,6 +853,35 @@ mod tests { } } + #[test] + fn test_database_open_legacy_v2_header_layout() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + + let header = Header { + generation: 0, + data_start: HEADER_REGION_SIZE, + num_molecules: 0, + compression: CompressionType::None, + record_format: RECORD_FORMAT_SOA_V2, + schema_offset: 0, + schema_len: 0, + index_offset: 0, + index_len: 0, + }; + let slot = encode_legacy_v2_header_slot(header); + let mut file = File::create(&path).unwrap(); + file.write_all(&slot).unwrap(); + file.write_all(&slot).unwrap(); + file.flush().unwrap(); + file.sync_all().unwrap(); + drop(file); + + let db = AtomDatabase::open(&path).unwrap(); + assert_eq!(db.len(), 0); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + } + #[test] fn test_database_with_forces_and_energy() { let temp = NamedTempFile::new().unwrap(); @@ -1228,6 +1294,31 @@ mod tests { assert!(format!("{}", err).contains("Schema mismatch for section 'spectrum'")); } + #[test] + fn test_add_owned_soa_records_rejects_v2_incompatible_builtin_dtype() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + db.record_format = RECORD_FORMAT_SOA_V2; + + let mut mol = molecule_from_atoms(vec![Atom::new(0.0, 0.0, 0.0, 6)]); + mol.cell = Some(Mat3Data::F32([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ])); + + let bytes = serialize_molecule_soa(&mol, RECORD_FORMAT_SOA_V3).unwrap(); + let err = db + .add_owned_soa_records(vec![(bytes, mol.len() as u32, TYPE_VEC3_F32)]) + .unwrap_err(); + + assert!( + err.to_string() + .contains("record format 2 does not support float32 cell") + ); + } + #[test] fn test_schema_lock_allows_late_optional_builtin() { let temp = NamedTempFile::new().unwrap(); diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs new file mode 100644 index 0000000..b9c036a --- /dev/null +++ b/atompack/src/storage/schema.rs @@ -0,0 +1,444 @@ +use super::soa::{ + arr, positions_stride, property_value_to_bytes, property_value_type_tag, resolve_positions_type, +}; +use super::*; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; +use std::collections::BTreeMap; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct SchemaLock { + pub(super) positions_type: Option, + pub(super) sections: BTreeMap<(u8, String), SchemaEntry>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct SchemaEntry { + pub(super) type_tag: u8, + pub(super) per_atom: bool, + pub(super) elem_bytes: usize, + pub(super) slot_bytes: usize, +} + +const SCHEMA_BLOB_VERSION: u32 = 1; + +fn positions_type_from_molecule(molecule: &Molecule) -> u8 { + match molecule.positions { + Vec3Data::F32(_) => TYPE_VEC3_F32, + Vec3Data::F64(_) => TYPE_VEC3_F64, + } +} + +pub(super) fn encode_schema_lock(lock: &SchemaLock) -> Result> { + let mut buf = Vec::new(); + buf.extend_from_slice(&SCHEMA_BLOB_VERSION.to_le_bytes()); + buf.push(lock.positions_type.unwrap_or(255)); + buf.extend_from_slice(&(lock.sections.len() as u32).to_le_bytes()); + for ((kind, key), entry) in &lock.sections { + let key_len: u16 = key + .len() + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema key '{}' is too long", key)))?; + let elem_bytes: u32 = entry + .elem_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema elem_bytes overflow for '{}'", key)))?; + let slot_bytes: u32 = entry + .slot_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema slot_bytes overflow for '{}'", key)))?; + buf.push(*kind); + buf.push(entry.type_tag); + buf.push(u8::from(entry.per_atom)); + buf.extend_from_slice(&key_len.to_le_bytes()); + buf.extend_from_slice(&elem_bytes.to_le_bytes()); + buf.extend_from_slice(&slot_bytes.to_le_bytes()); + buf.extend_from_slice(key.as_bytes()); + } + Ok(buf) +} + +pub(super) fn decode_schema_lock(bytes: &[u8]) -> Result { + if bytes.len() < 9 { + return Err(Error::InvalidData("Schema blob too small".into())); + } + let version = u32::from_le_bytes(arr(&bytes[0..4])?); + if version != SCHEMA_BLOB_VERSION { + return Err(Error::InvalidData(format!( + "Unsupported schema blob version {}", + version + ))); + } + let positions_type = match bytes[4] { + 255 => None, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => Some(bytes[4]), + other => { + return Err(Error::InvalidData(format!( + "Unsupported schema positions type tag {}", + other + ))); + } + }; + let count = u32::from_le_bytes(arr(&bytes[5..9])?) as usize; + let mut pos = 9usize; + let mut sections = BTreeMap::new(); + for _ in 0..count { + if pos + 13 > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let type_tag = bytes[pos]; + pos += 1; + let per_atom = match bytes[pos] { + 0 => false, + 1 => true, + _ => return Err(Error::InvalidData("Invalid schema per_atom flag".into())), + }; + pos += 1; + let key_len = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + let elem_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let slot_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated at key".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in schema key".into()))? + .to_string(); + pos += key_len; + sections.insert( + (kind, key), + SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }, + ); + } + if pos != bytes.len() { + return Err(Error::InvalidData("Schema blob trailing bytes".into())); + } + Ok(SchemaLock { + positions_type, + sections, + }) +} + +fn schema_type_tag_elem_bytes(tag: u8) -> Result { + match tag { + TYPE_FLOAT => Ok(8), + TYPE_INT => Ok(8), + TYPE_STRING => Ok(0), + TYPE_F64_ARRAY => Ok(8), + TYPE_VEC3_F32 => Ok(12), + TYPE_I64_ARRAY => Ok(8), + TYPE_F32_ARRAY => Ok(4), + TYPE_VEC3_F64 => Ok(24), + TYPE_I32_ARRAY => Ok(4), + TYPE_BOOL3 => Ok(3), + TYPE_MAT3X3_F64 => Ok(72), + TYPE_FLOAT32 => Ok(4), + TYPE_MAT3X3_F32 => Ok(36), + _ => Err(Error::InvalidData(format!( + "Unsupported section type tag {}", + tag + ))), + } +} + +fn schema_is_per_atom(kind: u8, key: &str) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn schema_entry( + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + n_atoms: usize, +) -> Result { + let per_atom = schema_is_per_atom(kind, key); + let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; + let slot_bytes = if type_tag == TYPE_STRING { + 0 + } else if per_atom { + match type_tag { + TYPE_F64_ARRAY | TYPE_I64_ARRAY | TYPE_F32_ARRAY | TYPE_I32_ARRAY => elem_bytes, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => elem_bytes, + _ => payload_len + .checked_div(n_atoms.max(1)) + .unwrap_or(elem_bytes), + } + } else { + payload_len + }; + + if per_atom { + let expected = elem_bytes + .checked_mul(n_atoms) + .ok_or_else(|| Error::InvalidData(format!("Schema overflow for section '{}'", key)))?; + if payload_len != expected { + return Err(Error::InvalidData(format!( + "Section '{}' payload length {} does not match expected {}", + key, payload_len, expected + ))); + } + } + + Ok(SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn validate_builtin_type_tag_for_record_format( + record_format: u32, + key: &str, + type_tag: u8, +) -> Result<()> { + match record_format { + RECORD_FORMAT_SOA_V3 => Ok(()), + RECORD_FORMAT_SOA_V2 => match key { + "charges" if type_tag != TYPE_F64_ARRAY => Err(Error::InvalidData( + "record format 2 does not support float32 charges".into(), + )), + "cell" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 cell".into(), + )), + "energy" if type_tag != TYPE_FLOAT => Err(Error::InvalidData( + "record format 2 does not support float32 energy".into(), + )), + "forces" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 forces".into(), + )), + "stress" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 stress".into(), + )), + "velocities" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 velocities".into(), + )), + _ => Ok(()), + }, + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { + let n_atoms = molecule.len(); + let mut schema = SchemaLock { + positions_type: Some(positions_type_from_molecule(molecule)), + sections: BTreeMap::new(), + }; + + let mut insert = |kind: u8, key: &str, type_tag: u8, payload_len: usize| -> Result<()> { + let entry = schema_entry(kind, key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key.to_string()), entry); + Ok(()) + }; + + if let Some(charges) = &molecule.charges { + let (type_tag, payload_len) = match charges { + FloatArrayData::F32(values) => (TYPE_F32_ARRAY, values.len() * 4), + FloatArrayData::F64(values) => (TYPE_F64_ARRAY, values.len() * 8), + }; + insert(KIND_BUILTIN, "charges", type_tag, payload_len)?; + } + if let Some(cell) = &molecule.cell { + let (type_tag, payload_len) = match cell { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "cell", type_tag, payload_len)?; + } + if let Some(energy) = &molecule.energy { + let (type_tag, payload_len) = match energy { + FloatScalarData::F32(_) => (TYPE_FLOAT32, 4), + FloatScalarData::F64(_) => (TYPE_FLOAT, 8), + }; + insert(KIND_BUILTIN, "energy", type_tag, payload_len)?; + } + if let Some(forces) = &molecule.forces { + let (type_tag, payload_len) = match forces { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "forces", type_tag, payload_len)?; + } + if let Some(name) = &molecule.name { + insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; + } + if molecule.pbc.is_some() { + insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; + } + if let Some(stress) = &molecule.stress { + let (type_tag, payload_len) = match stress { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "stress", type_tag, payload_len)?; + } + if let Some(velocities) = &molecule.velocities { + let (type_tag, payload_len) = match velocities { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "velocities", type_tag, payload_len)?; + } + + for (key, value) in &molecule.atom_properties { + insert( + KIND_ATOM_PROP, + key, + property_value_type_tag(value), + property_value_to_bytes(value).len(), + )?; + } + for (key, value) in &molecule.properties { + insert( + KIND_MOL_PROP, + key, + property_value_type_tag(value), + property_value_to_bytes(value).len(), + )?; + } + + Ok(schema) +} + +fn parse_record_schema_with_positions( + bytes: &[u8], + record_format: u32, + positions_type: u8, +) -> Result { + if bytes.len() < 6 { + return Err(Error::InvalidData("SOA record too small".into())); + } + + let mut pos = 0usize; + let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + + let positions_end = pos + .checked_add( + n_atoms + .checked_mul(positions_stride(positions_type)?) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, + ) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at positions".into(), + )); + } + pos = positions_end; + + let z_end = pos + .checked_add(n_atoms) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; + if z_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at atomic_numbers".into(), + )); + } + pos = z_end; + + if pos + 2 > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at n_sections".into(), + )); + } + let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + + let mut schema = SchemaLock { + positions_type: Some(positions_type), + sections: BTreeMap::new(), + }; + + for _ in 0..n_sections { + if pos + 7 > bytes.len() { + return Err(Error::InvalidData("SOA section header truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("SOA section key truncated".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? + .to_string(); + pos += key_len; + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let payload_end = pos + .checked_add(payload_len) + .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; + if payload_end > bytes.len() { + return Err(Error::InvalidData("SOA section payload truncated".into())); + } + pos = payload_end; + if kind == KIND_BUILTIN { + validate_builtin_type_tag_for_record_format(record_format, &key, type_tag)?; + } + let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key), entry); + } + + Ok(schema) +} + +pub(super) fn record_schema( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + parse_record_schema_with_positions(bytes, record_format, positions_type) +} + +pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { + match (lock.positions_type, record.positions_type) { + (None, Some(tag)) => lock.positions_type = Some(tag), + (Some(expected), Some(actual)) if expected != actual => { + return Err(Error::InvalidData(format!( + "Position dtype mismatch: expected type tag {}, got {}", + expected, actual + ))); + } + _ => {} + } + + for ((kind, key), entry) in &record.sections { + match lock.sections.get(&(*kind, key.clone())) { + Some(expected) if expected != entry => { + return Err(Error::InvalidData(format!( + "Schema mismatch for section '{}': expected {:?}, got {:?}", + key, expected, entry + ))); + } + Some(_) => {} + None => { + lock.sections.insert((*kind, key.clone()), entry.clone()); + } + } + } + + Ok(()) +} diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index 1ddfec0..df37f9f 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -1,128 +1,5 @@ use super::*; use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; -use std::collections::BTreeMap; - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub(super) struct SchemaLock { - pub(super) positions_type: Option, - pub(super) sections: BTreeMap<(u8, String), SchemaEntry>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) struct SchemaEntry { - pub(super) type_tag: u8, - pub(super) per_atom: bool, - pub(super) elem_bytes: usize, - pub(super) slot_bytes: usize, -} - -const SCHEMA_BLOB_VERSION: u32 = 1; - -fn positions_type_from_molecule(molecule: &Molecule) -> u8 { - match molecule.positions { - Vec3Data::F32(_) => TYPE_VEC3_F32, - Vec3Data::F64(_) => TYPE_VEC3_F64, - } -} - -pub(super) fn encode_schema_lock(lock: &SchemaLock) -> Result> { - let mut buf = Vec::new(); - buf.extend_from_slice(&SCHEMA_BLOB_VERSION.to_le_bytes()); - buf.push(lock.positions_type.unwrap_or(255)); - buf.extend_from_slice(&(lock.sections.len() as u32).to_le_bytes()); - for ((kind, key), entry) in &lock.sections { - let key_len: u16 = key - .len() - .try_into() - .map_err(|_| Error::InvalidData(format!("Schema key '{}' is too long", key)))?; - let elem_bytes: u32 = entry - .elem_bytes - .try_into() - .map_err(|_| Error::InvalidData(format!("Schema elem_bytes overflow for '{}'", key)))?; - let slot_bytes: u32 = entry - .slot_bytes - .try_into() - .map_err(|_| Error::InvalidData(format!("Schema slot_bytes overflow for '{}'", key)))?; - buf.push(*kind); - buf.push(entry.type_tag); - buf.push(u8::from(entry.per_atom)); - buf.extend_from_slice(&key_len.to_le_bytes()); - buf.extend_from_slice(&elem_bytes.to_le_bytes()); - buf.extend_from_slice(&slot_bytes.to_le_bytes()); - buf.extend_from_slice(key.as_bytes()); - } - Ok(buf) -} - -pub(super) fn decode_schema_lock(bytes: &[u8]) -> Result { - if bytes.len() < 9 { - return Err(Error::InvalidData("Schema blob too small".into())); - } - let version = u32::from_le_bytes(arr(&bytes[0..4])?); - if version != SCHEMA_BLOB_VERSION { - return Err(Error::InvalidData(format!( - "Unsupported schema blob version {}", - version - ))); - } - let positions_type = match bytes[4] { - 255 => None, - TYPE_VEC3_F32 | TYPE_VEC3_F64 => Some(bytes[4]), - other => { - return Err(Error::InvalidData(format!( - "Unsupported schema positions type tag {}", - other - ))); - } - }; - let count = u32::from_le_bytes(arr(&bytes[5..9])?) as usize; - let mut pos = 9usize; - let mut sections = BTreeMap::new(); - for _ in 0..count { - if pos + 13 > bytes.len() { - return Err(Error::InvalidData("Schema blob truncated".into())); - } - let kind = bytes[pos]; - pos += 1; - let type_tag = bytes[pos]; - pos += 1; - let per_atom = match bytes[pos] { - 0 => false, - 1 => true, - _ => return Err(Error::InvalidData("Invalid schema per_atom flag".into())), - }; - pos += 1; - let key_len = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; - pos += 2; - let elem_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - let slot_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - if pos + key_len > bytes.len() { - return Err(Error::InvalidData("Schema blob truncated at key".into())); - } - let key = std::str::from_utf8(&bytes[pos..pos + key_len]) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in schema key".into()))? - .to_string(); - pos += key_len; - sections.insert( - (kind, key), - SchemaEntry { - type_tag, - per_atom, - elem_bytes, - slot_bytes, - }, - ); - } - if pos != bytes.len() { - return Err(Error::InvalidData("Schema blob trailing bytes".into())); - } - Ok(SchemaLock { - positions_type, - sections, - }) -} /// Write a single tagged section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: &[u8]) { @@ -134,7 +11,7 @@ fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: buf.extend_from_slice(payload); } -fn property_value_type_tag(value: &PropertyValue) -> u8 { +pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { match value { PropertyValue::Float(_) => TYPE_FLOAT, PropertyValue::Int(_) => TYPE_INT, @@ -169,7 +46,7 @@ fn extend_i32(b: &mut Vec, v: &[i32]) { } } -fn property_value_to_bytes(value: &PropertyValue) -> Vec { +pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { match value { PropertyValue::Float(v) => v.to_le_bytes().to_vec(), PropertyValue::Int(v) => v.to_le_bytes().to_vec(), @@ -388,9 +265,17 @@ fn write_energy_section(buf: &mut Vec, value: &FloatScalarData) { } } -fn resolve_positions_type(record_format: u32, positions_type_hint: Option) -> Result { +pub(super) fn resolve_positions_type( + record_format: u32, + positions_type_hint: Option, +) -> Result { match record_format { - RECORD_FORMAT_SOA_V2 => Ok(TYPE_VEC3_F32), + RECORD_FORMAT_SOA_V2 => match positions_type_hint { + Some(TYPE_VEC3_F64) => Err(Error::InvalidData( + "record format 2 does not support float64 positions".into(), + )), + _ => Ok(TYPE_VEC3_F32), + }, RECORD_FORMAT_SOA_V3 => positions_type_hint.ok_or_else(|| { Error::InvalidData("Missing positions dtype for record format 3".into()) }), @@ -401,7 +286,7 @@ fn resolve_positions_type(record_format: u32, positions_type_hint: Option) - } } -fn positions_stride(positions_type: u8) -> Result { +pub(super) fn positions_stride(positions_type: u8) -> Result { match positions_type { TYPE_VEC3_F32 => Ok(12), TYPE_VEC3_F64 => Ok(24), @@ -691,89 +576,91 @@ fn decode_positions( Ok(positions) } -fn decode_builtin_section( - mol: &mut Molecule, - key: &str, - type_tag: u8, +fn decode_float_scalar_data( payload: &[u8], -) -> Result<()> { - match key { - "energy" => match type_tag { - TYPE_FLOAT => { - if payload.len() != 8 { - return Err(Error::InvalidData("energy f64 payload truncated".into())); - } - mol.energy = Some(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))); - } - TYPE_FLOAT32 => { - if payload.len() != 4 { - return Err(Error::InvalidData("energy f32 payload truncated".into())); - } - mol.energy = Some(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))); - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported energy type tag {}", - type_tag - ))); - } - }, - "forces" => match type_tag { - TYPE_VEC3_F32 => mol.forces = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), - TYPE_VEC3_F64 => mol.forces = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), - _ => { - return Err(Error::InvalidData(format!( - "Unsupported forces type tag {}", - type_tag - ))); - } - }, - "charges" => match type_tag { - TYPE_F32_ARRAY => mol.charges = Some(FloatArrayData::F32(decode_f32_array(payload)?)), - TYPE_F64_ARRAY => mol.charges = Some(FloatArrayData::F64(decode_f64_array(payload)?)), - _ => { - return Err(Error::InvalidData(format!( - "Unsupported charges type tag {}", - type_tag - ))); - } - }, - "velocities" => match type_tag { - TYPE_VEC3_F32 => mol.velocities = Some(Vec3Data::F32(decode_vec3_f32(payload)?)), - TYPE_VEC3_F64 => mol.velocities = Some(Vec3Data::F64(decode_vec3_f64(payload)?)), - _ => { + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_FLOAT => { + if payload.len() != 8 { return Err(Error::InvalidData(format!( - "Unsupported velocities type tag {}", - type_tag + "{field_name} f64 payload truncated" ))); } - }, - "cell" => match type_tag { - TYPE_MAT3X3_F32 => mol.cell = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), - TYPE_MAT3X3_F64 => mol.cell = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - _ => { + Ok(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))) + } + TYPE_FLOAT32 => { + if payload.len() != 4 { return Err(Error::InvalidData(format!( - "Unsupported cell type tag {}", - type_tag + "{field_name} f32 payload truncated" ))); } - }, + Ok(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))) + } + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +fn decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +fn decode_float_array_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +fn decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +fn decode_builtin_section( + mol: &mut Molecule, + key: &str, + type_tag: u8, + payload: &[u8], +) -> Result<()> { + match key { + "energy" => mol.energy = Some(decode_float_scalar_data(payload, type_tag, "energy")?), + "forces" => mol.forces = Some(decode_vec3_data(payload, type_tag, "forces")?), + "charges" => mol.charges = Some(decode_float_array_data(payload, type_tag, "charges")?), + "velocities" => mol.velocities = Some(decode_vec3_data(payload, type_tag, "velocities")?), + "cell" => mol.cell = Some(decode_mat3_data(payload, type_tag, "cell")?), "pbc" => { if payload.len() < 3 { return Err(Error::InvalidData("pbc payload truncated".into())); } mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); } - "stress" => match type_tag { - TYPE_MAT3X3_F32 => mol.stress = Some(Mat3Data::F32(decode_mat3x3_f32(payload)?)), - TYPE_MAT3X3_F64 => mol.stress = Some(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - _ => { - return Err(Error::InvalidData(format!( - "Unsupported stress type tag {}", - type_tag - ))); - } - }, + "stress" => mol.stress = Some(decode_mat3_data(payload, type_tag, "stress")?), "name" => { mol.name = Some( std::str::from_utf8(payload) @@ -873,277 +760,3 @@ pub(super) fn deserialize_molecule_soa( let positions_type = resolve_positions_type(record_format, positions_type_hint)?; deserialize_molecule_soa_with_positions(bytes, positions_type) } - -fn schema_type_tag_elem_bytes(tag: u8) -> Result { - match tag { - TYPE_FLOAT => Ok(8), - TYPE_INT => Ok(8), - TYPE_STRING => Ok(0), - TYPE_F64_ARRAY => Ok(8), - TYPE_VEC3_F32 => Ok(12), - TYPE_I64_ARRAY => Ok(8), - TYPE_F32_ARRAY => Ok(4), - TYPE_VEC3_F64 => Ok(24), - TYPE_I32_ARRAY => Ok(4), - TYPE_BOOL3 => Ok(3), - TYPE_MAT3X3_F64 => Ok(72), - TYPE_FLOAT32 => Ok(4), - TYPE_MAT3X3_F32 => Ok(36), - _ => Err(Error::InvalidData(format!( - "Unsupported section type tag {}", - tag - ))), - } -} - -fn schema_is_per_atom(kind: u8, key: &str) -> bool { - match kind { - KIND_ATOM_PROP => true, - KIND_MOL_PROP => false, - KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), - _ => false, - } -} - -fn schema_entry( - kind: u8, - key: &str, - type_tag: u8, - payload_len: usize, - n_atoms: usize, -) -> Result { - let per_atom = schema_is_per_atom(kind, key); - let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; - let slot_bytes = if type_tag == TYPE_STRING { - 0 - } else if per_atom { - match type_tag { - TYPE_F64_ARRAY | TYPE_I64_ARRAY | TYPE_F32_ARRAY | TYPE_I32_ARRAY => elem_bytes, - TYPE_VEC3_F32 | TYPE_VEC3_F64 => elem_bytes, - _ => payload_len - .checked_div(n_atoms.max(1)) - .unwrap_or(elem_bytes), - } - } else { - payload_len - }; - - if per_atom { - let expected = elem_bytes - .checked_mul(n_atoms) - .ok_or_else(|| Error::InvalidData(format!("Schema overflow for section '{}'", key)))?; - if payload_len != expected { - return Err(Error::InvalidData(format!( - "Section '{}' payload length {} does not match expected {}", - key, payload_len, expected - ))); - } - } - - Ok(SchemaEntry { - type_tag, - per_atom, - elem_bytes, - slot_bytes, - }) -} - -pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { - let n_atoms = molecule.len(); - let mut schema = SchemaLock { - positions_type: Some(positions_type_from_molecule(molecule)), - sections: BTreeMap::new(), - }; - - let mut insert = |kind: u8, key: &str, type_tag: u8, payload_len: usize| -> Result<()> { - let entry = schema_entry(kind, key, type_tag, payload_len, n_atoms)?; - schema.sections.insert((kind, key.to_string()), entry); - Ok(()) - }; - - if let Some(charges) = &molecule.charges { - let (type_tag, payload_len) = match charges { - FloatArrayData::F32(values) => (TYPE_F32_ARRAY, values.len() * 4), - FloatArrayData::F64(values) => (TYPE_F64_ARRAY, values.len() * 8), - }; - insert(KIND_BUILTIN, "charges", type_tag, payload_len)?; - } - if let Some(cell) = &molecule.cell { - let (type_tag, payload_len) = match cell { - Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), - Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), - }; - insert(KIND_BUILTIN, "cell", type_tag, payload_len)?; - } - if let Some(energy) = &molecule.energy { - let (type_tag, payload_len) = match energy { - FloatScalarData::F32(_) => (TYPE_FLOAT32, 4), - FloatScalarData::F64(_) => (TYPE_FLOAT, 8), - }; - insert(KIND_BUILTIN, "energy", type_tag, payload_len)?; - } - if let Some(forces) = &molecule.forces { - let (type_tag, payload_len) = match forces { - Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), - Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), - }; - insert(KIND_BUILTIN, "forces", type_tag, payload_len)?; - } - if let Some(name) = &molecule.name { - insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; - } - if molecule.pbc.is_some() { - insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; - } - if let Some(stress) = &molecule.stress { - let (type_tag, payload_len) = match stress { - Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), - Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), - }; - insert(KIND_BUILTIN, "stress", type_tag, payload_len)?; - } - if let Some(velocities) = &molecule.velocities { - let (type_tag, payload_len) = match velocities { - Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), - Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), - }; - insert(KIND_BUILTIN, "velocities", type_tag, payload_len)?; - } - - for (key, value) in &molecule.atom_properties { - insert( - KIND_ATOM_PROP, - key, - property_value_type_tag(value), - property_value_to_bytes(value).len(), - )?; - } - for (key, value) in &molecule.properties { - insert( - KIND_MOL_PROP, - key, - property_value_type_tag(value), - property_value_to_bytes(value).len(), - )?; - } - - Ok(schema) -} - -fn parse_record_schema_with_positions(bytes: &[u8], positions_type: u8) -> Result { - if bytes.len() < 6 { - return Err(Error::InvalidData("SOA record too small".into())); - } - - let mut pos = 0usize; - let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - - let positions_end = pos - .checked_add( - n_atoms - .checked_mul(positions_stride(positions_type)?) - .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, - ) - .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at positions".into(), - )); - } - pos = positions_end; - - let z_end = pos - .checked_add(n_atoms) - .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; - if z_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at atomic_numbers".into(), - )); - } - pos = z_end; - - if pos + 2 > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at n_sections".into(), - )); - } - let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; - pos += 2; - - let mut schema = SchemaLock { - positions_type: Some(positions_type), - sections: BTreeMap::new(), - }; - - for _ in 0..n_sections { - if pos + 7 > bytes.len() { - return Err(Error::InvalidData("SOA section header truncated".into())); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(Error::InvalidData("SOA section key truncated".into())); - } - let key = std::str::from_utf8(&bytes[pos..pos + key_len]) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? - .to_string(); - pos += key_len; - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; - pos += 4; - let payload_end = pos - .checked_add(payload_len) - .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; - if payload_end > bytes.len() { - return Err(Error::InvalidData("SOA section payload truncated".into())); - } - pos = payload_end; - let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; - schema.sections.insert((kind, key), entry); - } - - Ok(schema) -} - -pub(super) fn record_schema( - bytes: &[u8], - record_format: u32, - positions_type_hint: Option, -) -> Result { - let positions_type = resolve_positions_type(record_format, positions_type_hint)?; - parse_record_schema_with_positions(bytes, positions_type) -} - -pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { - match (lock.positions_type, record.positions_type) { - (None, Some(tag)) => lock.positions_type = Some(tag), - (Some(expected), Some(actual)) if expected != actual => { - return Err(Error::InvalidData(format!( - "Position dtype mismatch: expected type tag {}, got {}", - expected, actual - ))); - } - _ => {} - } - - for ((kind, key), entry) in &record.sections { - match lock.sections.get(&(*kind, key.clone())) { - Some(expected) if expected != entry => { - return Err(Error::InvalidData(format!( - "Schema mismatch for section '{}': expected {:?}, got {:?}", - key, expected, entry - ))); - } - Some(_) => {} - None => { - lock.sections.insert((*kind, key.clone()), entry.clone()); - } - } - } - - Ok(()) -} From 40cb6661b0e734f14935e4af2a492a0d69ee660e Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sun, 10 May 2026 13:23:27 +0200 Subject: [PATCH 5/9] perf: restore typed-float fast paths --- atompack-py/src/database.rs | 48 ++- atompack-py/src/database_batch.rs | 578 ++++++++++++++++++++++++++++- atompack-py/src/database_flat.rs | 527 ++++++++++++++++++++------ atompack-py/src/soa.rs | 347 ++++++++++++++++- atompack-py/tests/test_database.py | 84 ++++- atompack/src/atom.rs | 8 +- atompack/src/storage/mod.rs | 417 +++++++++++++++++---- atompack/src/storage/schema.rs | 36 +- atompack/src/storage/soa.rs | 262 +++++++++++-- 9 files changed, 2053 insertions(+), 254 deletions(-) diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index bea4369..ac3688d 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -127,6 +127,15 @@ impl PyAtomDatabase { /// Add a molecule to the database fn add_molecule(&mut self, molecule: &PyMolecule) -> PyResult<()> { + if let Some(view) = molecule.as_view() { + return self + .inner + .add_raw_soa_records_with_schema( + &[(view.raw_bytes(), view.n_atoms as u32)], + view.database_schema()?, + ) + .map_err(|e| PyValueError::new_err(format!("{}", e))); + } let owned = molecule.clone_as_owned()?; self.inner .add_molecule(&owned) @@ -135,10 +144,47 @@ impl PyAtomDatabase { /// Add multiple molecules (processed in parallel) fn add_molecules(&mut self, molecules: Vec>) -> PyResult<()> { + let mut raw_records: Vec<(&[u8], u32)> = Vec::new(); + let mut raw_views: Vec<&SoaMoleculeView> = Vec::new(); let mut owned_molecules: Vec = Vec::new(); for m in &molecules { - owned_molecules.push(m.clone_as_owned()?); + if let Some(view) = m.as_view() { + raw_records.push((view.raw_bytes(), view.n_atoms as u32)); + raw_views.push(view); + } else { + owned_molecules.push(m.clone_as_owned()?); + } + } + + if !raw_records.is_empty() { + let mut fast_schema = None; + if let Some((first, rest)) = raw_views.split_first() { + let mut all_match = true; + for view in rest { + if !view.same_schema_as(first)? { + all_match = false; + break; + } + } + if all_match { + fast_schema = Some(first.database_schema()?); + } + } + + if let Some(schema) = fast_schema { + self.inner + .add_raw_soa_records_with_schema(&raw_records, schema) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + } else { + let raw_records_with_positions = raw_views + .iter() + .map(|view| (view.raw_bytes(), view.n_atoms as u32, view.positions_type)) + .collect::>(); + self.inner + .add_raw_soa_records_with_positions_type(&raw_records_with_positions) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + } } if !owned_molecules.is_empty() { let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect(); diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index ff11c96..40560d4 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -1,5 +1,6 @@ use super::*; use crate::molecule::{SoaRecord, SoaSection, build_soa_record}; +use atompack::storage::{DatabaseSchema, DatabaseSchemaSection}; struct BatchSectionColumn { key: String, @@ -31,6 +32,75 @@ impl BatchSectionColumn { } } +fn batch_section_is_per_atom(kind: u8, key: &str) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn database_schema_section(column: &BatchSectionColumn) -> PyResult { + let per_atom = batch_section_is_per_atom(column.kind, &column.key); + let elem_bytes = if column.type_tag == TYPE_STRING { + 0 + } else { + let elem_bytes = type_tag_elem_bytes(column.type_tag); + if elem_bytes == 0 { + return Err(PyValueError::new_err(format!( + "Unsupported section type tag {} for '{}'", + column.type_tag, column.key + ))); + } + elem_bytes + }; + let slot_bytes = if column.type_tag == TYPE_STRING { + 0 + } else if per_atom { + elem_bytes + } else { + column.slot_bytes + }; + + Ok(DatabaseSchemaSection { + kind: column.kind, + key: column.key.clone(), + type_tag: column.type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn build_batch_schema<'a, I>(positions_type: u8, columns: I) -> PyResult +where + I: IntoIterator, +{ + let sections = columns + .into_iter() + .map(database_schema_section) + .collect::>>()?; + Ok(DatabaseSchema { + positions_type: Some(positions_type), + sections, + }) +} + +fn push_builtin_section<'a>( + sections: &mut Vec>, + key: &'a str, + type_tag: u8, + payload: &'a [u8], +) { + sections.push(SoaSection { + kind: KIND_BUILTIN, + key, + type_tag, + payload, + }); +} + fn reject_reserved_key(key: &str) -> PyResult<()> { if key == "stress" { return Err(PyValueError::new_err( @@ -408,6 +478,460 @@ fn extract_custom_columns( Ok(columns) } +struct FastMat3Column { + type_tag: u8, + slot_bytes: usize, + payload: Vec, +} + +impl FastMat3Column { + fn from_optional( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + label: &str, + ) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + if let Ok(arr) = value.cast::>() { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F32, + slot_bytes: 36, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } + if let Ok(arr) = value.cast::>() { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F64, + slot_bytes: 72, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } + Ok(None) + } + + fn type_tag(&self) -> u8 { + self.type_tag + } + + fn slot_bytes(&self) -> usize { + self.slot_bytes + } + + fn payload_bytes(&self, index: usize) -> &[u8] { + let start = index * self.slot_bytes; + let end = start + self.slot_bytes; + &self.payload[start..end] + } +} + +#[allow(clippy::too_many_arguments)] +fn try_add_arrays_batch_fast_canonical( + inner: &mut AtomDatabase, + py: Python<'_>, + positions: &Bound<'_, PyAny>, + atomic_numbers: &Bound<'_, PyArray2>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, + pbc: Option<&Bound<'_, PyArray2>>, + name: Option>, + properties: Option<&Bound<'_, PyDict>>, + atom_properties: Option<&Bound<'_, PyDict>>, +) -> PyResult { + let Ok(positions) = positions.cast::>() else { + return Ok(false); + }; + let energy = match energy { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let forces = match forces { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let charges = match charges { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let velocities = match velocities { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + + let pos = positions.readonly(); + let pos_arr = pos.as_array(); + let pos_shape = pos_arr.shape(); + if pos_shape.len() != 3 || pos_shape[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let batch = pos_shape[0]; + let n_atoms = pos_shape[1]; + let pos_slice = pos_arr + .as_slice() + .ok_or_else(|| PyValueError::new_err("positions must be C-contiguous"))?; + + let z = atomic_numbers.readonly(); + let z_arr = z.as_array(); + if z_arr.shape() != [batch, n_atoms] { + return Err(PyValueError::new_err(format!( + "atomic_numbers must have shape ({}, {})", + batch, n_atoms + ))); + } + let z_slice = z_arr + .as_slice() + .ok_or_else(|| PyValueError::new_err("atomic_numbers must be C-contiguous"))?; + + let energy_ro = energy.map(|arr| arr.readonly()); + let energy_slice = if let Some(ro) = energy_ro.as_ref() { + let view = ro.as_array(); + if view.len() != batch { + return Err(PyValueError::new_err(format!( + "energy length ({}) doesn't match batch size ({})", + view.len(), + batch + ))); + } + Some( + ro.as_slice() + .map_err(|_| PyValueError::new_err("energy must be C-contiguous"))?, + ) + } else { + None + }; + + let forces_ro = forces.map(|arr| arr.readonly()); + let forces_slice = if let Some(ro) = forces_ro.as_ref() { + let view = ro.as_array(); + if view.shape() != [batch, n_atoms, 3] { + return Err(PyValueError::new_err(format!( + "forces must have shape ({}, {}, 3)", + batch, n_atoms + ))); + } + Some( + ro.as_slice() + .map_err(|_| PyValueError::new_err("forces must be C-contiguous"))?, + ) + } else { + None + }; + + let charges_ro = charges.map(|arr| arr.readonly()); + let charges_slice = if let Some(ro) = charges_ro.as_ref() { + let view = ro.as_array(); + if view.shape() != [batch, n_atoms] { + return Err(PyValueError::new_err(format!( + "charges must have shape ({}, {})", + batch, n_atoms + ))); + } + Some( + ro.as_slice() + .map_err(|_| PyValueError::new_err("charges must be C-contiguous"))?, + ) + } else { + None + }; + + let velocities_ro = velocities.map(|arr| arr.readonly()); + let velocities_slice = if let Some(ro) = velocities_ro.as_ref() { + let view = ro.as_array(); + if view.shape() != [batch, n_atoms, 3] { + return Err(PyValueError::new_err(format!( + "velocities must have shape ({}, {}, 3)", + batch, n_atoms + ))); + } + Some( + ro.as_slice() + .map_err(|_| PyValueError::new_err("velocities must be C-contiguous"))?, + ) + } else { + None + }; + + let cell_slice = match FastMat3Column::from_optional(cell, batch, "cell")? { + Some(column) => Some(column), + None if cell.is_some() => return Ok(false), + None => None, + }; + + let stress_slice = match FastMat3Column::from_optional(stress, batch, "stress")? { + Some(column) => Some(column), + None if stress.is_some() => return Ok(false), + None => None, + }; + + let pbc_ro = pbc.map(|arr| arr.readonly()); + let pbc_slice = if let Some(ro) = pbc_ro.as_ref() { + let view = ro.as_array(); + if view.shape() != [batch, 3] { + return Err(PyValueError::new_err(format!( + "pbc must have shape ({}, 3)", + batch + ))); + } + Some( + ro.as_slice() + .map_err(|_| PyValueError::new_err("pbc must be C-contiguous"))?, + ) + } else { + None + }; + + if let Some(names) = &name + && names.len() != batch + { + return Err(PyValueError::new_err(format!( + "name length ({}) doesn't match batch size ({})", + names.len(), + batch + ))); + } + + let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + let mut builtin_columns = Vec::new(); + if energy_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "energy".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_FLOAT, + slot_bytes: 8, + payload: Vec::new(), + strings: None, + }); + } + if forces_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "forces".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_VEC3_F32, + slot_bytes: n_atoms * 12, + payload: Vec::new(), + strings: None, + }); + } + if charges_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "charges".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_F64_ARRAY, + slot_bytes: n_atoms * 8, + payload: Vec::new(), + strings: None, + }); + } + if velocities_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "velocities".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_VEC3_F32, + slot_bytes: n_atoms * 12, + payload: Vec::new(), + strings: None, + }); + } + if let Some(column) = cell_slice.as_ref() { + builtin_columns.push(BatchSectionColumn { + key: "cell".to_string(), + kind: KIND_BUILTIN, + type_tag: column.type_tag(), + slot_bytes: column.slot_bytes(), + payload: Vec::new(), + strings: None, + }); + } + if let Some(column) = stress_slice.as_ref() { + builtin_columns.push(BatchSectionColumn { + key: "stress".to_string(), + kind: KIND_BUILTIN, + type_tag: column.type_tag(), + slot_bytes: column.slot_bytes(), + payload: Vec::new(), + strings: None, + }); + } + if pbc_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "pbc".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_BOOL3, + slot_bytes: 3, + payload: Vec::new(), + strings: None, + }); + } + if name.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "name".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_STRING, + slot_bytes: 0, + payload: Vec::new(), + strings: None, + }); + } + + let batch_schema = build_batch_schema( + TYPE_VEC3_F32, + builtin_columns.iter().chain(custom_columns.iter()), + )?; + let record_format = inner + .record_format_for_schema(batch_schema.clone()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + let builtin_section_count = usize::from(energy_slice.is_some()) + + usize::from(forces_slice.is_some()) + + usize::from(charges_slice.is_some()) + + usize::from(velocities_slice.is_some()) + + usize::from(cell_slice.is_some()) + + usize::from(stress_slice.is_some()) + + usize::from(pbc_slice.is_some()) + + usize::from(name.is_some()); + + let build_record = |i: usize| { + let pos_start = i * n_atoms * 3; + let pos_end = pos_start + n_atoms * 3; + let z_start = i * n_atoms; + let z_end = z_start + n_atoms; + let forces_payload = forces_slice + .as_ref() + .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); + let charges_payload = charges_slice + .as_ref() + .map(|slice| bytemuck::cast_slice::(&slice[z_start..z_end])); + let velocities_payload = velocities_slice + .as_ref() + .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); + let cell_payload = cell_slice.as_ref().map(|column| column.payload_bytes(i)); + let stress_payload = stress_slice.as_ref().map(|column| column.payload_bytes(i)); + let energy_bytes = energy_slice.as_ref().map(|slice| slice[i].to_le_bytes()); + let pbc_payload = pbc_slice.as_ref().map(|slice| { + [ + slice[i * 3] as u8, + slice[i * 3 + 1] as u8, + slice[i * 3 + 2] as u8, + ] + }); + + let mut sections = Vec::with_capacity(builtin_section_count + custom_columns.len()); + if let Some(payload) = charges_payload { + push_builtin_section(&mut sections, "charges", TYPE_F64_ARRAY, payload); + } + if let Some(payload) = cell_payload { + push_builtin_section( + &mut sections, + "cell", + cell_slice + .as_ref() + .map(FastMat3Column::type_tag) + .expect("cell type tag must exist when payload exists"), + payload, + ); + } + if let Some(bytes) = energy_bytes.as_ref() { + push_builtin_section(&mut sections, "energy", TYPE_FLOAT, bytes); + } + if let Some(payload) = forces_payload { + push_builtin_section(&mut sections, "forces", TYPE_VEC3_F32, payload); + } + if let Some(names) = name.as_ref() { + push_builtin_section(&mut sections, "name", TYPE_STRING, names[i].as_bytes()); + } + if let Some(payload) = pbc_payload.as_ref() { + push_builtin_section(&mut sections, "pbc", TYPE_BOOL3, payload); + } + if let Some(payload) = stress_payload { + push_builtin_section( + &mut sections, + "stress", + stress_slice + .as_ref() + .map(FastMat3Column::type_tag) + .expect("stress type tag must exist when payload exists"), + payload, + ); + } + if let Some(payload) = velocities_payload { + push_builtin_section(&mut sections, "velocities", TYPE_VEC3_F32, payload); + } + sections.extend(custom_columns.iter().map(|column| column.section_for(i))); + + build_soa_record(SoaRecord { + record_format, + positions_type: TYPE_VEC3_F32, + positions: bytemuck::cast_slice::(&pos_slice[pos_start..pos_end]), + atomic_numbers: &z_slice[z_start..z_end], + sections: §ions, + }) + .map(|record| (record, n_atoms as u32)) + }; + + let records: Vec<(Vec, u32)> = if batch >= 1024 { + use rayon::prelude::*; + (0..batch) + .into_par_iter() + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + } else { + (0..batch) + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + }; + + py.detach(move || inner.add_owned_soa_records_with_schema(records, batch_schema)) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + Ok(true) +} + fn extract_positions_payload(value: &Bound<'_, PyAny>) -> PyResult<(usize, usize, u8, Vec)> { if let Ok(arr) = value.cast::>() { let readonly = arr.readonly(); @@ -644,6 +1168,25 @@ pub(super) fn add_arrays_batch_impl( properties: Option<&Bound<'_, PyDict>>, atom_properties: Option<&Bound<'_, PyDict>>, ) -> PyResult<()> { + if try_add_arrays_batch_fast_canonical( + inner, + py, + positions, + atomic_numbers, + energy, + forces, + charges, + velocities, + cell, + stress, + pbc, + name.clone(), + properties, + atom_properties, + )? { + return Ok(()); + } + let (batch, n_atoms, positions_type, positions_payload) = extract_positions_payload(positions)?; let atomic_numbers_payload = extract_atomic_numbers_payload(atomic_numbers, batch, n_atoms)?; @@ -788,10 +1331,16 @@ pub(super) fn add_arrays_batch_impl( let positions_slot_bytes = n_atoms .checked_mul(type_tag_elem_bytes(positions_type)) .ok_or_else(|| PyValueError::new_err("positions byte length overflow"))?; - let record_format = inner.record_format(); - let mut records = Vec::with_capacity(batch); - for index in 0..batch { + let batch_schema = build_batch_schema( + positions_type, + builtin_columns.iter().chain(custom_columns.iter()), + )?; + let record_format = inner + .record_format_for_schema(batch_schema.clone()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + + let build_record = |index: usize| -> Result<(Vec, u32), String> { let pos_start = index * positions_slot_bytes; let pos_end = pos_start + positions_slot_bytes; let z_start = index * n_atoms; @@ -815,11 +1364,24 @@ pub(super) fn add_arrays_batch_impl( positions: &positions_payload[pos_start..pos_end], atomic_numbers: &atomic_numbers_payload[z_start..z_end], sections: §ions, - }) - .map_err(PyValueError::new_err)?; - records.push((record, n_atoms as u32, positions_type)); - } + })?; + Ok((record, n_atoms as u32)) + }; + + let records: Vec<(Vec, u32)> = if batch >= 1024 { + use rayon::prelude::*; + (0..batch) + .into_par_iter() + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + } else { + (0..batch) + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + }; - py.detach(move || inner.add_owned_soa_records(records)) + py.detach(move || inner.add_owned_soa_records_with_schema(records, batch_schema)) .map_err(|e| PyValueError::new_err(format!("{}", e))) } diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index 63b0a56..56de991 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -4,6 +4,11 @@ use crate::molecule::{ pyarray1_from_cow, pyarray2_from_cow, }; +enum FlatPositions { + F32(Vec), + F64(Vec), +} + pub(super) fn get_molecules_flat_soa_impl<'py>( inner: &AtomDatabase, py: Python<'py>, @@ -43,34 +48,179 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .copied() .ok_or_else(|| invalid_data("missing final atom offset"))?; + let compression = inner.compression(); + let use_mmap = inner.get_compressed_slice(0).is_some(); let record_format = inner.record_format(); - let positions_type = inner - .positions_type() - .ok_or_else(|| invalid_data("Missing position dtype for batch"))?; - let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; - - let mut schema: Vec = Vec::new(); - for bytes in &raw_bytes { - let md = parse_mol_fast_soa(bytes, record_format, Some(positions_type))?; - - for section in &md.sections { - let incoming = section_schema_from_ref(section, md.n_atoms)?; - if let Some(existing) = schema.iter().find(|candidate| { - candidate.kind == incoming.kind && candidate.key == incoming.key - }) { - if existing.type_tag != incoming.type_tag - || existing.per_atom != incoming.per_atom - || existing.elem_bytes != incoming.elem_bytes - || existing.slot_bytes != incoming.slot_bytes - { - return Err(invalid_data(format!( - "SOA schema mismatch for section '{}'", - incoming.key - ))); - } + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => inner + .positions_type() + .ok_or_else(|| invalid_data("Missing position dtype for batch"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let schema_info = inner.schema_info(); + let raw_bytes_owned: Option>>; + let schema: Vec; + let use_ordered_schema: bool; + + let ordered_schema_from_first = |bytes: &[u8]| -> atompack::Result> { + let first_md = parse_mol_fast_soa(bytes, record_format, Some(positions_type))?; + let n = first_md.n_atoms; + first_md + .sections + .iter() + .map(|s| section_schema_from_ref(s, n)) + .collect::>() + }; + + let schema_matches_ordered = + |ordered: &[SectionSchema], from_lock: &[SectionSchema]| { + if ordered.len() != from_lock.len() { + return false; + } + let ordered_lookup: std::collections::HashMap<(u8, &str), &SectionSchema> = + ordered + .iter() + .map(|entry| ((entry.kind, entry.key.as_str()), entry)) + .collect(); + from_lock.iter().all(|entry| { + ordered_lookup + .get(&(entry.kind, entry.key.as_str())) + .is_some_and(|candidate| { + candidate.type_tag == entry.type_tag + && candidate.per_atom == entry.per_atom + && candidate.elem_bytes == entry.elem_bytes + && candidate.slot_bytes == entry.slot_bytes + }) + }) + }; + + if let Some(schema_info) = schema_info { + let schema_from_lock: Vec = schema_info + .sections + .into_iter() + .map(|section| SectionSchema { + kind: section.kind, + key: section.key, + type_tag: section.type_tag, + per_atom: section.per_atom, + elem_bytes: section.elem_bytes, + slot_bytes: section.slot_bytes, + }) + .collect(); + if use_mmap { + if compression == CompressionType::None { + let shared = inner.get_shared_mmap_bytes(indices[0]).ok_or_else(|| { + invalid_data(format!("Missing mmap bytes for molecule {}", indices[0])) + })?; + let ordered = ordered_schema_from_first(shared.as_slice())?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = None; } else { - schema.push(incoming); + let compressed = + inner.get_compressed_slice(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing compressed bytes for molecule {}", + indices[0] + )) + })?; + let uncompressed_size = + inner.uncompressed_size(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing uncompressed size for molecule {}", + indices[0] + )) + })? as usize; + let first_bytes = atompack::decompress_bytes( + compressed, + compression, + Some(uncompressed_size), + )?; + let ordered = ordered_schema_from_first(&first_bytes)?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = None; } + } else { + let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; + let ordered = ordered_schema_from_first(&raw_bytes[0])?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = Some(raw_bytes); + } + } else if use_mmap { + if compression == CompressionType::None { + let shared = inner.get_shared_mmap_bytes(indices[0]).ok_or_else(|| { + invalid_data(format!("Missing mmap bytes for molecule {}", indices[0])) + })?; + schema = ordered_schema_from_first(shared.as_slice())?; + use_ordered_schema = true; + } else { + let compressed = inner.get_compressed_slice(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing compressed bytes for molecule {}", + indices[0] + )) + })?; + let uncompressed_size = + inner.uncompressed_size(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing uncompressed size for molecule {}", + indices[0] + )) + })? as usize; + let first_bytes = atompack::decompress_bytes( + compressed, + compression, + Some(uncompressed_size), + )?; + schema = ordered_schema_from_first(&first_bytes)?; + use_ordered_schema = true; + } + raw_bytes_owned = None; + } else { + let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; + schema = ordered_schema_from_first(&raw_bytes[0])?; + use_ordered_schema = true; + raw_bytes_owned = Some(raw_bytes); + } + + let schema_keys: Vec<(u8, &[u8])> = if use_ordered_schema { + schema + .iter() + .map(|entry| (entry.kind, entry.key.as_bytes())) + .collect() + } else { + Vec::new() + }; + let mut schema_lookup: std::collections::HashMap< + u8, + std::collections::HashMap, + > = std::collections::HashMap::new(); + if !use_ordered_schema { + for (index, entry) in schema.iter().enumerate() { + schema_lookup + .entry(entry.kind) + .or_default() + .insert(entry.key.clone(), index); } } let positions_stride = match positions_type { @@ -83,7 +233,11 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ))); } }; - let mut positions = vec![0u8; total_atoms * positions_stride]; + let mut positions = match positions_type { + TYPE_VEC3_F32 => FlatPositions::F32(vec![0f32; total_atoms * 3]), + TYPE_VEC3_F64 => FlatPositions::F64(vec![0u8; total_atoms * positions_stride]), + _ => unreachable!(), + }; let mut atomic_numbers = vec![0u8; total_atoms]; let mut section_buffers: Vec> = schema @@ -110,7 +264,14 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( }) .collect(); - let pos_buf = RawBuf::new(&mut positions); + let pos_buf_f32 = match &mut positions { + FlatPositions::F32(values) => Some(RawBuf::new(values)), + FlatPositions::F64(_) => None, + }; + let pos_buf_f64 = match &mut positions { + FlatPositions::F32(_) => None, + FlatPositions::F64(values) => Some(RawBuf::new(values)), + }; let z_buf = RawBuf::new(&mut atomic_numbers); let sec_bufs: Vec> = section_buffers .iter_mut() @@ -137,11 +298,34 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let n = md.n_atoms; unsafe { - std::ptr::copy_nonoverlapping( - md.positions_bytes.as_ptr(), - pos_buf.at(atom_off * positions_stride), - n * positions_stride, - ); + match positions_type { + TYPE_VEC3_F32 => { + let pos_buf = pos_buf_f32 + .as_ref() + .ok_or_else(|| invalid_data("missing f32 position buffer"))?; + std::ptr::copy_nonoverlapping( + md.positions_bytes.as_ptr(), + pos_buf.at(atom_off * 3) as *mut u8, + n * 12, + ); + } + TYPE_VEC3_F64 => { + let pos_buf = pos_buf_f64 + .as_ref() + .ok_or_else(|| invalid_data("missing f64 position buffer"))?; + std::ptr::copy_nonoverlapping( + md.positions_bytes.as_ptr(), + pos_buf.at(atom_off * positions_stride), + n * positions_stride, + ); + } + other => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + other + ))); + } + } std::ptr::copy_nonoverlapping( md.atomic_numbers_bytes.as_ptr(), z_buf.at(atom_off), @@ -149,97 +333,228 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ); } - for (section_idx, schema_entry) in schema.iter().enumerate() { - let sec = md - .sections - .iter() - .find(|sec| sec.kind == schema_entry.kind && sec.key == schema_entry.key); - let Some(sec) = sec else { - continue; - }; - - if sec.type_tag != schema_entry.type_tag { + if use_ordered_schema { + if md.sections.len() != schema.len() { return Err(invalid_data(format!( - "SOA schema mismatch at molecule {} for section '{}'", - i, sec.key + "SOA schema mismatch for molecule {}: expected {} sections, got {}", + i, + schema.len(), + md.sections.len() ))); } + for (section_idx, sec) in md.sections.iter().enumerate() { + let schema_entry = &schema[section_idx]; + let expected_key = &schema_keys[section_idx]; - if schema_entry.per_atom { - let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { - invalid_data(format!("Section '{}' payload length overflow", sec.key)) - })?; - if sec.payload.len() != expected { + if sec.kind != expected_key.0 || sec.key.as_bytes() != expected_key.1 { + return Err(invalid_data(format!( + "SOA schema order mismatch at molecule {} for section '{}'", + i, sec.key + ))); + } + + if sec.type_tag != schema_entry.type_tag { + return Err(invalid_data(format!( + "SOA schema mismatch at molecule {} for section '{}'", + i, sec.key + ))); + } + + if schema_entry.per_atom { + let expected = + n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { + invalid_data(format!( + "Section '{}' payload length overflow", + sec.key + )) + })?; + if sec.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, + sec.payload.len(), + expected + ))); + } + } else if schema_entry.slot_bytes != 0 + && sec.payload.len() != schema_entry.slot_bytes + { return Err(invalid_data(format!( "Section '{}' has invalid payload length {} (expected {})", sec.key, sec.payload.len(), - expected + schema_entry.slot_bytes ))); } - } else if schema_entry.slot_bytes != 0 - && sec.payload.len() != schema_entry.slot_bytes - { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - sec.key, - sec.payload.len(), - schema_entry.slot_bytes - ))); + + if schema_entry.slot_bytes == 0 { + if let Some(ref mtx) = string_mutexes[section_idx] { + let val = Some( + std::str::from_utf8(sec.payload) + .map_err(|_| { + invalid_data(format!( + "Invalid UTF-8 in section '{}'", + sec.key + )) + })? + .to_string(), + ); + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("string section mutex poisoned"))?; + guard[i] = val; + } + } else { + let buf = &sec_bufs[section_idx]; + let offset = if schema_entry.per_atom { + atom_off * schema_entry.elem_bytes + } else { + i * schema_entry.slot_bytes + }; + unsafe { + std::ptr::copy_nonoverlapping( + sec.payload.as_ptr(), + buf.at(offset), + sec.payload.len(), + ); + } + } } + } else { + for sec in &md.sections { + let section_idx = schema_lookup + .get(&sec.kind) + .and_then(|entries| entries.get(sec.key)) + .copied() + .ok_or_else(|| { + invalid_data(format!( + "Unexpected SOA section '{}' in molecule {}", + sec.key, i + )) + })?; + let schema_entry = &schema[section_idx]; - if schema_entry.slot_bytes == 0 { - if let Some(ref mtx) = string_mutexes[section_idx] { - let val = Some( - std::str::from_utf8(sec.payload) - .map_err(|_| { - invalid_data(format!( - "Invalid UTF-8 in section '{}'", - sec.key - )) - })? - .to_string(), - ); - let mut guard = mtx - .lock() - .map_err(|_| invalid_data("string section mutex poisoned"))?; - guard[i] = val; + if sec.type_tag != schema_entry.type_tag { + return Err(invalid_data(format!( + "SOA schema mismatch at molecule {} for section '{}'", + i, sec.key + ))); } - } else { - let buf = &sec_bufs[section_idx]; - let offset = if schema_entry.per_atom { - atom_off * schema_entry.elem_bytes - } else { - i * schema_entry.slot_bytes - }; - unsafe { - std::ptr::copy_nonoverlapping( - sec.payload.as_ptr(), - buf.at(offset), + + if schema_entry.per_atom { + let expected = + n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { + invalid_data(format!( + "Section '{}' payload length overflow", + sec.key + )) + })?; + if sec.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, + sec.payload.len(), + expected + ))); + } + } else if schema_entry.slot_bytes != 0 + && sec.payload.len() != schema_entry.slot_bytes + { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, sec.payload.len(), - ); + schema_entry.slot_bytes + ))); + } + + if schema_entry.slot_bytes == 0 { + if let Some(ref mtx) = string_mutexes[section_idx] { + let val = Some( + std::str::from_utf8(sec.payload) + .map_err(|_| { + invalid_data(format!( + "Invalid UTF-8 in section '{}'", + sec.key + )) + })? + .to_string(), + ); + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("string section mutex poisoned"))?; + guard[i] = val; + } + } else { + let buf = &sec_bufs[section_idx]; + let offset = if schema_entry.per_atom { + atom_off * schema_entry.elem_bytes + } else { + i * schema_entry.slot_bytes + }; + unsafe { + std::ptr::copy_nonoverlapping( + sec.payload.as_ptr(), + buf.at(offset), + sec.payload.len(), + ); + } } } } Ok(()) }; - let results: Vec> = raw_bytes - .par_iter() - .enumerate() - .map(|(i, bytes)| process_mol(i, bytes)) - .collect(); + let results: Vec> = if use_mmap { + (0..n_mols) + .into_par_iter() + .map(|i| { + let idx = indices[i]; + if compression == CompressionType::None { + let shared = inner.get_shared_mmap_bytes(idx).ok_or_else(|| { + invalid_data(format!("Missing mmap bytes for molecule {}", idx)) + })?; + process_mol(i, shared.as_slice()) + } else { + let compressed = inner.get_compressed_slice(idx).ok_or_else(|| { + invalid_data(format!( + "Missing compressed bytes for molecule {}", + idx + )) + })?; + let uncompressed_size = + inner.uncompressed_size(idx).ok_or_else(|| { + invalid_data(format!( + "Missing uncompressed size for molecule {}", + idx + )) + })? as usize; + let decompressed = atompack::decompress_bytes( + compressed, + compression, + Some(uncompressed_size), + )?; + process_mol(i, &decompressed) + } + }) + .collect() + } else { + let raw_bytes = raw_bytes_owned.expect("raw bytes must exist without mmap"); + raw_bytes + .par_iter() + .enumerate() + .map(|(i, bytes)| process_mol(i, bytes)) + .collect() + }; results.into_iter().collect::>>()?; Ok(Some(( n_atoms_vec, - positions_type, positions, atomic_numbers, schema, section_buffers, string_sections, - n_mols, total_atoms, ))) }) @@ -247,13 +562,11 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let ( n_atoms_vec, - positions_type, positions, atomic_numbers, schema, section_buffers, string_results, - _n_mols, total_atoms, ) = match result { None => { @@ -273,21 +586,19 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let dict = PyDict::new(py); dict.set_item("n_atoms", PyArray1::from_vec(py, n_atoms_vec))?; - match positions_type { - TYPE_VEC3_F32 => { - let arr = cast_or_decode_f32(&positions)?; - dict.set_item("positions", pyarray2_from_cow(py, arr, total_atoms, 3)?)?; + match positions { + FlatPositions::F32(values) => { + dict.set_item( + "positions", + PyArray1::from_vec(py, values) + .reshape([total_atoms, 3]) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?, + )?; } - TYPE_VEC3_F64 => { - let arr = cast_or_decode_f64(&positions)?; + FlatPositions::F64(bytes) => { + let arr = cast_or_decode_f64(&bytes)?; dict.set_item("positions", pyarray2_from_cow(py, arr, total_atoms, 3)?)?; } - other => { - return Err(PyValueError::new_err(format!( - "Unsupported positions type tag {}", - other - ))); - } } dict.set_item("atomic_numbers", PyArray1::from_vec(py, atomic_numbers))?; diff --git a/atompack-py/src/soa.rs b/atompack-py/src/soa.rs index eb2d2cb..ce401e2 100644 --- a/atompack-py/src/soa.rs +++ b/atompack-py/src/soa.rs @@ -1,4 +1,5 @@ use super::*; +use atompack::storage::{DatabaseSchema, DatabaseSchemaSection}; /// A single parsed section reference (zero-copy into decompressed bytes). #[derive(Clone)] @@ -27,6 +28,42 @@ pub(crate) struct SectionSchema { pub(crate) slot_bytes: usize, } +pub(crate) fn parse_mol_fast_soa_v2(bytes: &[u8]) -> atompack::Result> { + let mut pos = 0usize; + let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; + let positions_len = n_atoms + .checked_mul(12) + .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; + let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; + let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; + let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; + + let mut sections = Vec::with_capacity(n_sections); + for _ in 0..n_sections { + let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; + let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; + let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; + let key = std::str::from_utf8(key_bytes) + .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; + let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; + let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; + let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; + sections.push(SectionRef { + kind, + key, + type_tag, + payload, + }); + } + + Ok(MolData { + n_atoms, + positions_bytes, + atomic_numbers_bytes, + sections, + }) +} + /// Parse SOA format bytes into MolData without allocation. /// /// Layout: @@ -38,10 +75,13 @@ pub(crate) fn parse_mol_fast_soa( record_format: u32, positions_type_hint: Option, ) -> atompack::Result> { + if record_format == RECORD_FORMAT_SOA_V2 { + return parse_mol_fast_soa_v2(bytes); + } + let mut pos = 0usize; let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; let positions_type = match record_format { - RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, RECORD_FORMAT_SOA_V3 => positions_type_hint .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, _ => { @@ -231,6 +271,53 @@ pub(crate) fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { } } +fn database_schema_section( + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + n_atoms: usize, +) -> PyResult { + let per_atom = is_per_atom(kind, key, type_tag); + let elem_bytes = if type_tag == TYPE_STRING { + 0 + } else { + let elem_bytes = type_tag_elem_bytes(type_tag); + if elem_bytes == 0 { + return Err(PyValueError::new_err(format!( + "Unsupported section type tag {} for '{}'", + type_tag, key + ))); + } + elem_bytes + }; + let slot_bytes = if type_tag == TYPE_STRING { + 0 + } else if per_atom { + let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { + PyValueError::new_err(format!("Section '{}' payload length overflow", key)) + })?; + if payload_len != expected { + return Err(PyValueError::new_err(format!( + "Section '{}' has invalid payload length {} (expected {})", + key, payload_len, expected + ))); + } + elem_bytes + } else { + payload_len + }; + + Ok(DatabaseSchemaSection { + kind, + key: key.to_string(), + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + fn py_decode_float_array_data( payload: &[u8], type_tag: u8, @@ -340,12 +427,141 @@ pub(crate) struct SoaMoleculeView { } impl SoaMoleculeView { + fn from_storage_v2(bytes: SoaBytes) -> atompack::Result { + if bytes.len() < 6 { + return Err(invalid_data("SOA record too small")); + } + + let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; + let mut pos = 4usize; + let positions_start = pos; + let positions_len = n_atoms + .checked_mul(12) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + pos = pos + .checked_add(positions_len) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA record truncated at positions")); + } + + let atomic_numbers_start = pos; + pos = pos + .checked_add(n_atoms) + .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA record truncated at atomic_numbers")); + } + + let n_sections = + u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; + pos += 2; + + let mut forces = None; + let mut energy = None; + let mut cell = None; + let mut stress = None; + let mut charges = None; + let mut velocities = None; + let mut pbc = None; + let mut name = None; + let mut custom_sections = Vec::new(); + + for _ in 0..n_sections { + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(invalid_data("SOA section key truncated")); + } + let key_start = pos; + pos += key_len; + if pos + 5 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(slice_to_array( + &bytes[pos..pos + 4], + "SOA section payload length", + )?) as usize; + pos += 4; + let payload_start = pos; + pos = pos + .checked_add(payload_len) + .ok_or_else(|| invalid_data("SOA section payload overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA section payload truncated")); + } + + let key_bytes = &bytes[key_start..key_start + key_len]; + if kind == KIND_BUILTIN { + let slot = (payload_start, payload_len, type_tag); + match key_bytes { + b"forces" => forces = Some(slot), + b"energy" => energy = Some(slot), + b"cell" => cell = Some(slot), + b"stress" => stress = Some(slot), + b"charges" => charges = Some(slot), + b"velocities" => velocities = Some(slot), + b"pbc" => pbc = Some(slot), + b"name" => name = Some(slot), + _ => { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + } else { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + + Ok(Self { + bytes, + n_atoms, + positions_type: TYPE_VEC3_F32, + positions_start, + positions_len, + atomic_numbers_start, + forces, + energy, + cell, + stress, + charges, + velocities, + pbc, + name, + custom_sections, + }) + } + /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. fn from_storage_inner( bytes: SoaBytes, record_format: u32, positions_type_hint: Option, ) -> atompack::Result { + if record_format == RECORD_FORMAT_SOA_V2 { + return Self::from_storage_v2(bytes); + } + if bytes.len() < 6 { return Err(invalid_data("SOA record too small")); } @@ -353,7 +569,6 @@ impl SoaMoleculeView { let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; let mut pos = 4usize; let positions_type = match record_format { - RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, RECORD_FORMAT_SOA_V3 => positions_type_hint .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, _ => { @@ -520,6 +735,134 @@ impl SoaMoleculeView { &self.bytes[self.positions_start..self.positions_start + self.positions_len] } + #[inline] + pub(crate) fn raw_bytes(&self) -> &[u8] { + self.bytes.as_slice() + } + + pub(crate) fn database_schema(&self) -> PyResult { + let mut sections = Vec::with_capacity(8 + self.custom_sections.len()); + + if let Some(slot) = self.energy { + sections.push(database_schema_section( + KIND_BUILTIN, + "energy", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.forces { + sections.push(database_schema_section( + KIND_BUILTIN, + "forces", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.charges { + sections.push(database_schema_section( + KIND_BUILTIN, + "charges", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.velocities { + sections.push(database_schema_section( + KIND_BUILTIN, + "velocities", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.cell { + sections.push(database_schema_section( + KIND_BUILTIN, + "cell", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.stress { + sections.push(database_schema_section( + KIND_BUILTIN, + "stress", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.pbc { + sections.push(database_schema_section( + KIND_BUILTIN, + "pbc", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.name { + sections.push(database_schema_section( + KIND_BUILTIN, + "name", + slot.2, + slot.1, + self.n_atoms, + )?); + } + for section in &self.custom_sections { + sections.push(database_schema_section( + section.kind, + self.lazy_section_key(section)?, + section.type_tag, + section.payload_len, + self.n_atoms, + )?); + } + + Ok(DatabaseSchema { + positions_type: Some(self.positions_type), + sections, + }) + } + + pub(crate) fn same_schema_as(&self, other: &Self) -> PyResult { + if self.positions_type != other.positions_type + || self.energy != other.energy + || self.forces != other.forces + || self.charges != other.charges + || self.velocities != other.velocities + || self.cell != other.cell + || self.stress != other.stress + || self.pbc != other.pbc + || self.name != other.name + || self.custom_sections.len() != other.custom_sections.len() + { + return Ok(false); + } + + for (left, right) in self + .custom_sections + .iter() + .zip(other.custom_sections.iter()) + { + if left.kind != right.kind + || left.type_tag != right.type_tag + || left.payload_len != right.payload_len + || self.lazy_section_key(left)? != other.lazy_section_key(right)? + { + return Ok(false); + } + } + + Ok(true) + } + pub(crate) fn atomic_numbers_bytes(&self) -> &[u8] { &self.bytes[self.atomic_numbers_start..self.atomic_numbers_start + self.n_atoms] } diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 62daf7e..4d83dfd 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -3,7 +3,6 @@ from pathlib import Path import pickle -import zlib import atompack import numpy as np @@ -63,26 +62,18 @@ def test_database_rejects_invalid_compression(tmp_path: Path) -> None: with pytest.raises(ValueError, match=r"Invalid compression"): atompack.Database(str(tmp_path / "bad.atp"), compression="definitely-not-a-codec") - -def _rewrite_record_format_v2(path: Path) -> None: - header_slot_size = 4096 - for slot_offset in (0, header_slot_size): - with path.open("r+b") as handle: - handle.seek(slot_offset) - slot = bytearray(handle.read(header_slot_size)) - slot[56:60] = (2).to_bytes(4, "little") - checksum = zlib.adler32(slot[:-4]) & 0xFFFFFFFF - slot[-4:] = checksum.to_bytes(4, "little") - handle.seek(slot_offset) - handle.write(slot) - - def test_database_add_arrays_batch_rejects_v2_incompatible_builtin_dtype(tmp_path: Path) -> None: path = tmp_path / "batch_arrays_v2_compat.atp" - atompack.Database(str(path)) - _rewrite_record_format_v2(path) + db = atompack.Database(str(path)) + db.add_molecule( + atompack.Molecule( + np.array([[0.0, 0.0, 0.0]], dtype=np.float32), + np.array([6], dtype=np.uint8), + ) + ) + db.flush() - db = atompack.Database.open(str(path)) + db = atompack.Database.open(str(path), mmap=False) positions = np.array([[[0.0, 0.0, 0.0]]], dtype=np.float32) atomic_numbers = np.array([[6]], dtype=np.uint8) cell = np.eye(3, dtype=np.float32)[None, ...] @@ -91,6 +82,29 @@ def test_database_add_arrays_batch_rejects_v2_incompatible_builtin_dtype(tmp_pat db.add_arrays_batch(positions, atomic_numbers, cell=cell) +def test_database_add_molecules_view_passthrough_preserves_v3_positions_dtype( + tmp_path: Path, +) -> None: + source_path = tmp_path / "source_v3.atp" + target_path = tmp_path / "target_v3.atp" + + positions = np.array([[0.0, 0.0, 0.0], [1.25, 0.0, 0.0]], dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + + source = atompack.Database(str(source_path)) + source.add_molecule(atompack.Molecule.from_arrays(positions, atomic_numbers)) + source.flush() + + source_view_db = atompack.Database.open(str(source_path)) + target = atompack.Database(str(target_path)) + target.add_molecules([source_view_db[0]]) + target.flush() + + roundtrip = atompack.Database.open(str(target_path))[0] + np.testing.assert_allclose(roundtrip.positions, positions) + assert roundtrip.positions.dtype == np.float64 + + def test_database_roundtrip_from_arrays_with_builtins(tmp_path: Path) -> None: path = tmp_path / "from_arrays_builtins.atp" positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32) @@ -276,6 +290,40 @@ def test_database_add_arrays_batch_roundtrip_with_custom_properties(tmp_path: Pa assert second.get_property("phase") == "valid" +def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed( + tmp_path: Path, +) -> None: + path = tmp_path / "batch_arrays_float64_geometry.atp" + positions = np.array( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.5, 0.1, 0.2], [1.5, 0.1, 0.2]], + ], + dtype=np.float64, + ) + atomic_numbers = np.array([[6, 8], [1, 8]], dtype=np.uint8) + forces = np.array( + [ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + [[0.6, 0.5, 0.4], [0.3, 0.2, 0.1]], + ], + dtype=np.float64, + ) + + db = atompack.Database(str(path)) + db.add_arrays_batch(positions, atomic_numbers, forces=forces) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + first = reopened[0] + flat = reopened.get_molecules_flat([0, 1]) + + assert first.positions.dtype == np.float64 + assert first.forces.dtype == np.float64 + assert flat["positions"].dtype == np.float64 + assert flat["forces"].dtype == np.float64 + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_reads_are_view_compatible( diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index c7be8e7..3edec41 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -281,8 +281,12 @@ impl Molecule { } pub fn from_atoms(atoms: Vec) -> Self { - let positions = atoms.iter().map(|atom| atom.position()).collect(); - let atomic_numbers = atoms.iter().map(|atom| atom.atomic_number).collect(); + let mut positions = Vec::with_capacity(atoms.len()); + let mut atomic_numbers = Vec::with_capacity(atoms.len()); + for atom in atoms { + positions.push([atom.x, atom.y, atom.z]); + atomic_numbers.push(atom.atomic_number); + } Self { name: None, positions: Vec3Data::F32(positions), diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index 48f8112..8c070d1 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -41,10 +41,12 @@ mod soa; use self::header::{Header, encode_header_slot, read_best_header}; use self::index::{IndexStorage, MoleculeIndex, decode_index, encode_index}; use self::schema::{ - SchemaLock, decode_schema_lock, encode_schema_lock, merge_schema_lock, record_schema, - schema_from_molecule, + SchemaEntry, SchemaLock, decode_schema_lock, encode_schema_lock, merge_schema_lock, + record_schema, schema_from_molecule, validate_schema_lock_for_record_format, +}; +use self::soa::{ + arr, deserialize_molecule_soa, minimum_record_format_for_molecule, serialize_molecule_soa, }; -use self::soa::{arr, deserialize_molecule_soa, serialize_molecule_soa}; // --------------------------------------------------------------------------- // Constants @@ -55,7 +57,7 @@ const MAGIC: &[u8; 4] = b"ATPK"; const FILE_FORMAT_VERSION: u32 = 2; const RECORD_FORMAT_SOA_V2: u32 = 2; const RECORD_FORMAT_SOA_V3: u32 = 3; -const RECORD_FORMAT_SOA: u32 = RECORD_FORMAT_SOA_V3; +const RECORD_FORMAT_SOA: u32 = RECORD_FORMAT_SOA_V2; // Section kind tags (inside each SOA record) const KIND_BUILTIN: u8 = 0; @@ -106,6 +108,63 @@ impl SharedMmapBytes { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DatabaseSchemaSection { + pub kind: u8, + pub key: String, + pub type_tag: u8, + pub per_atom: bool, + pub elem_bytes: usize, + pub slot_bytes: usize, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct DatabaseSchema { + pub positions_type: Option, + pub sections: Vec, +} + +fn database_schema_from_lock(lock: &SchemaLock) -> DatabaseSchema { + let sections = lock + .sections + .iter() + .map(|((kind, key), entry)| DatabaseSchemaSection { + kind: *kind, + key: key.clone(), + type_tag: entry.type_tag, + per_atom: entry.per_atom, + elem_bytes: entry.elem_bytes, + slot_bytes: entry.slot_bytes, + }) + .collect(); + DatabaseSchema { + positions_type: lock.positions_type, + sections, + } +} + +fn schema_lock_from_database_schema(schema: DatabaseSchema) -> SchemaLock { + let sections = schema + .sections + .into_iter() + .map(|section| { + ( + (section.kind, section.key), + SchemaEntry { + type_tag: section.type_tag, + per_atom: section.per_atom, + elem_bytes: section.elem_bytes, + slot_bytes: section.slot_bytes, + }, + ) + }) + .collect(); + SchemaLock { + positions_type: schema.positions_type, + sections, + } +} + pub struct AtomDatabase { path: PathBuf, compression: CompressionType, @@ -350,6 +409,23 @@ impl AtomDatabase { Ok(lock) } + fn can_promote_record_format(&self) -> bool { + self.record_format == RECORD_FORMAT_SOA_V2 + && self.index.is_empty() + && self.schema_lock.is_none() + } + + fn resolved_record_format_for_schema(&self, schema: &SchemaLock) -> Result { + match validate_schema_lock_for_record_format(self.record_format, schema) { + Ok(()) => Ok(self.record_format), + Err(current_err) if self.can_promote_record_format() => { + validate_schema_lock_for_record_format(RECORD_FORMAT_SOA_V3, schema)?; + Ok(RECORD_FORMAT_SOA_V3) + } + Err(current_err) => Err(current_err), + } + } + fn ensure_schema_compatible<'a, I>(&mut self, records: I) -> Result<()> where I: IntoIterator)>, @@ -359,17 +435,112 @@ impl AtomDatabase { None if self.index.is_empty() => SchemaLock::default(), None => self.rebuild_schema_lock()?, }; + let mut record_format = self.record_format; + let can_promote = self.can_promote_record_format(); for (bytes, positions_type_hint) in records { let hint = positions_type_hint.or(lock.positions_type); - let record = record_schema(bytes, self.record_format, hint)?; + let record = match record_schema(bytes, record_format, hint) { + Ok(record) => record, + Err(current_err) if can_promote && record_format == RECORD_FORMAT_SOA_V2 => { + let record = record_schema(bytes, RECORD_FORMAT_SOA_V3, hint)?; + record_format = RECORD_FORMAT_SOA_V3; + record + } + Err(current_err) => return Err(current_err), + }; merge_schema_lock(&mut lock, &record)?; } + self.record_format = record_format; self.schema_lock = Some(lock); Ok(()) } + fn ensure_schema_lock(&mut self, incoming: &SchemaLock) -> Result<()> { + self.record_format = self.resolved_record_format_for_schema(incoming)?; + let mut lock = match &self.schema_lock { + Some(lock) => lock.clone(), + None if self.index.is_empty() => SchemaLock::default(), + None => self.rebuild_schema_lock()?, + }; + merge_schema_lock(&mut lock, incoming)?; + self.schema_lock = Some(lock); + Ok(()) + } + + fn write_owned_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { + let compression = self.compression; + + let compressed_records: Vec<(Vec, u32, u32)> = records + .into_par_iter() + .map(|(bytes, num_atoms)| { + let uncompressed_size = bytes.len() as u32; + let compressed = compress(&bytes, compression)?; + Ok((compressed, uncompressed_size, num_atoms)) + }) + .collect::>>()?; + + let mut file = OpenOptions::new().append(true).open(&self.path)?; + + let mut offset = file.seek(SeekFrom::End(0))?; + let mut new_indices = Vec::with_capacity(compressed_records.len()); + + for (compressed_data, uncompressed_size, num_atoms) in compressed_records { + file.write_all(&compressed_data)?; + + new_indices.push(MoleculeIndex { + offset, + compressed_size: compressed_data.len() as u32, + uncompressed_size, + num_atoms, + }); + + offset += compressed_data.len() as u64; + } + + file.flush()?; + self.index.extend(new_indices)?; + + Ok(()) + } + + fn write_borrowed_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { + let compression = self.compression; + + let compressed_records: Vec<(Vec, u32, u32)> = records + .par_iter() + .map(|(bytes, num_atoms)| { + let uncompressed_size = bytes.len() as u32; + let compressed = compress(bytes, compression)?; + Ok((compressed, uncompressed_size, *num_atoms)) + }) + .collect::>>()?; + + let mut file = OpenOptions::new().append(true).open(&self.path)?; + + let mut offset = file.seek(SeekFrom::End(0))?; + let mut new_indices = Vec::with_capacity(compressed_records.len()); + + for (compressed_data, uncompressed_size, num_atoms) in compressed_records { + file.write_all(&compressed_data)?; + + new_indices.push(MoleculeIndex { + offset, + compressed_size: compressed_data.len() as u32, + uncompressed_size, + num_atoms, + }); + + offset += compressed_data.len() as u64; + } + + file.flush()?; + self.index.extend(new_indices)?; + + Ok(()) + } + // -- Writing ------------------------------------------------------------- /// Add a single molecule. @@ -384,20 +555,32 @@ impl AtomDatabase { return Ok(()); } - let serialized: Vec<(Vec, u32, u8)> = molecules + let target_format = if self.can_promote_record_format() + && molecules.iter().any(|molecule| { + minimum_record_format_for_molecule(molecule) == RECORD_FORMAT_SOA_V3 + }) { + RECORD_FORMAT_SOA_V3 + } else { + self.record_format + }; + + let serialized: Vec<(Vec, u32, SchemaLock)> = molecules .par_iter() .map(|mol| { - let bytes = serialize_molecule_soa(mol, self.record_format)?; + let bytes = serialize_molecule_soa(mol, target_format)?; let num_atoms = mol.len() as u32; - Ok(( - bytes, - num_atoms, - schema_from_molecule(mol)?.positions_type.unwrap(), - )) + Ok((bytes, num_atoms, schema_from_molecule(mol)?)) }) .collect::>>()?; - self.append_owned_soa_records(serialized) + let mut batch_schema = SchemaLock::default(); + let mut records = Vec::with_capacity(serialized.len()); + for (bytes, num_atoms, schema) in serialized { + merge_schema_lock(&mut batch_schema, &schema)?; + records.push((bytes, num_atoms)); + } + + self.append_owned_soa_records_prevalidated(records, batch_schema) } /// Add pre-serialized SOA records, compressing in parallel and appending to the file. @@ -410,7 +593,38 @@ impl AtomDatabase { if records.is_empty() { return Ok(()); } - self.append_soa_records(records) + self.append_soa_records( + records + .iter() + .map(|(bytes, num_atoms)| (*bytes, *num_atoms, None)), + ) + } + + #[doc(hidden)] + pub fn add_raw_soa_records_with_positions_type( + &mut self, + records: &[(&[u8], u32, u8)], + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_soa_records( + records.iter().map(|(bytes, num_atoms, positions_type)| { + (*bytes, *num_atoms, Some(*positions_type)) + }), + ) + } + + #[doc(hidden)] + pub fn add_raw_soa_records_with_schema( + &mut self, + records: &[(&[u8], u32)], + schema: DatabaseSchema, + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_raw_soa_records_prevalidated(records, schema_lock_from_database_schema(schema)) } #[doc(hidden)] @@ -421,7 +635,25 @@ impl AtomDatabase { self.append_owned_soa_records(records) } - fn append_soa_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { + #[doc(hidden)] + pub fn add_owned_soa_records_with_schema( + &mut self, + records: Vec<(Vec, u32)>, + schema: DatabaseSchema, + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_owned_soa_records_prevalidated( + records, + schema_lock_from_database_schema(schema), + ) + } + + fn append_soa_records<'a, I>(&mut self, records: I) -> Result<()> + where + I: IntoIterator)>, + { if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { return Err(Error::InvalidData( "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." @@ -429,44 +661,22 @@ impl AtomDatabase { )); } - self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_compatible(records.iter().map(|(bytes, _)| (*bytes, None)))?; - - let compression = self.compression; - - // Step 1: Compress all records in parallel. - let compressed_records: Vec<(Vec, u32, u32)> = records - .par_iter() - .map(|(bytes, num_atoms)| { - let uncompressed_size = bytes.len() as u32; - let compressed = compress(bytes, compression)?; - Ok((compressed, uncompressed_size, *num_atoms)) - }) - .collect::>>()?; - - // Step 2: Write all compressed data sequentially (file I/O must be sequential) - let mut file = OpenOptions::new().append(true).open(&self.path)?; - - let mut offset = file.seek(SeekFrom::End(0))?; - let mut new_indices = Vec::with_capacity(compressed_records.len()); - - for (compressed_data, uncompressed_size, num_atoms) in compressed_records { - file.write_all(&compressed_data)?; - - new_indices.push(MoleculeIndex { - offset, - compressed_size: compressed_data.len() as u32, - uncompressed_size, - num_atoms, - }); - - offset += compressed_data.len() as u64; + let records: Vec<(&[u8], u32, Option)> = records.into_iter().collect(); + if records.is_empty() { + return Ok(()); } - file.flush()?; - self.index.extend(new_indices)?; - - Ok(()) + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_compatible( + records + .iter() + .map(|(bytes, _, positions_type_hint)| (*bytes, *positions_type_hint)), + )?; + let borrowed = records + .iter() + .map(|(bytes, num_atoms, _)| (*bytes, *num_atoms)) + .collect::>(); + self.write_borrowed_records(&borrowed) } fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { @@ -484,39 +694,45 @@ impl AtomDatabase { .map(|(bytes, _, positions_type)| (bytes.as_slice(), Some(*positions_type))), )?; - let compression = self.compression; - - let compressed_records: Vec<(Vec, u32, u32)> = records - .into_par_iter() - .map(|(bytes, num_atoms, _positions_type)| { - let uncompressed_size = bytes.len() as u32; - let compressed = compress(&bytes, compression)?; - Ok((compressed, uncompressed_size, num_atoms)) - }) - .collect::>>()?; - - let mut file = OpenOptions::new().append(true).open(&self.path)?; - - let mut offset = file.seek(SeekFrom::End(0))?; - let mut new_indices = Vec::with_capacity(compressed_records.len()); + let records = records + .into_iter() + .map(|(bytes, num_atoms, _positions_type)| (bytes, num_atoms)) + .collect(); + self.write_owned_records(records) + } - for (compressed_data, uncompressed_size, num_atoms) in compressed_records { - file.write_all(&compressed_data)?; + fn append_owned_soa_records_prevalidated( + &mut self, + records: Vec<(Vec, u32)>, + batch_schema: SchemaLock, + ) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } - new_indices.push(MoleculeIndex { - offset, - compressed_size: compressed_data.len() as u32, - uncompressed_size, - num_atoms, - }); + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_lock(&batch_schema)?; + self.write_owned_records(records) + } - offset += compressed_data.len() as u64; + fn append_raw_soa_records_prevalidated( + &mut self, + records: &[(&[u8], u32)], + batch_schema: SchemaLock, + ) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); } - file.flush()?; - self.index.extend(new_indices)?; - - Ok(()) + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_lock(&batch_schema)?; + self.write_borrowed_records(records) } // -- Reading ------------------------------------------------------------- @@ -718,6 +934,15 @@ impl AtomDatabase { .and_then(|lock| lock.positions_type) } + pub fn schema_info(&self) -> Option { + self.schema_lock.as_ref().map(database_schema_from_lock) + } + + #[doc(hidden)] + pub fn record_format_for_schema(&self, schema: DatabaseSchema) -> Result { + self.resolved_record_format_for_schema(&schema_lock_from_database_schema(schema)) + } + /// Compressed bytes for a molecule from the mmap (None if no mmap). pub fn get_compressed_slice(&self, index: usize) -> Option<&[u8]> { let mol_index = self.index.get(index)?; @@ -1272,6 +1497,35 @@ mod tests { assert!(format!("{}", err).contains("Position dtype mismatch")); } + #[test] + fn test_empty_database_stays_v2_for_v2_compatible_first_write() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol.energy = Some(FloatScalarData::F64(-1.0)); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3]])); + + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + db.add_molecule(&mol).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + } + + #[test] + fn test_empty_database_promotes_to_v3_when_first_write_requires_it() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol.forces = Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]])); + + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + db.add_molecule(&mol).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V3); + } + #[test] fn test_schema_lock_rejects_custom_shape_mismatch() { use crate::atom::PropertyValue; @@ -1299,7 +1553,10 @@ mod tests { let temp = NamedTempFile::new().unwrap(); let path = temp.path().to_path_buf(); let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); - db.record_format = RECORD_FORMAT_SOA_V2; + + let legacy = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + db.add_molecule(&legacy).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); let mut mol = molecule_from_atoms(vec![Atom::new(0.0, 0.0, 0.0, 6)]); mol.cell = Some(Mat3Data::F32([ diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs index b9c036a..14da316 100644 --- a/atompack/src/storage/schema.rs +++ b/atompack/src/storage/schema.rs @@ -1,6 +1,4 @@ -use super::soa::{ - arr, positions_stride, property_value_to_bytes, property_value_type_tag, resolve_positions_type, -}; +use super::soa::{arr, positions_stride, property_value_type_tag, resolve_positions_type}; use super::*; use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; use std::collections::BTreeMap; @@ -21,7 +19,7 @@ pub(super) struct SchemaEntry { const SCHEMA_BLOB_VERSION: u32 = 1; -fn positions_type_from_molecule(molecule: &Molecule) -> u8 { +pub(super) fn positions_type_from_molecule(molecule: &Molecule) -> u8 { match molecule.positions { Vec3Data::F32(_) => TYPE_VEC3_F32, Vec3Data::F64(_) => TYPE_VEC3_F64, @@ -236,6 +234,32 @@ fn validate_builtin_type_tag_for_record_format( } } +pub(super) fn validate_schema_lock_for_record_format( + record_format: u32, + schema: &SchemaLock, +) -> Result<()> { + let _ = resolve_positions_type(record_format, schema.positions_type)?; + for ((kind, key), entry) in &schema.sections { + if *kind == KIND_BUILTIN { + validate_builtin_type_tag_for_record_format(record_format, key, entry.type_tag)?; + } + } + Ok(()) +} + +fn property_value_payload_len(value: &PropertyValue) -> usize { + match value { + PropertyValue::Float(_) | PropertyValue::Int(_) => 8, + PropertyValue::String(value) => value.len(), + PropertyValue::FloatArray(values) => values.len() * 8, + PropertyValue::Vec3Array(values) => values.len() * 12, + PropertyValue::IntArray(values) => values.len() * 8, + PropertyValue::Float32Array(values) => values.len() * 4, + PropertyValue::Vec3ArrayF64(values) => values.len() * 24, + PropertyValue::Int32Array(values) => values.len() * 4, + } +} + pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { let n_atoms = molecule.len(); let mut schema = SchemaLock { @@ -303,7 +327,7 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { KIND_ATOM_PROP, key, property_value_type_tag(value), - property_value_to_bytes(value).len(), + property_value_payload_len(value), )?; } for (key, value) in &molecule.properties { @@ -311,7 +335,7 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { KIND_MOL_PROP, key, property_value_type_tag(value), - property_value_to_bytes(value).len(), + property_value_payload_len(value), )?; } diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index df37f9f..6c52d7b 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -25,6 +25,19 @@ pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { } } +fn property_value_payload_len(value: &PropertyValue) -> usize { + match value { + PropertyValue::Float(_) | PropertyValue::Int(_) => 8, + PropertyValue::String(value) => value.len(), + PropertyValue::FloatArray(values) => values.len() * 8, + PropertyValue::Vec3Array(values) => values.len() * 12, + PropertyValue::IntArray(values) => values.len() * 8, + PropertyValue::Float32Array(values) => values.len() * 4, + PropertyValue::Vec3ArrayF64(values) => values.len() * 24, + PropertyValue::Int32Array(values) => values.len() * 4, + } +} + fn extend_f64(b: &mut Vec, v: &[f64]) { for x in v { b.extend_from_slice(&f64::to_le_bytes(*x)); @@ -192,18 +205,44 @@ pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { fn write_vec3_section(buf: &mut Vec, key: &str, values: &Vec3Data) { match values { Vec3Data::F32(values) => { - let mut payload = Vec::with_capacity(values.len() * 12); - for value in values { - extend_f32(&mut payload, value); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_VEC3_F32, + bytemuck::cast_slice::<[f32; 3], u8>(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 12); + for value in values { + extend_f32(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F32, &payload); } - write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F32, &payload); } Vec3Data::F64(values) => { - let mut payload = Vec::with_capacity(values.len() * 24); - for value in values { - extend_f64(&mut payload, value); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_VEC3_F64, + bytemuck::cast_slice::<[f64; 3], u8>(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 24); + for value in values { + extend_f64(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F64, &payload); } - write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F64, &payload); } } } @@ -211,14 +250,40 @@ fn write_vec3_section(buf: &mut Vec, key: &str, values: &Vec3Data) { fn write_float_array_section(buf: &mut Vec, key: &str, values: &FloatArrayData) { match values { FloatArrayData::F32(values) => { - let mut payload = Vec::with_capacity(values.len() * 4); - extend_f32(&mut payload, values); - write_section(buf, KIND_BUILTIN, key, TYPE_F32_ARRAY, &payload); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_F32_ARRAY, + bytemuck::cast_slice::(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 4); + extend_f32(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F32_ARRAY, &payload); + } } FloatArrayData::F64(values) => { - let mut payload = Vec::with_capacity(values.len() * 8); - extend_f64(&mut payload, values); - write_section(buf, KIND_BUILTIN, key, TYPE_F64_ARRAY, &payload); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_F64_ARRAY, + bytemuck::cast_slice::(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 8); + extend_f64(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F64_ARRAY, &payload); + } } } } @@ -226,18 +291,44 @@ fn write_float_array_section(buf: &mut Vec, key: &str, values: &FloatArrayDa fn write_mat3_section(buf: &mut Vec, key: &str, values: &Mat3Data) { match values { Mat3Data::F32(values) => { - let mut payload = Vec::with_capacity(36); - for row in values { - extend_f32(&mut payload, row); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_MAT3X3_F32, + bytemuck::cast_slice::<[f32; 3], u8>(values.as_slice()), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(36); + for row in values { + extend_f32(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F32, &payload); } - write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F32, &payload); } Mat3Data::F64(values) => { - let mut payload = Vec::with_capacity(72); - for row in values { - extend_f64(&mut payload, row); + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_MAT3X3_F64, + bytemuck::cast_slice::<[f64; 3], u8>(values.as_slice()), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(72); + for row in values { + extend_f64(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F64, &payload); } - write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F64, &payload); } } } @@ -345,16 +436,38 @@ fn validate_record_format_compat(molecule: &Molecule, record_format: u32) -> Res } } +pub(super) fn minimum_record_format_for_molecule(molecule: &Molecule) -> u32 { + if validate_record_format_compat(molecule, RECORD_FORMAT_SOA_V2).is_ok() { + RECORD_FORMAT_SOA_V2 + } else { + RECORD_FORMAT_SOA_V3 + } +} + fn write_positions(buf: &mut Vec, positions: &Vec3Data) { match positions { Vec3Data::F32(values) => { - for value in values { - extend_f32(buf, value); + #[cfg(target_endian = "little")] + { + buf.extend_from_slice(bytemuck::cast_slice::<[f32; 3], u8>(values)); + } + #[cfg(not(target_endian = "little"))] + { + for value in values { + extend_f32(buf, value); + } } } Vec3Data::F64(values) => { - for value in values { - extend_f64(buf, value); + #[cfg(target_endian = "little")] + { + buf.extend_from_slice(bytemuck::cast_slice::<[f64; 3], u8>(values)); + } + #[cfg(not(target_endian = "little"))] + { + for value in values { + extend_f64(buf, value); + } } } } @@ -391,6 +504,76 @@ fn count_sections(molecule: &Molecule) -> u16 { n_sections } +fn section_size(key_len: usize, payload_len: usize) -> usize { + 1 + 1 + key_len + 1 + 4 + payload_len +} + +fn estimate_serialized_len(molecule: &Molecule) -> usize { + let positions_bytes = match &molecule.positions { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + let mut total = 4 + positions_bytes + molecule.atomic_numbers.len() + 2; + + if let Some(charges) = &molecule.charges { + let payload_len = match charges { + FloatArrayData::F32(values) => values.len() * 4, + FloatArrayData::F64(values) => values.len() * 8, + }; + total += section_size("charges".len(), payload_len); + } + if let Some(cell) = &molecule.cell { + let payload_len = match cell { + Mat3Data::F32(_) => 36, + Mat3Data::F64(_) => 72, + }; + total += section_size("cell".len(), payload_len); + } + if let Some(energy) = &molecule.energy { + let payload_len = match energy { + FloatScalarData::F32(_) => 4, + FloatScalarData::F64(_) => 8, + }; + total += section_size("energy".len(), payload_len); + } + if let Some(forces) = &molecule.forces { + let payload_len = match forces { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + total += section_size("forces".len(), payload_len); + } + if let Some(name) = &molecule.name { + total += section_size("name".len(), name.len()); + } + if molecule.pbc.is_some() { + total += section_size("pbc".len(), 3); + } + if let Some(stress) = &molecule.stress { + let payload_len = match stress { + Mat3Data::F32(_) => 36, + Mat3Data::F64(_) => 72, + }; + total += section_size("stress".len(), payload_len); + } + if let Some(velocities) = &molecule.velocities { + let payload_len = match velocities { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + total += section_size("velocities".len(), payload_len); + } + + for (key, value) in &molecule.atom_properties { + total += section_size(key.len(), property_value_payload_len(value)); + } + for (key, value) in &molecule.properties { + total += section_size(key.len(), property_value_payload_len(value)); + } + + total +} + fn write_sections(buf: &mut Vec, molecule: &Molecule) { if let Some(ref charges) = molecule.charges { write_float_array_section(buf, "charges", charges); @@ -534,7 +717,7 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> validate_record_format_compat(molecule, record_format)?; let n = molecule.len(); - let mut buf = Vec::new(); + let mut buf = Vec::with_capacity(estimate_serialized_len(molecule)); buf.extend_from_slice(&(n as u32).to_le_bytes()); write_positions(&mut buf, &molecule.positions); @@ -562,9 +745,30 @@ fn decode_positions( "SOA record truncated at positions".into(), )); } + let payload = &bytes[*pos..positions_end]; let positions = match positions_type { - TYPE_VEC3_F32 => Vec3Data::F32(decode_vec3_f32(&bytes[*pos..positions_end])?), - TYPE_VEC3_F64 => Vec3Data::F64(decode_vec3_f64(&bytes[*pos..positions_end])?), + TYPE_VEC3_F32 => { + let mut values = Vec::with_capacity(n_atoms); + for chunk in payload.chunks_exact(12) { + values.push([ + f32::from_le_bytes(arr(&chunk[0..4])?), + f32::from_le_bytes(arr(&chunk[4..8])?), + f32::from_le_bytes(arr(&chunk[8..12])?), + ]); + } + Vec3Data::F32(values) + } + TYPE_VEC3_F64 => { + let mut values = Vec::with_capacity(n_atoms); + for chunk in payload.chunks_exact(24) { + values.push([ + f64::from_le_bytes(arr(&chunk[0..8])?), + f64::from_le_bytes(arr(&chunk[8..16])?), + f64::from_le_bytes(arr(&chunk[16..24])?), + ]); + } + Vec3Data::F64(values) + } _ => { return Err(Error::InvalidData(format!( "Unsupported positions type tag {}", From ac7949cc643d53d38047a80cc89633d11157906e Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sun, 10 May 2026 14:26:50 +0200 Subject: [PATCH 6/9] refactor: simplify typed-float rust and pyo3 paths --- atompack-py/src/database.rs | 168 ++++----- atompack-py/src/database_batch.rs | 486 ++++++++++++------------- atompack-py/src/database_flat.rs | 30 +- atompack-py/src/lib.rs | 11 +- atompack-py/src/molecule.rs | 61 +--- atompack-py/src/molecule_helpers.rs | 423 +++++++--------------- atompack-py/src/py_dtypes.rs | 310 ++++++++++++++++ atompack-py/src/soa.rs | 321 +++++----------- atompack/src/atom.rs | 162 +-------- atompack/src/lib.rs | 6 +- atompack/src/storage/dtypes.rs | 409 +++++++++++++++++++++ atompack/src/storage/mod.rs | 79 ++-- atompack/src/storage/schema.rs | 147 +++----- atompack/src/storage/soa.rs | 542 +++++----------------------- atompack/src/types.rs | 162 +++++++++ 15 files changed, 1595 insertions(+), 1722 deletions(-) create mode 100644 atompack-py/src/py_dtypes.rs create mode 100644 atompack/src/storage/dtypes.rs create mode 100644 atompack/src/types.rs diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index ac3688d..c890ae8 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -11,51 +11,42 @@ pub(crate) struct PyAtomDatabase { } impl PyAtomDatabase { + fn soa_context(&self) -> PyResult { + SoaContext::from_database(&self.inner).map_err(|e| PyValueError::new_err(format!("{}", e))) + } + + fn load_mmap_view( + &self, + index: usize, + compression: CompressionType, + ctx: SoaContext, + ) -> atompack::Result { + if compression == CompressionType::None { + let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| { + invalid_data(format!("Missing mmap bytes for molecule {}", index)) + })?; + return SoaMoleculeView::from_shared_bytes(bytes, ctx); + } + + let compressed = self.inner.get_compressed_slice(index).ok_or_else(|| { + invalid_data(format!("Missing compressed bytes for molecule {}", index)) + })?; + let uncompressed_size = self.inner.uncompressed_size(index).ok_or_else(|| { + invalid_data(format!("Missing uncompressed size for molecule {}", index)) + })? as usize; + let decompressed = + atompack::decompress_bytes(compressed, compression, Some(uncompressed_size))?; + SoaMoleculeView::from_owned_bytes(decompressed, ctx) + } + fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult { let compression = self.inner.compression(); - let record_format = self.inner.record_format(); - let positions_type = self.inner.positions_type(); + let ctx = self.soa_context()?; let use_mmap = self.inner.get_compressed_slice(0).is_some(); if use_mmap { return py - .detach(|| -> atompack::Result { - if compression == CompressionType::None { - let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| { - invalid_data(format!("Missing mmap bytes for molecule {}", index)) - })?; - SoaMoleculeView::from_shared_bytes_inner( - bytes, - record_format, - positions_type, - ) - } else { - let compressed = - self.inner.get_compressed_slice(index).ok_or_else(|| { - invalid_data(format!( - "Missing compressed bytes for molecule {}", - index - )) - })?; - let uncompressed_size = - self.inner.uncompressed_size(index).ok_or_else(|| { - invalid_data(format!( - "Missing uncompressed size for molecule {}", - index - )) - })? as usize; - let decompressed = atompack::decompress_bytes( - compressed, - compression, - Some(uncompressed_size), - )?; - SoaMoleculeView::from_bytes_inner( - decompressed, - record_format, - positions_type, - ) - } - }) + .detach(|| self.load_mmap_view(index, compression, ctx)) .map_err(|e| PyValueError::new_err(format!("{}", e))); } @@ -65,7 +56,40 @@ impl PyAtomDatabase { let raw = raw_bytes.pop().ok_or_else(|| { PyValueError::new_err(format!("Missing raw bytes for molecule {}", index)) })?; - SoaMoleculeView::from_bytes(raw, record_format, positions_type) + SoaMoleculeView::from_bytes(raw, ctx) + } + + fn molecule_views( + &self, + py: Python<'_>, + indices: Vec, + ) -> PyResult> { + let compression = self.inner.compression(); + let ctx = self.soa_context()?; + let use_mmap = self.inner.get_compressed_slice(0).is_some(); + + let views = if use_mmap { + let results: Vec> = py.detach(|| { + use rayon::prelude::*; + indices + .par_iter() + .map(|&idx| self.load_mmap_view(idx, compression, ctx)) + .collect() + }); + results.into_iter().collect::>>() + } else { + let raw_bytes = py + .detach(|| self.inner.get_raw_bytes(&indices)) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + raw_bytes + .into_iter() + .map(|bytes| SoaMoleculeView::from_bytes(bytes, ctx)) + .collect::>>() + .map_err(|e| invalid_data(format!("{}", e))) + } + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + + Ok(views) } } @@ -274,69 +298,7 @@ impl PyAtomDatabase { if indices.is_empty() { return Ok(Vec::new()); } - - let compression = self.inner.compression(); - let record_format = self.inner.record_format(); - let positions_type = self.inner.positions_type(); - let use_mmap = self.inner.get_compressed_slice(0).is_some(); - - let views: Vec = if use_mmap { - let results: Vec> = py.detach(|| { - use rayon::prelude::*; - indices - .par_iter() - .map(|&idx| -> atompack::Result { - if compression == CompressionType::None { - let bytes = self.inner.get_shared_mmap_bytes(idx).ok_or_else(|| { - invalid_data(format!("Missing mmap bytes for molecule {}", idx)) - })?; - SoaMoleculeView::from_shared_bytes_inner( - bytes, - record_format, - positions_type, - ) - } else { - let compressed = - self.inner.get_compressed_slice(idx).ok_or_else(|| { - invalid_data(format!( - "Missing compressed bytes for molecule {}", - idx - )) - })?; - let uncompressed_size = - self.inner.uncompressed_size(idx).ok_or_else(|| { - invalid_data(format!( - "Missing uncompressed size for molecule {}", - idx - )) - })? as usize; - let decompressed = atompack::decompress_bytes( - compressed, - compression, - Some(uncompressed_size), - )?; - SoaMoleculeView::from_bytes_inner( - decompressed, - record_format, - positions_type, - ) - } - }) - .collect() - }); - results.into_iter().collect::>>() - } else { - let raw_bytes = py - .detach(|| self.inner.get_raw_bytes(&indices)) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - raw_bytes - .into_iter() - .map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format, positions_type)) - .collect::>>() - .map_err(|e| invalid_data(format!("{}", e))) - } - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - + let views = self.molecule_views(py, indices)?; Ok(views.into_iter().map(PyMolecule::from_view).collect()) } diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 40560d4..3a55bd3 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -282,35 +282,27 @@ fn extract_property_column( if let Some(column) = extract_string_column(value, batch, key, kind)? { return Ok(Some(column)); } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_scalar_column_f64(arr, batch, key, kind)?)); + if let Some(arr) = PyFloatArray1::from_any(value) { + return Ok(Some(match arr { + PyFloatArray1::F32(arr) => extract_scalar_column_f64(&arr, batch, key, kind)?, + PyFloatArray1::F64(arr) => extract_scalar_column_f64(&arr, batch, key, kind)?, + })); } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_scalar_column_f64(arr, batch, key, kind)?)); + if let Some(arr) = PyIntArray1::from_any(value) { + return Ok(Some(match arr { + PyIntArray1::I32(arr) => extract_scalar_column_i64(&arr, batch, key, kind)?, + PyIntArray1::I64(arr) => extract_scalar_column_i64(&arr, batch, key, kind)?, + })); } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_scalar_column_i64(arr, batch, key, kind)?)); - } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_scalar_column_i64(arr, batch, key, kind)?)); - } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_matrix_column( - arr, - batch, - key, - kind, - TYPE_F64_ARRAY, - )?)); - } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_matrix_column( - arr, - batch, - key, - kind, - TYPE_F32_ARRAY, - )?)); + if let Some(arr) = PyFloatArray2::from_any(value) { + return Ok(Some(match arr { + PyFloatArray2::F32(arr) => { + extract_matrix_column(&arr, batch, key, kind, TYPE_F32_ARRAY)? + } + PyFloatArray2::F64(arr) => { + extract_matrix_column(&arr, batch, key, kind, TYPE_F64_ARRAY)? + } + })); } if let Ok(arr) = value.cast::>() { return Ok(Some(extract_matrix_column( @@ -330,36 +322,40 @@ fn extract_property_column( TYPE_I32_ARRAY, )?)); } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - let shape = view.shape(); - if shape.len() == 3 && shape[0] == batch && shape[2] == 3 { - return Ok(Some(extract_vec3_column( - arr, - batch, - shape[1], - key, - kind, - TYPE_VEC3_F64, - "molecule properties", - )?)); - } - } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - let shape = view.shape(); - if shape.len() == 3 && shape[0] == batch && shape[2] == 3 { - return Ok(Some(extract_vec3_column( - arr, - batch, - shape[1], - key, - kind, - TYPE_VEC3_F32, - "molecule properties", - )?)); + if let Some(arr) = PyFloatArray3::from_any(value) { + match arr { + PyFloatArray3::F32(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape.len() == 3 && shape[0] == batch && shape[2] == 3 { + return Ok(Some(extract_vec3_column( + &arr, + batch, + shape[1], + key, + kind, + TYPE_VEC3_F32, + "molecule properties", + )?)); + } + } + PyFloatArray3::F64(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape.len() == 3 && shape[0] == batch && shape[2] == 3 { + return Ok(Some(extract_vec3_column( + &arr, + batch, + shape[1], + key, + kind, + TYPE_VEC3_F64, + "molecule properties", + )?)); + } + } } } Ok(None) @@ -371,19 +367,18 @@ fn extract_atom_property_column( n_atoms: usize, key: &str, ) -> PyResult> { - if let Ok(arr) = value.cast::>() { - let column = extract_matrix_column(arr, batch, key, KIND_ATOM_PROP, TYPE_F64_ARRAY)?; - if column.slot_bytes != n_atoms * std::mem::size_of::() { - return Err(PyValueError::new_err(format!( - "atom property '{}' must have shape ({}, {})", - key, batch, n_atoms - ))); - } - return Ok(Some(column)); - } - if let Ok(arr) = value.cast::>() { - let column = extract_matrix_column(arr, batch, key, KIND_ATOM_PROP, TYPE_F32_ARRAY)?; - if column.slot_bytes != n_atoms * std::mem::size_of::() { + if let Some(arr) = PyFloatArray2::from_any(value) { + let (column, expected) = match arr { + PyFloatArray2::F32(arr) => ( + extract_matrix_column(&arr, batch, key, KIND_ATOM_PROP, TYPE_F32_ARRAY)?, + n_atoms * std::mem::size_of::(), + ), + PyFloatArray2::F64(arr) => ( + extract_matrix_column(&arr, batch, key, KIND_ATOM_PROP, TYPE_F64_ARRAY)?, + n_atoms * std::mem::size_of::(), + ), + }; + if column.slot_bytes != expected { return Err(PyValueError::new_err(format!( "atom property '{}' must have shape ({}, {})", key, batch, n_atoms @@ -411,27 +406,27 @@ fn extract_atom_property_column( } return Ok(Some(column)); } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_vec3_column( - arr, - batch, - n_atoms, - key, - KIND_ATOM_PROP, - TYPE_VEC3_F64, - "atom properties", - )?)); - } - if let Ok(arr) = value.cast::>() { - return Ok(Some(extract_vec3_column( - arr, - batch, - n_atoms, - key, - KIND_ATOM_PROP, - TYPE_VEC3_F32, - "atom properties", - )?)); + if let Some(arr) = PyFloatArray3::from_any(value) { + return Ok(Some(match arr { + PyFloatArray3::F32(arr) => extract_vec3_column( + &arr, + batch, + n_atoms, + key, + KIND_ATOM_PROP, + TYPE_VEC3_F32, + "atom properties", + )?, + PyFloatArray3::F64(arr) => extract_vec3_column( + &arr, + batch, + n_atoms, + key, + KIND_ATOM_PROP, + TYPE_VEC3_F64, + "atom properties", + )?, + })); } Ok(None) } @@ -493,39 +488,43 @@ impl FastMat3Column { let Some(value) = value else { return Ok(None); }; - if let Ok(arr) = value.cast::>() { - let ro = arr.readonly(); - if ro.as_array().shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "{label} must have shape ({}, 3, 3)", - batch - ))); + if let Some(arr) = PyFloatArray3::from_any(value) { + match arr { + PyFloatArray3::F32(arr) => { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro.as_slice().map_err(|_| { + PyValueError::new_err(format!("{label} must be C-contiguous")) + })?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F32, + slot_bytes: 36, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } + PyFloatArray3::F64(arr) => { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro.as_slice().map_err(|_| { + PyValueError::new_err(format!("{label} must be C-contiguous")) + })?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F64, + slot_bytes: 72, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } } - let slice = ro - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; - return Ok(Some(Self { - type_tag: TYPE_MAT3X3_F32, - slot_bytes: 36, - payload: bytemuck::cast_slice::(slice).to_vec(), - })); - } - if let Ok(arr) = value.cast::>() { - let ro = arr.readonly(); - if ro.as_array().shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "{label} must have shape ({}, 3, 3)", - batch - ))); - } - let slice = ro - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; - return Ok(Some(Self { - type_tag: TYPE_MAT3X3_F64, - slot_bytes: 72, - payload: bytemuck::cast_slice::(slice).to_vec(), - })); } Ok(None) } @@ -562,12 +561,12 @@ fn try_add_arrays_batch_fast_canonical( properties: Option<&Bound<'_, PyDict>>, atom_properties: Option<&Bound<'_, PyDict>>, ) -> PyResult { - let Ok(positions) = positions.cast::>() else { + let Some(PyFloatArray3::F32(positions)) = PyFloatArray3::from_any(positions) else { return Ok(false); }; let energy = match energy { Some(value) => { - let Ok(arr) = value.cast::>() else { + let Some(PyFloatArray1::F64(arr)) = PyFloatArray1::from_any(value) else { return Ok(false); }; Some(arr) @@ -576,7 +575,7 @@ fn try_add_arrays_batch_fast_canonical( }; let forces = match forces { Some(value) => { - let Ok(arr) = value.cast::>() else { + let Some(PyFloatArray3::F32(arr)) = PyFloatArray3::from_any(value) else { return Ok(false); }; Some(arr) @@ -585,7 +584,7 @@ fn try_add_arrays_batch_fast_canonical( }; let charges = match charges { Some(value) => { - let Ok(arr) = value.cast::>() else { + let Some(PyFloatArray2::F64(arr)) = PyFloatArray2::from_any(value) else { return Ok(false); }; Some(arr) @@ -594,7 +593,7 @@ fn try_add_arrays_batch_fast_canonical( }; let velocities = match velocities { Some(value) => { - let Ok(arr) = value.cast::>() else { + let Some(PyFloatArray3::F32(arr)) = PyFloatArray3::from_any(value) else { return Ok(false); }; Some(arr) @@ -933,41 +932,45 @@ fn try_add_arrays_batch_fast_canonical( } fn extract_positions_payload(value: &Bound<'_, PyAny>) -> PyResult<(usize, usize, u8, Vec)> { - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape().len() != 3 || view.shape()[2] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (batch, n_atoms, 3)", - )); - } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; - return Ok(( - view.shape()[0], - view.shape()[1], - TYPE_VEC3_F32, - bytemuck::cast_slice::(slice).to_vec(), - )); - } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape().len() != 3 || view.shape()[2] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (batch, n_atoms, 3)", - )); + if let Some(arr) = PyFloatArray3::from_any(value) { + match arr { + PyFloatArray3::F32(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F32, + bytemuck::cast_slice::(slice).to_vec(), + )); + } + PyFloatArray3::F64(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F64, + bytemuck::cast_slice::(slice).to_vec(), + )); + } } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; - return Ok(( - view.shape()[0], - view.shape()[1], - TYPE_VEC3_F64, - bytemuck::cast_slice::(slice).to_vec(), - )); } Err(PyValueError::new_err( "positions must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", @@ -1192,133 +1195,94 @@ pub(super) fn add_arrays_batch_impl( let mut builtin_columns = Vec::new(); if let Some(energy) = energy { - if let Ok(arr) = energy.cast::>() { - builtin_columns.push(extract_builtin_scalar_column( - arr, - batch, - "energy", - TYPE_FLOAT32, - )?); - } else if let Ok(arr) = energy.cast::>() { - builtin_columns.push(extract_builtin_scalar_column( - arr, batch, "energy", TYPE_FLOAT, - )?); - } else { + let Some(array) = PyFloatArray1::from_any(energy) else { return Err(PyValueError::new_err( "energy must be a float32 or float64 ndarray with shape (batch,)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray1::F32(arr) => { + extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT32)? + } + PyFloatArray1::F64(arr) => { + extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT)? + } + }); } if let Some(forces) = forces { - if let Ok(arr) = forces.cast::>() { - builtin_columns.push(extract_builtin_vec3_column( - arr, - batch, - n_atoms, - "forces", - TYPE_VEC3_F32, - )?); - } else if let Ok(arr) = forces.cast::>() { - builtin_columns.push(extract_builtin_vec3_column( - arr, - batch, - n_atoms, - "forces", - TYPE_VEC3_F64, - )?); - } else { + let Some(array) = PyFloatArray3::from_any(forces) else { return Err(PyValueError::new_err( "forces must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray3::F32(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F32)? + } + PyFloatArray3::F64(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F64)? + } + }); } if let Some(charges) = charges { - if let Ok(arr) = charges.cast::>() { - builtin_columns.push(extract_builtin_float_array_column( - arr, - batch, - n_atoms, - "charges", - TYPE_F32_ARRAY, - )?); - } else if let Ok(arr) = charges.cast::>() { - builtin_columns.push(extract_builtin_float_array_column( - arr, - batch, - n_atoms, - "charges", - TYPE_F64_ARRAY, - )?); - } else { + let Some(array) = PyFloatArray2::from_any(charges) else { return Err(PyValueError::new_err( "charges must be a float32 or float64 ndarray with shape (batch, n_atoms)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray2::F32(arr) => { + extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F32_ARRAY)? + } + PyFloatArray2::F64(arr) => { + extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F64_ARRAY)? + } + }); } if let Some(velocities) = velocities { - if let Ok(arr) = velocities.cast::>() { - builtin_columns.push(extract_builtin_vec3_column( - arr, - batch, - n_atoms, - "velocities", - TYPE_VEC3_F32, - )?); - } else if let Ok(arr) = velocities.cast::>() { - builtin_columns.push(extract_builtin_vec3_column( - arr, - batch, - n_atoms, - "velocities", - TYPE_VEC3_F64, - )?); - } else { + let Some(array) = PyFloatArray3::from_any(velocities) else { return Err(PyValueError::new_err( "velocities must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray3::F32(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F32)? + } + PyFloatArray3::F64(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F64)? + } + }); } if let Some(cell) = cell { - if let Ok(arr) = cell.cast::>() { - builtin_columns.push(extract_builtin_mat3_column( - arr, - batch, - "cell", - TYPE_MAT3X3_F32, - )?); - } else if let Ok(arr) = cell.cast::>() { - builtin_columns.push(extract_builtin_mat3_column( - arr, - batch, - "cell", - TYPE_MAT3X3_F64, - )?); - } else { + let Some(array) = PyFloatArray3::from_any(cell) else { return Err(PyValueError::new_err( "cell must be a float32 or float64 ndarray with shape (batch, 3, 3)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray3::F32(arr) => { + extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F32)? + } + PyFloatArray3::F64(arr) => { + extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F64)? + } + }); } if let Some(stress) = stress { - if let Ok(arr) = stress.cast::>() { - builtin_columns.push(extract_builtin_mat3_column( - arr, - batch, - "stress", - TYPE_MAT3X3_F32, - )?); - } else if let Ok(arr) = stress.cast::>() { - builtin_columns.push(extract_builtin_mat3_column( - arr, - batch, - "stress", - TYPE_MAT3X3_F64, - )?); - } else { + let Some(array) = PyFloatArray3::from_any(stress) else { return Err(PyValueError::new_err( "stress must be a float32 or float64 ndarray with shape (batch, 3, 3)", )); - } + }; + builtin_columns.push(match array { + PyFloatArray3::F32(arr) => { + extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F32)? + } + PyFloatArray3::F64(arr) => { + extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F64)? + } + }); } if let Some(pbc) = pbc { builtin_columns.push(extract_builtin_pbc_column(pbc, batch)?); diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index 56de991..a8ada30 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -50,26 +50,15 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let compression = inner.compression(); let use_mmap = inner.get_compressed_slice(0).is_some(); - let record_format = inner.record_format(); - let positions_type = match record_format { - RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, - RECORD_FORMAT_SOA_V3 => inner - .positions_type() - .ok_or_else(|| invalid_data("Missing position dtype for batch"))?, - _ => { - return Err(invalid_data(format!( - "Unsupported record format {}", - record_format - ))); - } - }; + let ctx = SoaContext::from_database(inner)?; + let positions_type = ctx.positions_type(); let schema_info = inner.schema_info(); let raw_bytes_owned: Option>>; let schema: Vec; let use_ordered_schema: bool; let ordered_schema_from_first = |bytes: &[u8]| -> atompack::Result> { - let first_md = parse_mol_fast_soa(bytes, record_format, Some(positions_type))?; + let first_md = parse_mol_fast_soa(bytes, ctx)?; let n = first_md.n_atoms; first_md .sections @@ -223,16 +212,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .insert(entry.key.clone(), index); } } - let positions_stride = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(invalid_data(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; + let positions_stride = ctx.layout.positions_stride; let mut positions = match positions_type { TYPE_VEC3_F32 => FlatPositions::F32(vec![0f32; total_atoms * 3]), TYPE_VEC3_F64 => FlatPositions::F64(vec![0u8; total_atoms * positions_stride]), @@ -293,7 +273,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .collect(); let process_mol = |i: usize, mol_bytes: &[u8]| -> atompack::Result<()> { - let md = parse_mol_fast_soa(mol_bytes, record_format, Some(positions_type))?; + let md = parse_mol_fast_soa(mol_bytes, ctx)?; let atom_off = offsets[i]; let n = md.n_atoms; diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index e20e0d1..251fbd9 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -117,12 +117,17 @@ const TYPE_MAT3X3_F32: u8 = 12; const RECORD_FORMAT_SOA_V2: u32 = 2; const RECORD_FORMAT_SOA_V3: u32 = 3; +mod py_dtypes; mod soa; +pub(crate) use self::py_dtypes::{ + PyFloatArray1, PyFloatArray2, PyFloatArray3, PyIntArray1, parse_float_array_field, + parse_mat3_field, parse_positions_field, parse_property_value, parse_vec3_field, +}; pub(crate) use self::soa::{ - LazySection, SectionRef, SectionSchema, SoaMoleculeView, is_per_atom, parse_mol_fast_soa, - read_f64_scalar, read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes, - validate_section_payload, + LazySection, SectionRef, SectionSchema, SoaContext, SoaMoleculeView, is_per_atom, + parse_mol_fast_soa, read_f64_scalar, read_i64_scalar, section_schema_from_ref, + type_tag_elem_bytes, validate_section_payload, }; mod database; diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 3e082f5..81e128e 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -60,8 +60,7 @@ mod helpers; pub(crate) use self::helpers::{ SoaRecord, SoaSection, build_soa_record, cast_or_decode_f32, cast_or_decode_f64, - cast_or_decode_i32, cast_or_decode_i64, parse_float_array_field, parse_mat3_field, - parse_vec3_field, pyarray1_from_cow, pyarray2_from_cow, + cast_or_decode_i32, cast_or_decode_i64, pyarray1_from_cow, pyarray2_from_cow, }; use self::helpers::{into_py_any, property_section_to_pyobject, property_value_to_pyobject}; @@ -457,63 +456,7 @@ impl PyMolecule { "'stress' is a reserved field; use molecule.stress instead", )); } - - // Try to extract different Python types - if let Ok(v) = value.extract::() { - inner.properties.insert(key, PropertyValue::Int(v)); - } else if let Ok(v) = value.extract::() { - inner.properties.insert(key, PropertyValue::Float(v)); - } else if let Ok(v) = value.extract::() { - inner.properties.insert(key, PropertyValue::String(v)); - } else if let Ok(arr) = value.cast::>() { - let vec = arr.readonly().as_array().to_vec(); - inner - .properties - .insert(key, PropertyValue::Float32Array(vec)); - } else if let Ok(arr) = value.cast::>() { - let vec = arr.readonly().as_array().to_vec(); - inner.properties.insert(key, PropertyValue::FloatArray(vec)); - } else if let Ok(arr) = value.cast::>() { - let vec = arr.readonly().as_array().to_vec(); - inner.properties.insert(key, PropertyValue::Int32Array(vec)); - } else if let Ok(arr) = value.cast::>() { - let vec = arr.readonly().as_array().to_vec(); - inner.properties.insert(key, PropertyValue::IntArray(vec)); - } else if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let arr_view = readonly.as_array(); - let shape = arr_view.shape(); - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Vec3Array properties must have shape (n, 3)", - )); - } - let vec: Vec<[f32; 3]> = arr_view - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - inner.properties.insert(key, PropertyValue::Vec3Array(vec)); - } else if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let arr_view = readonly.as_array(); - let shape = arr_view.shape(); - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Vec3Array properties must have shape (n, 3)", - )); - } - let vec: Vec<[f64; 3]> = arr_view - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - inner - .properties - .insert(key, PropertyValue::Vec3ArrayF64(vec)); - } else { - return Err(PyValueError::new_err( - "Unsupported property type. Supported: float, int, str, ndarray", - )); - } + inner.properties.insert(key, parse_property_value(value)?); Ok(()) } diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index cb2f768..a71f6db 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -253,153 +253,6 @@ fn molecule_from_positions( } } -pub(crate) fn parse_vec3_field( - value: &Bound<'_, PyAny>, - label: &str, - expected_rows: usize, -) -> PyResult { - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape() != [expected_rows, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape ({}, 3)", - label, expected_rows - ))); - } - return Ok(Vec3Data::F32( - view.outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - )); - } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape() != [expected_rows, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape ({}, 3)", - label, expected_rows - ))); - } - return Ok(Vec3Data::F64( - view.outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - )); - } - Err(PyValueError::new_err(format!( - "{} must be a float32 or float64 ndarray with shape ({}, 3)", - label, expected_rows - ))) -} - -fn parse_positions_field(value: &Bound<'_, PyAny>) -> PyResult { - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape().len() != 2 || view.shape()[1] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (n_atoms, 3)", - )); - } - return Ok(Vec3Data::F32( - view.outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - )); - } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape().len() != 2 || view.shape()[1] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (n_atoms, 3)", - )); - } - return Ok(Vec3Data::F64( - view.outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - )); - } - Err(PyValueError::new_err( - "positions must be a float32 or float64 ndarray with shape (n_atoms, 3)", - )) -} - -pub(crate) fn parse_float_array_field( - value: &Bound<'_, PyAny>, - label: &str, - expected_len: usize, -) -> PyResult { - if let Ok(arr) = value.cast::>() { - let values = arr.readonly().as_array().to_vec(); - if values.len() != expected_len { - return Err(PyValueError::new_err(format!( - "{} length ({}) doesn't match atom count ({})", - label, - values.len(), - expected_len - ))); - } - return Ok(FloatArrayData::F32(values)); - } - if let Ok(arr) = value.cast::>() { - let values = arr.readonly().as_array().to_vec(); - if values.len() != expected_len { - return Err(PyValueError::new_err(format!( - "{} length ({}) doesn't match atom count ({})", - label, - values.len(), - expected_len - ))); - } - return Ok(FloatArrayData::F64(values)); - } - Err(PyValueError::new_err(format!( - "{} must be a float32 or float64 ndarray with shape ({},)", - label, expected_len - ))) -} - -pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResult { - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape() != [3, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape (3, 3)", - label - ))); - } - return Ok(Mat3Data::F32([ - [view[[0, 0]], view[[0, 1]], view[[0, 2]]], - [view[[1, 0]], view[[1, 1]], view[[1, 2]]], - [view[[2, 0]], view[[2, 1]], view[[2, 2]]], - ])); - } - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - let view = readonly.as_array(); - if view.shape() != [3, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape (3, 3)", - label - ))); - } - return Ok(Mat3Data::F64([ - [view[[0, 0]], view[[0, 1]], view[[0, 2]]], - [view[[1, 0]], view[[1, 1]], view[[1, 2]]], - [view[[2, 0]], view[[2, 1]], view[[2, 2]]], - ])); - } - Err(PyValueError::new_err(format!( - "{} must be a float32 or float64 ndarray with shape (3, 3)", - label - ))) -} - pub(crate) fn molecule_from_numpy_arrays( positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, @@ -574,6 +427,116 @@ fn owned_mat3x3_array<'py>(py: Python<'py>, values: &Mat3Data) -> PyResult PyErr { + PyValueError::new_err("Molecule is missing both owned and view state") +} + +fn view_vec3_payload_py<'py>( + py: Python<'py>, + payload: &[u8], + type_tag: u8, + rows: usize, + label: &str, +) -> PyResult> { + match type_tag { + TYPE_VEC3_F32 => Ok( + pyarray2_from_cow(py, cast_or_decode_f32(payload)?, rows, 3)? + .into_any() + .unbind(), + ), + TYPE_VEC3_F64 => Ok( + pyarray2_from_cow(py, cast_or_decode_f64(payload)?, rows, 3)? + .into_any() + .unbind(), + ), + _ => Err(PyValueError::new_err(format!("Invalid {label} section"))), + } +} + +fn view_builtin_vec3_slot_py<'py>( + py: Python<'py>, + view: &SoaMoleculeView, + slot: (usize, usize, u8), + label: &str, +) -> PyResult> { + match slot.2 { + TYPE_VEC3_F32 => { + if slot.1 != view.n_atoms * 12 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + } + TYPE_VEC3_F64 => { + if slot.1 != view.n_atoms * 24 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + } + _ => return Err(PyValueError::new_err(format!("Invalid {label} section"))), + } + view_vec3_payload_py(py, view.builtin_payload(slot), slot.2, view.n_atoms, label) +} + +fn view_builtin_float_array_slot_py<'py>( + py: Python<'py>, + view: &SoaMoleculeView, + slot: (usize, usize, u8), + label: &str, +) -> PyResult> { + match slot.2 { + TYPE_F32_ARRAY => { + if slot.1 != view.n_atoms * 4 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + Ok( + pyarray1_from_cow(py, cast_or_decode_f32(view.builtin_payload(slot))?) + .into_any() + .unbind(), + ) + } + TYPE_F64_ARRAY => { + if slot.1 != view.n_atoms * 8 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + Ok( + pyarray1_from_cow(py, cast_or_decode_f64(view.builtin_payload(slot))?) + .into_any() + .unbind(), + ) + } + _ => Err(PyValueError::new_err(format!("Invalid {label} section"))), + } +} + +fn view_builtin_mat3_slot_py<'py>( + py: Python<'py>, + view: &SoaMoleculeView, + slot: (usize, usize, u8), + label: &str, +) -> PyResult> { + match slot.2 { + TYPE_MAT3X3_F32 => { + if slot.1 != 36 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + Ok( + pyarray2_from_cow(py, cast_or_decode_f32(view.builtin_payload(slot))?, 3, 3)? + .into_any() + .unbind(), + ) + } + TYPE_MAT3X3_F64 => { + if slot.1 != 72 { + return Err(PyValueError::new_err(format!("Invalid {label} section"))); + } + Ok( + pyarray2_from_cow(py, cast_or_decode_f64(view.builtin_payload(slot))?, 3, 3)? + .into_any() + .unbind(), + ) + } + _ => Err(PyValueError::new_err(format!("Invalid {label} section"))), + } +} + impl PyMolecule { pub(crate) fn from_owned(inner: Molecule) -> Self { Self { @@ -631,27 +594,14 @@ impl PyMolecule { if let Some(inner) = self.as_owned() { return owned_vec3_array(py, &inner.positions); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; - match view.positions_type { - TYPE_VEC3_F32 => { - let positions = cast_or_decode_f32(view.positions_bytes())?; - Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? - .into_any() - .unbind()) - } - TYPE_VEC3_F64 => { - let positions = cast_or_decode_f64(view.positions_bytes())?; - Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? - .into_any() - .unbind()) - } - other => Err(PyValueError::new_err(format!( - "Unsupported positions type tag {}", - other - ))), - } + let view = self.as_view().ok_or_else(missing_molecule_state)?; + view_vec3_payload_py( + py, + view.positions_bytes(), + view.positions_type, + view.n_atoms, + "positions", + ) } pub(super) fn atomic_numbers_py<'py>( @@ -661,9 +611,7 @@ impl PyMolecule { if let Some(inner) = self.as_owned() { return Ok(PyArray1::from_slice(py, &inner.atomic_numbers)); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; Ok(PyArray1::from_slice(py, view.atomic_numbers_bytes())) } @@ -675,37 +623,11 @@ impl PyMolecule { .map(|forces| owned_vec3_array(py, forces)) .transpose(); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; let Some(slot) = view.forces else { return Ok(None); }; - match slot.2 { - TYPE_VEC3_F32 => { - if slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid forces section")); - } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some( - pyarray2_from_cow(py, data, view.n_atoms, 3)? - .into_any() - .unbind(), - )) - } - TYPE_VEC3_F64 => { - if slot.1 != view.n_atoms * 24 { - return Err(PyValueError::new_err("Invalid forces section")); - } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some( - pyarray2_from_cow(py, data, view.n_atoms, 3)? - .into_any() - .unbind(), - )) - } - _ => Err(PyValueError::new_err("Invalid forces section")), - } + Ok(Some(view_builtin_vec3_slot_py(py, view, slot, "forces")?)) } pub(super) fn charges_py<'py>(&self, py: Python<'py>) -> PyResult>> { @@ -715,29 +637,13 @@ impl PyMolecule { .as_ref() .map(|charges| owned_float_array(py, charges))); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; let Some(slot) = view.charges else { return Ok(None); }; - match slot.2 { - TYPE_F32_ARRAY => { - if slot.1 != view.n_atoms * 4 { - return Err(PyValueError::new_err("Invalid charges section")); - } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) - } - TYPE_F64_ARRAY => { - if slot.1 != view.n_atoms * 8 { - return Err(PyValueError::new_err("Invalid charges section")); - } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) - } - _ => Err(PyValueError::new_err("Invalid charges section")), - } + Ok(Some(view_builtin_float_array_slot_py( + py, view, slot, "charges", + )?)) } pub(super) fn velocities_py<'py>(&self, py: Python<'py>) -> PyResult>> { @@ -748,37 +654,16 @@ impl PyMolecule { .map(|velocities| owned_vec3_array(py, velocities)) .transpose(); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; let Some(slot) = view.velocities else { return Ok(None); }; - match slot.2 { - TYPE_VEC3_F32 => { - if slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid velocities section")); - } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some( - pyarray2_from_cow(py, data, view.n_atoms, 3)? - .into_any() - .unbind(), - )) - } - TYPE_VEC3_F64 => { - if slot.1 != view.n_atoms * 24 { - return Err(PyValueError::new_err("Invalid velocities section")); - } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some( - pyarray2_from_cow(py, data, view.n_atoms, 3)? - .into_any() - .unbind(), - )) - } - _ => Err(PyValueError::new_err("Invalid velocities section")), - } + Ok(Some(view_builtin_vec3_slot_py( + py, + view, + slot, + "velocities", + )?)) } pub(super) fn cell_py<'py>(&self, py: Python<'py>) -> PyResult>> { @@ -789,29 +674,11 @@ impl PyMolecule { .map(|cell| owned_mat3x3_array(py, cell)) .transpose(); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; let Some(slot) = view.cell else { return Ok(None); }; - match slot.2 { - TYPE_MAT3X3_F32 => { - if slot.1 != 36 { - return Err(PyValueError::new_err("Invalid cell section")); - } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) - } - TYPE_MAT3X3_F64 => { - if slot.1 != 72 { - return Err(PyValueError::new_err("Invalid cell section")); - } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) - } - _ => Err(PyValueError::new_err("Invalid cell section")), - } + Ok(Some(view_builtin_mat3_slot_py(py, view, slot, "cell")?)) } pub(super) fn stress_py<'py>(&self, py: Python<'py>) -> PyResult>> { @@ -822,29 +689,11 @@ impl PyMolecule { .map(|stress| owned_mat3x3_array(py, stress)) .transpose(); } - let view = self.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; + let view = self.as_view().ok_or_else(missing_molecule_state)?; let Some(slot) = view.stress else { return Ok(None); }; - match slot.2 { - TYPE_MAT3X3_F32 => { - if slot.1 != 36 { - return Err(PyValueError::new_err("Invalid stress section")); - } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) - } - TYPE_MAT3X3_F64 => { - if slot.1 != 72 { - return Err(PyValueError::new_err("Invalid stress section")); - } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) - } - _ => Err(PyValueError::new_err("Invalid stress section")), - } + Ok(Some(view_builtin_mat3_slot_py(py, view, slot, "stress")?)) } pub(super) fn append_owned_ase_properties<'py>( diff --git a/atompack-py/src/py_dtypes.rs b/atompack-py/src/py_dtypes.rs new file mode 100644 index 0000000..e05f112 --- /dev/null +++ b/atompack-py/src/py_dtypes.rs @@ -0,0 +1,310 @@ +use super::*; + +pub(crate) enum PyFloatArray1<'py> { + F32(Bound<'py, PyArray1>), + F64(Bound<'py, PyArray1>), +} + +impl<'py> PyFloatArray1<'py> { + pub(crate) fn from_any(value: &Bound<'py, PyAny>) -> Option { + if let Ok(arr) = value.cast::>() { + return Some(Self::F32(arr.clone())); + } + if let Ok(arr) = value.cast::>() { + return Some(Self::F64(arr.clone())); + } + None + } + + pub(crate) fn to_float_array_data( + &self, + label: &str, + expected_len: usize, + ) -> PyResult { + match self { + Self::F32(arr) => { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + Ok(FloatArrayData::F32(values)) + } + Self::F64(arr) => { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + Ok(FloatArrayData::F64(values)) + } + } + } +} + +pub(crate) enum PyIntArray1<'py> { + I32(Bound<'py, PyArray1>), + I64(Bound<'py, PyArray1>), +} + +impl<'py> PyIntArray1<'py> { + pub(crate) fn from_any(value: &Bound<'py, PyAny>) -> Option { + if let Ok(arr) = value.cast::>() { + return Some(Self::I32(arr.clone())); + } + if let Ok(arr) = value.cast::>() { + return Some(Self::I64(arr.clone())); + } + None + } +} + +pub(crate) enum PyFloatArray2<'py> { + F32(Bound<'py, PyArray2>), + F64(Bound<'py, PyArray2>), +} + +impl<'py> PyFloatArray2<'py> { + pub(crate) fn from_any(value: &Bound<'py, PyAny>) -> Option { + if let Ok(arr) = value.cast::>() { + return Some(Self::F32(arr.clone())); + } + if let Ok(arr) = value.cast::>() { + return Some(Self::F64(arr.clone())); + } + None + } + + pub(crate) fn parse_vec3_data( + &self, + label: &str, + expected_rows: Option, + ) -> PyResult { + match self { + Self::F32(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape.len() != 2 || shape[1] != 3 { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, + expected_rows.map_or("n".to_string(), |rows| rows.to_string()) + ))); + } + if let Some(rows) = expected_rows + && shape[0] != rows + { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, rows + ))); + } + Ok(Vec3Data::F32( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )) + } + Self::F64(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape.len() != 2 || shape[1] != 3 { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, + expected_rows.map_or("n".to_string(), |rows| rows.to_string()) + ))); + } + if let Some(rows) = expected_rows + && shape[0] != rows + { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, rows + ))); + } + Ok(Vec3Data::F64( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )) + } + } + } + + pub(crate) fn parse_mat3_data(&self, label: &str) -> PyResult { + match self { + Self::F32(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape (3, 3)", + label + ))); + } + Ok(Mat3Data::F32([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])) + } + Self::F64(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape (3, 3)", + label + ))); + } + Ok(Mat3Data::F64([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])) + } + } + } +} + +pub(crate) enum PyFloatArray3<'py> { + F32(Bound<'py, PyArray3>), + F64(Bound<'py, PyArray3>), +} + +impl<'py> PyFloatArray3<'py> { + pub(crate) fn from_any(value: &Bound<'py, PyAny>) -> Option { + if let Ok(arr) = value.cast::>() { + return Some(Self::F32(arr.clone())); + } + if let Ok(arr) = value.cast::>() { + return Some(Self::F64(arr.clone())); + } + None + } +} + +pub(crate) fn parse_vec3_field( + value: &Bound<'_, PyAny>, + label: &str, + expected_rows: usize, +) -> PyResult { + let Some(array) = PyFloatArray2::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({}, 3)", + label, expected_rows + ))); + }; + array.parse_vec3_data(label, Some(expected_rows)) +} + +pub(crate) fn parse_positions_field(value: &Bound<'_, PyAny>) -> PyResult { + let Some(array) = PyFloatArray2::from_any(value) else { + return Err(PyValueError::new_err( + "positions must be a float32 or float64 ndarray with shape (n_atoms, 3)", + )); + }; + array.parse_vec3_data("positions", None) +} + +pub(crate) fn parse_float_array_field( + value: &Bound<'_, PyAny>, + label: &str, + expected_len: usize, +) -> PyResult { + let Some(array) = PyFloatArray1::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({},)", + label, expected_len + ))); + }; + array.to_float_array_data(label, expected_len) +} + +pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResult { + let Some(array) = PyFloatArray2::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape (3, 3)", + label + ))); + }; + array.parse_mat3_data(label) +} + +pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(v) = value.extract::() { + return Ok(PropertyValue::Int(v)); + } + if let Ok(v) = value.extract::() { + return Ok(PropertyValue::Float(v)); + } + if let Ok(v) = value.extract::() { + return Ok(PropertyValue::String(v)); + } + if let Some(array) = PyFloatArray1::from_any(value) { + return Ok(match array { + PyFloatArray1::F32(arr) => { + PropertyValue::Float32Array(arr.readonly().as_array().to_vec()) + } + PyFloatArray1::F64(arr) => { + PropertyValue::FloatArray(arr.readonly().as_array().to_vec()) + } + }); + } + if let Some(array) = PyIntArray1::from_any(value) { + return Ok(match array { + PyIntArray1::I32(arr) => PropertyValue::Int32Array(arr.readonly().as_array().to_vec()), + PyIntArray1::I64(arr) => PropertyValue::IntArray(arr.readonly().as_array().to_vec()), + }); + } + if let Some(array) = PyFloatArray2::from_any(value) { + return Ok(match array { + PyFloatArray2::F32(arr) => { + let readonly = arr.readonly(); + let arr_view = readonly.as_array(); + let shape = arr_view.shape(); + if shape[1] != 3 { + return Err(PyValueError::new_err( + "Vec3Array properties must have shape (n, 3)", + )); + } + PropertyValue::Vec3Array( + arr_view + .outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + ) + } + PyFloatArray2::F64(arr) => { + let readonly = arr.readonly(); + let arr_view = readonly.as_array(); + let shape = arr_view.shape(); + if shape[1] != 3 { + return Err(PyValueError::new_err( + "Vec3Array properties must have shape (n, 3)", + )); + } + PropertyValue::Vec3ArrayF64( + arr_view + .outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + ) + } + }); + } + Err(PyValueError::new_err( + "Unsupported property type. Supported: float, int, str, ndarray", + )) +} diff --git a/atompack-py/src/soa.rs b/atompack-py/src/soa.rs index ce401e2..793df97 100644 --- a/atompack-py/src/soa.rs +++ b/atompack-py/src/soa.rs @@ -28,11 +28,81 @@ pub(crate) struct SectionSchema { pub(crate) slot_bytes: usize, } -pub(crate) fn parse_mol_fast_soa_v2(bytes: &[u8]) -> atompack::Result> { +#[derive(Clone, Copy)] +pub(crate) struct SoaLayout { + pub(crate) positions_type: u8, + pub(crate) positions_stride: usize, +} + +impl SoaLayout { + pub(crate) fn resolve( + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + match record_format { + RECORD_FORMAT_SOA_V2 => Ok(Self { + positions_type: TYPE_VEC3_F32, + positions_stride: 12, + }), + RECORD_FORMAT_SOA_V3 => { + let positions_type = positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + Ok(Self { + positions_type, + positions_stride, + }) + } + _ => Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))), + } + } +} + +#[derive(Clone, Copy)] +pub(crate) struct SoaContext { + pub(crate) layout: SoaLayout, +} + +impl SoaContext { + pub(crate) fn resolve( + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Ok(Self { + layout: SoaLayout::resolve(record_format, positions_type_hint)?, + }) + } + + pub(crate) fn from_database(database: &AtomDatabase) -> atompack::Result { + Self::resolve(database.record_format(), database.positions_type()) + } + + #[inline] + pub(crate) fn positions_type(self) -> u8 { + self.layout.positions_type + } +} + +fn parse_mol_fast_soa_with_layout( + bytes: &[u8], + layout: SoaLayout, +) -> atompack::Result> { let mut pos = 0usize; let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; let positions_len = n_atoms - .checked_mul(12) + .checked_mul(layout.positions_stride) .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; @@ -70,68 +140,8 @@ pub(crate) fn parse_mol_fast_soa_v2(bytes: &[u8]) -> atompack::Result, -) -> atompack::Result> { - if record_format == RECORD_FORMAT_SOA_V2 { - return parse_mol_fast_soa_v2(bytes); - } - - let mut pos = 0usize; - let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; - let positions_type = match record_format { - RECORD_FORMAT_SOA_V3 => positions_type_hint - .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, - _ => { - return Err(invalid_data(format!( - "Unsupported record format {}", - record_format - ))); - } - }; - let positions_stride = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(invalid_data(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; - let positions_len = n_atoms - .checked_mul(positions_stride) - .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; - let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; - let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; - let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; - - let mut sections = Vec::with_capacity(n_sections); - for _ in 0..n_sections { - let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; - let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; - let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; - let key = std::str::from_utf8(key_bytes) - .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; - let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; - let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; - let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; - sections.push(SectionRef { - kind, - key, - type_tag, - payload, - }); - } - - Ok(MolData { - n_atoms, - positions_bytes, - atomic_numbers_bytes, - sections, - }) +pub(crate) fn parse_mol_fast_soa(bytes: &[u8], ctx: SoaContext) -> atompack::Result> { + parse_mol_fast_soa_with_layout(bytes, ctx.layout) } pub(crate) fn section_schema_from_ref( @@ -427,7 +437,7 @@ pub(crate) struct SoaMoleculeView { } impl SoaMoleculeView { - fn from_storage_v2(bytes: SoaBytes) -> atompack::Result { + fn from_storage(bytes: SoaBytes, ctx: SoaContext) -> atompack::Result { if bytes.len() < 6 { return Err(invalid_data("SOA record too small")); } @@ -436,7 +446,7 @@ impl SoaMoleculeView { let mut pos = 4usize; let positions_start = pos; let positions_len = n_atoms - .checked_mul(12) + .checked_mul(ctx.layout.positions_stride) .ok_or_else(|| invalid_data("SOA positions overflow"))?; pos = pos .checked_add(positions_len) @@ -536,7 +546,7 @@ impl SoaMoleculeView { Ok(Self { bytes, n_atoms, - positions_type: TYPE_VEC3_F32, + positions_type: ctx.positions_type(), positions_start, positions_len, atomic_numbers_start, @@ -553,182 +563,19 @@ impl SoaMoleculeView { } /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. - fn from_storage_inner( - bytes: SoaBytes, - record_format: u32, - positions_type_hint: Option, - ) -> atompack::Result { - if record_format == RECORD_FORMAT_SOA_V2 { - return Self::from_storage_v2(bytes); - } - - if bytes.len() < 6 { - return Err(invalid_data("SOA record too small")); - } - - let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; - let mut pos = 4usize; - let positions_type = match record_format { - RECORD_FORMAT_SOA_V3 => positions_type_hint - .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, - _ => { - return Err(invalid_data(format!( - "Unsupported record format {}", - record_format - ))); - } - }; - let positions_stride = match positions_type { - TYPE_VEC3_F32 => 12usize, - TYPE_VEC3_F64 => 24usize, - _ => { - return Err(invalid_data(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; - let positions_start = pos; - let positions_len = n_atoms - .checked_mul(positions_stride) - .ok_or_else(|| invalid_data("SOA positions overflow"))?; - pos = pos - .checked_add(positions_len) - .ok_or_else(|| invalid_data("SOA positions overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA record truncated at positions")); - } - - let atomic_numbers_start = pos; - pos = pos - .checked_add(n_atoms) - .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA record truncated at atomic_numbers")); - } - - let n_sections = - u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; - pos += 2; - - let mut forces = None; - let mut energy = None; - let mut cell = None; - let mut stress = None; - let mut charges = None; - let mut velocities = None; - let mut pbc = None; - let mut name = None; - let mut custom_sections = Vec::new(); - - for _ in 0..n_sections { - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(invalid_data("SOA section key truncated")); - } - let key_start = pos; - pos += key_len; - if pos + 5 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(slice_to_array( - &bytes[pos..pos + 4], - "SOA section payload length", - )?) as usize; - pos += 4; - let payload_start = pos; - pos = pos - .checked_add(payload_len) - .ok_or_else(|| invalid_data("SOA section payload overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA section payload truncated")); - } - - let key_bytes = &bytes[key_start..key_start + key_len]; - if kind == KIND_BUILTIN { - let slot = (payload_start, payload_len, type_tag); - match key_bytes { - b"forces" => forces = Some(slot), - b"energy" => energy = Some(slot), - b"cell" => cell = Some(slot), - b"stress" => stress = Some(slot), - b"charges" => charges = Some(slot), - b"velocities" => velocities = Some(slot), - b"pbc" => pbc = Some(slot), - b"name" => name = Some(slot), - _ => { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - } else { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - - Ok(Self { - bytes, - n_atoms, - positions_type, - positions_start, - positions_len, - atomic_numbers_start, - forces, - energy, - cell, - stress, - charges, - velocities, - pbc, - name, - custom_sections, - }) + pub(crate) fn from_owned_bytes(bytes: Vec, ctx: SoaContext) -> atompack::Result { + Self::from_storage(SoaBytes::Owned(bytes), ctx) } - pub(crate) fn from_bytes_inner( - bytes: Vec, - record_format: u32, - positions_type_hint: Option, - ) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Owned(bytes), record_format, positions_type_hint) - } - - pub(crate) fn from_shared_bytes_inner( + pub(crate) fn from_shared_bytes( bytes: SharedMmapBytes, - record_format: u32, - positions_type_hint: Option, + ctx: SoaContext, ) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Shared(bytes), record_format, positions_type_hint) + Self::from_storage(SoaBytes::Shared(bytes), ctx) } - pub(crate) fn from_bytes( - bytes: Vec, - record_format: u32, - positions_type_hint: Option, - ) -> PyResult { - Self::from_bytes_inner(bytes, record_format, positions_type_hint) - .map_err(|e| PyValueError::new_err(format!("{}", e))) + pub(crate) fn from_bytes(bytes: Vec, ctx: SoaContext) -> PyResult { + Self::from_owned_bytes(bytes, ctx).map_err(|e| PyValueError::new_err(format!("{}", e))) } pub(crate) fn positions_bytes(&self) -> &[u8] { diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index 3edec41..ed7b575 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -3,6 +3,8 @@ use bytemuck::{Pod, Zeroable}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +pub use crate::types::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, Vec3Data}; + #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Pod, Zeroable)] #[repr(C)] pub struct Atom { @@ -38,166 +40,6 @@ impl Atom { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum PropertyValue { - Float(f64), - Int(i64), - String(String), - FloatArray(Vec), - Vec3Array(Vec<[f32; 3]>), - IntArray(Vec), - Float32Array(Vec), - Vec3ArrayF64(Vec<[f64; 3]>), - Int32Array(Vec), -} - -impl PropertyValue { - pub fn len(&self) -> Option { - match self { - PropertyValue::FloatArray(v) => Some(v.len()), - PropertyValue::Vec3Array(v) => Some(v.len()), - PropertyValue::IntArray(v) => Some(v.len()), - PropertyValue::Float32Array(v) => Some(v.len()), - PropertyValue::Vec3ArrayF64(v) => Some(v.len()), - PropertyValue::Int32Array(v) => Some(v.len()), - _ => None, - } - } - - pub fn is_empty(&self) -> bool { - self.len() == Some(0) - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum Vec3Data { - F32(Vec<[f32; 3]>), - F64(Vec<[f64; 3]>), -} - -impl Vec3Data { - pub fn len(&self) -> usize { - match self { - Self::F32(values) => values.len(), - Self::F64(values) => values.len(), - } - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn atom_position(&self, index: usize) -> Option<[f32; 3]> { - match self { - Self::F32(values) => values.get(index).copied(), - Self::F64(values) => values - .get(index) - .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]), - } - } - - pub fn flatten_f32_lossy(&self) -> Vec { - match self { - Self::F32(values) => values - .iter() - .flat_map(|value| [value[0], value[1], value[2]]) - .collect(), - Self::F64(values) => values - .iter() - .flat_map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) - .collect(), - } - } - - pub fn flatten_f64(&self) -> Vec { - match self { - Self::F32(values) => values - .iter() - .flat_map(|value| [value[0] as f64, value[1] as f64, value[2] as f64]) - .collect(), - Self::F64(values) => values - .iter() - .flat_map(|value| [value[0], value[1], value[2]]) - .collect(), - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum FloatScalarData { - F32(f32), - F64(f64), -} - -impl FloatScalarData { - pub fn as_f64(&self) -> f64 { - match self { - Self::F32(value) => *value as f64, - Self::F64(value) => *value, - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum FloatArrayData { - F32(Vec), - F64(Vec), -} - -impl FloatArrayData { - pub fn len(&self) -> usize { - match self { - Self::F32(values) => values.len(), - Self::F64(values) => values.len(), - } - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn to_f64_vec(&self) -> Vec { - match self { - Self::F32(values) => values.iter().map(|value| *value as f64).collect(), - Self::F64(values) => values.clone(), - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum Mat3Data { - F32([[f32; 3]; 3]), - F64([[f64; 3]; 3]), -} - -impl Mat3Data { - pub fn flatten_f32_lossy(&self) -> Vec { - match self { - Self::F32(values) => values - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(), - Self::F64(values) => values - .iter() - .flat_map(|row| [row[0] as f32, row[1] as f32, row[2] as f32]) - .collect(), - } - } - - pub fn flatten_f64(&self) -> Vec { - match self { - Self::F32(values) => values - .iter() - .flat_map(|row| [row[0] as f64, row[1] as f64, row[2] as f64]) - .collect(), - Self::F64(values) => values - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(), - } - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Molecule { pub name: Option, diff --git a/atompack/src/lib.rs b/atompack/src/lib.rs index 04a4dfe..1ad5c8a 100644 --- a/atompack/src/lib.rs +++ b/atompack/src/lib.rs @@ -11,12 +11,12 @@ pub mod atom; pub mod compression; pub mod storage; +pub mod types; -pub use atom::{ - Atom, FloatArrayData, FloatScalarData, Mat3Data, Molecule, PropertyValue, Vec3Data, -}; +pub use atom::{Atom, Molecule}; pub use compression::decompress as decompress_bytes; pub use storage::{AtomDatabase, SharedMmapBytes}; +pub use types::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, Vec3Data}; /// Result type used throughout atompack pub type Result = std::result::Result; diff --git a/atompack/src/storage/dtypes.rs b/atompack/src/storage/dtypes.rs new file mode 100644 index 0000000..950b9b0 --- /dev/null +++ b/atompack/src/storage/dtypes.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, Vec3Data}; + +pub(super) fn arr(bytes: &[u8]) -> Result<[u8; N]> { + bytes + .try_into() + .map_err(|_| Error::InvalidData("byte slice truncated".into())) +} + +pub(super) fn positions_type_from_molecule(molecule: &Molecule) -> u8 { + vec3_data_type_tag(&molecule.positions) +} + +pub(super) fn vec3_data_type_tag(values: &Vec3Data) -> u8 { + match values { + Vec3Data::F32(_) => TYPE_VEC3_F32, + Vec3Data::F64(_) => TYPE_VEC3_F64, + } +} + +pub(super) fn vec3_payload_len(values: &Vec3Data) -> usize { + match values { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + } +} + +pub(super) fn float_array_data_type_tag(values: &FloatArrayData) -> u8 { + match values { + FloatArrayData::F32(_) => TYPE_F32_ARRAY, + FloatArrayData::F64(_) => TYPE_F64_ARRAY, + } +} + +pub(super) fn float_array_payload_len(values: &FloatArrayData) -> usize { + match values { + FloatArrayData::F32(values) => values.len() * 4, + FloatArrayData::F64(values) => values.len() * 8, + } +} + +pub(super) fn mat3_data_type_tag(values: &Mat3Data) -> u8 { + match values { + Mat3Data::F32(_) => TYPE_MAT3X3_F32, + Mat3Data::F64(_) => TYPE_MAT3X3_F64, + } +} + +pub(super) fn mat3_payload_len(values: &Mat3Data) -> usize { + match values { + Mat3Data::F32(_) => 36, + Mat3Data::F64(_) => 72, + } +} + +pub(super) fn float_scalar_data_type_tag(value: &FloatScalarData) -> u8 { + match value { + FloatScalarData::F32(_) => TYPE_FLOAT32, + FloatScalarData::F64(_) => TYPE_FLOAT, + } +} + +pub(super) fn float_scalar_payload_len(value: &FloatScalarData) -> usize { + match value { + FloatScalarData::F32(_) => 4, + FloatScalarData::F64(_) => 8, + } +} + +pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { + match value { + PropertyValue::Float(_) => TYPE_FLOAT, + PropertyValue::Int(_) => TYPE_INT, + PropertyValue::String(_) => TYPE_STRING, + PropertyValue::FloatArray(_) => TYPE_F64_ARRAY, + PropertyValue::Vec3Array(_) => TYPE_VEC3_F32, + PropertyValue::IntArray(_) => TYPE_I64_ARRAY, + PropertyValue::Float32Array(_) => TYPE_F32_ARRAY, + PropertyValue::Vec3ArrayF64(_) => TYPE_VEC3_F64, + PropertyValue::Int32Array(_) => TYPE_I32_ARRAY, + } +} + +pub(super) fn property_value_payload_len(value: &PropertyValue) -> usize { + match value { + PropertyValue::Float(_) | PropertyValue::Int(_) => 8, + PropertyValue::String(value) => value.len(), + PropertyValue::FloatArray(values) => values.len() * 8, + PropertyValue::Vec3Array(values) => values.len() * 12, + PropertyValue::IntArray(values) => values.len() * 8, + PropertyValue::Float32Array(values) => values.len() * 4, + PropertyValue::Vec3ArrayF64(values) => values.len() * 24, + PropertyValue::Int32Array(values) => values.len() * 4, + } +} + +fn extend_f64(buf: &mut Vec, values: &[f64]) { + for value in values { + buf.extend_from_slice(&f64::to_le_bytes(*value)); + } +} + +fn extend_f32(buf: &mut Vec, values: &[f32]) { + for value in values { + buf.extend_from_slice(&f32::to_le_bytes(*value)); + } +} + +fn extend_i64(buf: &mut Vec, values: &[i64]) { + for value in values { + buf.extend_from_slice(&i64::to_le_bytes(*value)); + } +} + +fn extend_i32(buf: &mut Vec, values: &[i32]) { + for value in values { + buf.extend_from_slice(&i32::to_le_bytes(*value)); + } +} + +pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { + match value { + PropertyValue::Float(value) => value.to_le_bytes().to_vec(), + PropertyValue::Int(value) => value.to_le_bytes().to_vec(), + PropertyValue::String(value) => value.as_bytes().to_vec(), + PropertyValue::FloatArray(values) => { + let mut payload = Vec::with_capacity(values.len() * 8); + extend_f64(&mut payload, values); + payload + } + PropertyValue::Vec3Array(values) => { + let mut payload = Vec::with_capacity(values.len() * 12); + for value in values { + extend_f32(&mut payload, value); + } + payload + } + PropertyValue::IntArray(values) => { + let mut payload = Vec::with_capacity(values.len() * 8); + extend_i64(&mut payload, values); + payload + } + PropertyValue::Float32Array(values) => { + let mut payload = Vec::with_capacity(values.len() * 4); + extend_f32(&mut payload, values); + payload + } + PropertyValue::Vec3ArrayF64(values) => { + let mut payload = Vec::with_capacity(values.len() * 24); + for value in values { + extend_f64(&mut payload, value); + } + payload + } + PropertyValue::Int32Array(values) => { + let mut payload = Vec::with_capacity(values.len() * 4); + extend_i32(&mut payload, values); + payload + } + } +} + +pub(super) fn validate_builtin_type_tag_for_record_format( + record_format: u32, + key: &str, + type_tag: u8, +) -> Result<()> { + match record_format { + RECORD_FORMAT_SOA_V3 => Ok(()), + RECORD_FORMAT_SOA_V2 => match key { + "charges" if type_tag != TYPE_F64_ARRAY => Err(Error::InvalidData( + "record format 2 does not support float32 charges".into(), + )), + "cell" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 cell".into(), + )), + "energy" if type_tag != TYPE_FLOAT => Err(Error::InvalidData( + "record format 2 does not support float32 energy".into(), + )), + "forces" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 forces".into(), + )), + "stress" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 stress".into(), + )), + "velocities" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 velocities".into(), + )), + _ => Ok(()), + }, + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn decode_vec3_f32(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(12) { + return Err(Error::InvalidData( + "vec3 payload length not divisible by 12".into(), + )); + } + payload + .chunks_exact(12) + .map(|chunk| { + Ok([ + f32::from_le_bytes(arr(&chunk[0..4])?), + f32::from_le_bytes(arr(&chunk[4..8])?), + f32::from_le_bytes(arr(&chunk[8..12])?), + ]) + }) + .collect() +} + +pub(super) fn decode_vec3_f64(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(24) { + return Err(Error::InvalidData( + "vec3 payload length not divisible by 24".into(), + )); + } + payload + .chunks_exact(24) + .map(|chunk| { + Ok([ + f64::from_le_bytes(arr(&chunk[0..8])?), + f64::from_le_bytes(arr(&chunk[8..16])?), + f64::from_le_bytes(arr(&chunk[16..24])?), + ]) + }) + .collect() +} + +pub(super) fn decode_f32_array(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "f32 array payload length not divisible by 4".into(), + )); + } + payload + .chunks_exact(4) + .map(|chunk| Ok(f32::from_le_bytes(arr(chunk)?))) + .collect() +} + +pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(8) { + return Err(Error::InvalidData( + "f64 array payload length not divisible by 8".into(), + )); + } + payload + .chunks_exact(8) + .map(|chunk| Ok(f64::from_le_bytes(arr(chunk)?))) + .collect() +} + +pub(super) fn decode_mat3x3_f32(payload: &[u8]) -> Result<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(Error::InvalidData(format!( + "mat3x3 payload length {} (expected 36)", + payload.len() + ))); + } + let mut mat = [[0.0f32; 3]; 3]; + for (row_idx, row) in mat.iter_mut().enumerate() { + for (col_idx, cell) in row.iter_mut().enumerate() { + let offset = (row_idx * 3 + col_idx) * 4; + *cell = f32::from_le_bytes(arr(&payload[offset..offset + 4])?); + } + } + Ok(mat) +} + +pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { + if payload.len() != 72 { + return Err(Error::InvalidData(format!( + "mat3x3 payload length {} (expected 72)", + payload.len() + ))); + } + let mut mat = [[0.0f64; 3]; 3]; + for (row_idx, row) in mat.iter_mut().enumerate() { + for (col_idx, cell) in row.iter_mut().enumerate() { + let offset = (row_idx * 3 + col_idx) * 8; + *cell = f64::from_le_bytes(arr(&payload[offset..offset + 8])?); + } + } + Ok(mat) +} + +pub(super) fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { + Ok(match type_tag { + TYPE_FLOAT => { + if payload.len() < 8 { + return Err(Error::InvalidData("f64 property truncated".into())); + } + PropertyValue::Float(f64::from_le_bytes(arr(&payload[..8])?)) + } + TYPE_INT => { + if payload.len() < 8 { + return Err(Error::InvalidData("i64 property truncated".into())); + } + PropertyValue::Int(i64::from_le_bytes(arr(&payload[..8])?)) + } + TYPE_STRING => PropertyValue::String( + std::str::from_utf8(payload) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in property".into()))? + .to_string(), + ), + TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), + TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), + TYPE_I64_ARRAY => { + if !payload.len().is_multiple_of(8) { + return Err(Error::InvalidData( + "i64 array payload length not divisible by 8".into(), + )); + } + PropertyValue::IntArray( + payload + .chunks_exact(8) + .map(|chunk| Ok(i64::from_le_bytes(arr(chunk)?))) + .collect::>()?, + ) + } + TYPE_F32_ARRAY => PropertyValue::Float32Array(decode_f32_array(payload)?), + TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), + TYPE_I32_ARRAY => { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "i32 array payload length not divisible by 4".into(), + )); + } + PropertyValue::Int32Array( + payload + .chunks_exact(4) + .map(|chunk| Ok(i32::from_le_bytes(arr(chunk)?))) + .collect::>()?, + ) + } + _ => return Err(Error::InvalidData(format!("Unknown type tag {}", type_tag))), + }) +} + +pub(super) fn decode_float_scalar_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_FLOAT => { + if payload.len() != 8 { + return Err(Error::InvalidData(format!( + "{field_name} f64 payload truncated" + ))); + } + Ok(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))) + } + TYPE_FLOAT32 => { + if payload.len() != 4 { + return Err(Error::InvalidData(format!( + "{field_name} f32 payload truncated" + ))); + } + Ok(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))) + } + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +pub(super) fn decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +pub(super) fn decode_float_array_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +pub(super) fn decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index 8c070d1..783a26e 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -21,7 +21,6 @@ //! └──────────────────────────────────────┘ //! ``` -use crate::atom::PropertyValue; use crate::compression::{CompressionType, compress, decompress}; use crate::{Error, Molecule, Result}; use bytemuck::{Pod, Zeroable}; @@ -33,11 +32,13 @@ use std::io::{Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; +mod dtypes; mod header; mod index; mod schema; mod soa; +use self::dtypes::arr; use self::header::{Header, encode_header_slot, read_best_header}; use self::index::{IndexStorage, MoleculeIndex, decode_index, encode_index}; use self::schema::{ @@ -45,7 +46,7 @@ use self::schema::{ record_schema, schema_from_molecule, validate_schema_lock_for_record_format, }; use self::soa::{ - arr, deserialize_molecule_soa, minimum_record_format_for_molecule, serialize_molecule_soa, + deserialize_molecule_soa, minimum_record_format_for_molecule, serialize_molecule_soa, }; // --------------------------------------------------------------------------- @@ -180,6 +181,11 @@ pub struct AtomDatabase { data_mmap: Option>, } +enum AppendSchema<'a> { + Infer(Vec<(&'a [u8], Option)>), + Locked(SchemaLock), +} + impl AtomDatabase { // -- Creation & opening -------------------------------------------------- @@ -469,6 +475,25 @@ impl AtomDatabase { Ok(()) } + fn ensure_writable_for_append(&self) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } + Ok(()) + } + + fn prepare_append<'a>(&mut self, schema: AppendSchema<'a>) -> Result<()> { + self.ensure_writable_for_append()?; + self.truncate_uncommitted_tail_if_needed()?; + match schema { + AppendSchema::Infer(records) => self.ensure_schema_compatible(records), + AppendSchema::Locked(schema) => self.ensure_schema_lock(&schema), + } + } + fn write_owned_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { let compression = self.compression; @@ -654,24 +679,17 @@ impl AtomDatabase { where I: IntoIterator)>, { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); - } - let records: Vec<(&[u8], u32, Option)> = records.into_iter().collect(); if records.is_empty() { return Ok(()); } - self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_compatible( + self.prepare_append(AppendSchema::Infer( records .iter() - .map(|(bytes, _, positions_type_hint)| (*bytes, *positions_type_hint)), - )?; + .map(|(bytes, _, positions_type_hint)| (*bytes, *positions_type_hint)) + .collect(), + ))?; let borrowed = records .iter() .map(|(bytes, num_atoms, _)| (*bytes, *num_atoms)) @@ -680,19 +698,12 @@ impl AtomDatabase { } fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); - } - - self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_compatible( + self.prepare_append(AppendSchema::Infer( records .iter() - .map(|(bytes, _, positions_type)| (bytes.as_slice(), Some(*positions_type))), - )?; + .map(|(bytes, _, positions_type)| (bytes.as_slice(), Some(*positions_type))) + .collect(), + ))?; let records = records .into_iter() @@ -706,15 +717,7 @@ impl AtomDatabase { records: Vec<(Vec, u32)>, batch_schema: SchemaLock, ) -> Result<()> { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); - } - - self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_lock(&batch_schema)?; + self.prepare_append(AppendSchema::Locked(batch_schema))?; self.write_owned_records(records) } @@ -723,15 +726,7 @@ impl AtomDatabase { records: &[(&[u8], u32)], batch_schema: SchemaLock, ) -> Result<()> { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); - } - - self.truncate_uncommitted_tail_if_needed()?; - self.ensure_schema_lock(&batch_schema)?; + self.prepare_append(AppendSchema::Locked(batch_schema))?; self.write_borrowed_records(records) } diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs index 14da316..0f87334 100644 --- a/atompack/src/storage/schema.rs +++ b/atompack/src/storage/schema.rs @@ -1,6 +1,11 @@ -use super::soa::{arr, positions_stride, property_value_type_tag, resolve_positions_type}; +use super::dtypes::{ + arr, float_array_data_type_tag, float_array_payload_len, float_scalar_data_type_tag, + float_scalar_payload_len, mat3_data_type_tag, mat3_payload_len, positions_type_from_molecule, + property_value_payload_len, property_value_type_tag, + validate_builtin_type_tag_for_record_format, vec3_data_type_tag, vec3_payload_len, +}; +use super::soa::{SoaLayout, resolve_layout}; use super::*; -use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; use std::collections::BTreeMap; #[derive(Debug, Clone, Default, PartialEq, Eq)] @@ -19,13 +24,6 @@ pub(super) struct SchemaEntry { const SCHEMA_BLOB_VERSION: u32 = 1; -pub(super) fn positions_type_from_molecule(molecule: &Molecule) -> u8 { - match molecule.positions { - Vec3Data::F32(_) => TYPE_VEC3_F32, - Vec3Data::F64(_) => TYPE_VEC3_F64, - } -} - pub(super) fn encode_schema_lock(lock: &SchemaLock) -> Result> { let mut buf = Vec::new(); buf.extend_from_slice(&SCHEMA_BLOB_VERSION.to_le_bytes()); @@ -199,46 +197,11 @@ fn schema_entry( }) } -fn validate_builtin_type_tag_for_record_format( - record_format: u32, - key: &str, - type_tag: u8, -) -> Result<()> { - match record_format { - RECORD_FORMAT_SOA_V3 => Ok(()), - RECORD_FORMAT_SOA_V2 => match key { - "charges" if type_tag != TYPE_F64_ARRAY => Err(Error::InvalidData( - "record format 2 does not support float32 charges".into(), - )), - "cell" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( - "record format 2 does not support float32 cell".into(), - )), - "energy" if type_tag != TYPE_FLOAT => Err(Error::InvalidData( - "record format 2 does not support float32 energy".into(), - )), - "forces" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( - "record format 2 does not support float64 forces".into(), - )), - "stress" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( - "record format 2 does not support float32 stress".into(), - )), - "velocities" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( - "record format 2 does not support float64 velocities".into(), - )), - _ => Ok(()), - }, - _ => Err(Error::InvalidData(format!( - "Unsupported record format {}", - record_format - ))), - } -} - pub(super) fn validate_schema_lock_for_record_format( record_format: u32, schema: &SchemaLock, ) -> Result<()> { - let _ = resolve_positions_type(record_format, schema.positions_type)?; + let _ = resolve_layout(record_format, schema.positions_type)?; for ((kind, key), entry) in &schema.sections { if *kind == KIND_BUILTIN { validate_builtin_type_tag_for_record_format(record_format, key, entry.type_tag)?; @@ -247,19 +210,6 @@ pub(super) fn validate_schema_lock_for_record_format( Ok(()) } -fn property_value_payload_len(value: &PropertyValue) -> usize { - match value { - PropertyValue::Float(_) | PropertyValue::Int(_) => 8, - PropertyValue::String(value) => value.len(), - PropertyValue::FloatArray(values) => values.len() * 8, - PropertyValue::Vec3Array(values) => values.len() * 12, - PropertyValue::IntArray(values) => values.len() * 8, - PropertyValue::Float32Array(values) => values.len() * 4, - PropertyValue::Vec3ArrayF64(values) => values.len() * 24, - PropertyValue::Int32Array(values) => values.len() * 4, - } -} - pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { let n_atoms = molecule.len(); let mut schema = SchemaLock { @@ -274,32 +224,36 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { }; if let Some(charges) = &molecule.charges { - let (type_tag, payload_len) = match charges { - FloatArrayData::F32(values) => (TYPE_F32_ARRAY, values.len() * 4), - FloatArrayData::F64(values) => (TYPE_F64_ARRAY, values.len() * 8), - }; - insert(KIND_BUILTIN, "charges", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "charges", + float_array_data_type_tag(charges), + float_array_payload_len(charges), + )?; } if let Some(cell) = &molecule.cell { - let (type_tag, payload_len) = match cell { - Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), - Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), - }; - insert(KIND_BUILTIN, "cell", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "cell", + mat3_data_type_tag(cell), + mat3_payload_len(cell), + )?; } if let Some(energy) = &molecule.energy { - let (type_tag, payload_len) = match energy { - FloatScalarData::F32(_) => (TYPE_FLOAT32, 4), - FloatScalarData::F64(_) => (TYPE_FLOAT, 8), - }; - insert(KIND_BUILTIN, "energy", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "energy", + float_scalar_data_type_tag(energy), + float_scalar_payload_len(energy), + )?; } if let Some(forces) = &molecule.forces { - let (type_tag, payload_len) = match forces { - Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), - Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), - }; - insert(KIND_BUILTIN, "forces", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "forces", + vec3_data_type_tag(forces), + vec3_payload_len(forces), + )?; } if let Some(name) = &molecule.name { insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; @@ -308,18 +262,20 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; } if let Some(stress) = &molecule.stress { - let (type_tag, payload_len) = match stress { - Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), - Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), - }; - insert(KIND_BUILTIN, "stress", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "stress", + mat3_data_type_tag(stress), + mat3_payload_len(stress), + )?; } if let Some(velocities) = &molecule.velocities { - let (type_tag, payload_len) = match velocities { - Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), - Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), - }; - insert(KIND_BUILTIN, "velocities", type_tag, payload_len)?; + insert( + KIND_BUILTIN, + "velocities", + vec3_data_type_tag(velocities), + vec3_payload_len(velocities), + )?; } for (key, value) in &molecule.atom_properties { @@ -342,10 +298,10 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { Ok(schema) } -fn parse_record_schema_with_positions( +fn parse_record_schema_with_layout( bytes: &[u8], record_format: u32, - positions_type: u8, + layout: SoaLayout, ) -> Result { if bytes.len() < 6 { return Err(Error::InvalidData("SOA record too small".into())); @@ -358,7 +314,7 @@ fn parse_record_schema_with_positions( let positions_end = pos .checked_add( n_atoms - .checked_mul(positions_stride(positions_type)?) + .checked_mul(layout.positions_stride) .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, ) .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; @@ -388,7 +344,7 @@ fn parse_record_schema_with_positions( pos += 2; let mut schema = SchemaLock { - positions_type: Some(positions_type), + positions_type: Some(layout.positions_type), sections: BTreeMap::new(), }; @@ -433,8 +389,11 @@ pub(super) fn record_schema( record_format: u32, positions_type_hint: Option, ) -> Result { - let positions_type = resolve_positions_type(record_format, positions_type_hint)?; - parse_record_schema_with_positions(bytes, record_format, positions_type) + parse_record_schema_with_layout( + bytes, + record_format, + resolve_layout(record_format, positions_type_hint)?, + ) } pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index 6c52d7b..8a3dbbd 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -1,3 +1,11 @@ +use super::dtypes::{ + arr, decode_float_array_data, decode_float_scalar_data, decode_mat3_data, + decode_property_value, decode_vec3_data, float_array_data_type_tag, float_array_payload_len, + float_scalar_data_type_tag, float_scalar_payload_len, mat3_data_type_tag, mat3_payload_len, + positions_type_from_molecule, property_value_payload_len, property_value_to_bytes, + property_value_type_tag, validate_builtin_type_tag_for_record_format, vec3_data_type_tag, + vec3_payload_len, +}; use super::*; use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; @@ -11,195 +19,18 @@ fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: buf.extend_from_slice(payload); } -pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { - match value { - PropertyValue::Float(_) => TYPE_FLOAT, - PropertyValue::Int(_) => TYPE_INT, - PropertyValue::String(_) => TYPE_STRING, - PropertyValue::FloatArray(_) => TYPE_F64_ARRAY, - PropertyValue::Vec3Array(_) => TYPE_VEC3_F32, - PropertyValue::IntArray(_) => TYPE_I64_ARRAY, - PropertyValue::Float32Array(_) => TYPE_F32_ARRAY, - PropertyValue::Vec3ArrayF64(_) => TYPE_VEC3_F64, - PropertyValue::Int32Array(_) => TYPE_I32_ARRAY, - } -} - -fn property_value_payload_len(value: &PropertyValue) -> usize { - match value { - PropertyValue::Float(_) | PropertyValue::Int(_) => 8, - PropertyValue::String(value) => value.len(), - PropertyValue::FloatArray(values) => values.len() * 8, - PropertyValue::Vec3Array(values) => values.len() * 12, - PropertyValue::IntArray(values) => values.len() * 8, - PropertyValue::Float32Array(values) => values.len() * 4, - PropertyValue::Vec3ArrayF64(values) => values.len() * 24, - PropertyValue::Int32Array(values) => values.len() * 4, - } -} - -fn extend_f64(b: &mut Vec, v: &[f64]) { - for x in v { - b.extend_from_slice(&f64::to_le_bytes(*x)); - } -} -fn extend_f32(b: &mut Vec, v: &[f32]) { - for x in v { - b.extend_from_slice(&f32::to_le_bytes(*x)); - } -} -fn extend_i64(b: &mut Vec, v: &[i64]) { - for x in v { - b.extend_from_slice(&i64::to_le_bytes(*x)); - } -} -fn extend_i32(b: &mut Vec, v: &[i32]) { - for x in v { - b.extend_from_slice(&i32::to_le_bytes(*x)); - } -} - -pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { - match value { - PropertyValue::Float(v) => v.to_le_bytes().to_vec(), - PropertyValue::Int(v) => v.to_le_bytes().to_vec(), - PropertyValue::String(s) => s.as_bytes().to_vec(), - PropertyValue::FloatArray(v) => { - let mut b = Vec::with_capacity(v.len() * 8); - extend_f64(&mut b, v); - b - } - PropertyValue::Vec3Array(v) => { - let mut b = Vec::with_capacity(v.len() * 12); - for a in v { - extend_f32(&mut b, a); - } - b - } - PropertyValue::IntArray(v) => { - let mut b = Vec::with_capacity(v.len() * 8); - extend_i64(&mut b, v); - b - } - PropertyValue::Float32Array(v) => { - let mut b = Vec::with_capacity(v.len() * 4); - extend_f32(&mut b, v); - b - } - PropertyValue::Vec3ArrayF64(v) => { - let mut b = Vec::with_capacity(v.len() * 24); - for a in v { - extend_f64(&mut b, a); - } - b - } - PropertyValue::Int32Array(v) => { - let mut b = Vec::with_capacity(v.len() * 4); - extend_i32(&mut b, v); - b - } - } -} - -/// Try to read a fixed-size array from a byte slice, returning an error on truncation. -pub(super) fn arr(bytes: &[u8]) -> Result<[u8; N]> { - bytes - .try_into() - .map_err(|_| Error::InvalidData("byte slice truncated".into())) -} - -pub(super) fn decode_vec3_f32(payload: &[u8]) -> Result> { - if !payload.len().is_multiple_of(12) { - return Err(Error::InvalidData( - "vec3 payload length not divisible by 12".into(), - )); - } - payload - .chunks_exact(12) - .map(|c| { - Ok([ - f32::from_le_bytes(arr(&c[0..4])?), - f32::from_le_bytes(arr(&c[4..8])?), - f32::from_le_bytes(arr(&c[8..12])?), - ]) - }) - .collect() -} - -pub(super) fn decode_vec3_f64(payload: &[u8]) -> Result> { - if !payload.len().is_multiple_of(24) { - return Err(Error::InvalidData( - "vec3 payload length not divisible by 24".into(), - )); - } - payload - .chunks_exact(24) - .map(|c| { - Ok([ - f64::from_le_bytes(arr(&c[0..8])?), - f64::from_le_bytes(arr(&c[8..16])?), - f64::from_le_bytes(arr(&c[16..24])?), - ]) - }) - .collect() -} - -pub(super) fn decode_f32_array(payload: &[u8]) -> Result> { - if !payload.len().is_multiple_of(4) { - return Err(Error::InvalidData( - "f32 array payload length not divisible by 4".into(), - )); - } - payload - .chunks_exact(4) - .map(|c| Ok(f32::from_le_bytes(arr(c)?))) - .collect() -} - -pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { - if !payload.len().is_multiple_of(8) { - return Err(Error::InvalidData( - "f64 array payload length not divisible by 8".into(), - )); +#[cfg(not(target_endian = "little"))] +fn extend_f64(buf: &mut Vec, values: &[f64]) { + for value in values { + buf.extend_from_slice(&f64::to_le_bytes(*value)); } - payload - .chunks_exact(8) - .map(|c| Ok(f64::from_le_bytes(arr(c)?))) - .collect() } -pub(super) fn decode_mat3x3_f32(payload: &[u8]) -> Result<[[f32; 3]; 3]> { - if payload.len() != 36 { - return Err(Error::InvalidData(format!( - "mat3x3 payload length {} (expected 36)", - payload.len() - ))); - } - let mut mat = [[0.0f32; 3]; 3]; - for (r, row) in mat.iter_mut().enumerate() { - for (c, cell) in row.iter_mut().enumerate() { - let o = (r * 3 + c) * 4; - *cell = f32::from_le_bytes(arr(&payload[o..o + 4])?); - } - } - Ok(mat) -} - -pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { - if payload.len() != 72 { - return Err(Error::InvalidData(format!( - "mat3x3 payload length {} (expected 72)", - payload.len() - ))); - } - let mut mat = [[0.0f64; 3]; 3]; - for (r, row) in mat.iter_mut().enumerate() { - for (c, cell) in row.iter_mut().enumerate() { - let o = (r * 3 + c) * 8; - *cell = f64::from_le_bytes(arr(&payload[o..o + 8])?); - } +#[cfg(not(target_endian = "little"))] +fn extend_f32(buf: &mut Vec, values: &[f32]) { + for value in values { + buf.extend_from_slice(&f32::to_le_bytes(*value)); } - Ok(mat) } fn write_vec3_section(buf: &mut Vec, key: &str, values: &Vec3Data) { @@ -388,52 +219,73 @@ pub(super) fn positions_stride(positions_type: u8) -> Result { } } +#[derive(Clone, Copy)] +pub(super) struct SoaLayout { + pub(super) positions_type: u8, + pub(super) positions_stride: usize, +} + +pub(super) fn resolve_layout( + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + let positions_stride = positions_stride(positions_type)?; + Ok(SoaLayout { + positions_type, + positions_stride, + }) +} + fn validate_record_format_compat(molecule: &Molecule, record_format: u32) -> Result<()> { - match record_format { - RECORD_FORMAT_SOA_V3 => Ok(()), - RECORD_FORMAT_SOA_V2 => { - if matches!(molecule.positions, Vec3Data::F64(_)) { - return Err(Error::InvalidData( - "record format 2 does not support float64 positions".into(), - )); - } - if matches!(molecule.charges, Some(FloatArrayData::F32(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float32 charges".into(), - )); - } - if matches!(molecule.cell, Some(Mat3Data::F32(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float32 cell".into(), - )); - } - if matches!(molecule.energy, Some(FloatScalarData::F32(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float32 energy".into(), - )); - } - if matches!(molecule.forces, Some(Vec3Data::F64(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float64 forces".into(), - )); - } - if matches!(molecule.stress, Some(Mat3Data::F32(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float32 stress".into(), - )); - } - if matches!(molecule.velocities, Some(Vec3Data::F64(_))) { - return Err(Error::InvalidData( - "record format 2 does not support float64 velocities".into(), - )); - } - Ok(()) - } - _ => Err(Error::InvalidData(format!( - "Unsupported record format {}", - record_format - ))), + resolve_positions_type(record_format, Some(positions_type_from_molecule(molecule)))?; + if record_format == RECORD_FORMAT_SOA_V3 { + return Ok(()); + } + + if let Some(charges) = &molecule.charges { + validate_builtin_type_tag_for_record_format( + record_format, + "charges", + float_array_data_type_tag(charges), + )?; + } + if let Some(cell) = &molecule.cell { + validate_builtin_type_tag_for_record_format( + record_format, + "cell", + mat3_data_type_tag(cell), + )?; } + if let Some(energy) = &molecule.energy { + validate_builtin_type_tag_for_record_format( + record_format, + "energy", + float_scalar_data_type_tag(energy), + )?; + } + if let Some(forces) = &molecule.forces { + validate_builtin_type_tag_for_record_format( + record_format, + "forces", + vec3_data_type_tag(forces), + )?; + } + if let Some(stress) = &molecule.stress { + validate_builtin_type_tag_for_record_format( + record_format, + "stress", + mat3_data_type_tag(stress), + )?; + } + if let Some(velocities) = &molecule.velocities { + validate_builtin_type_tag_for_record_format( + record_format, + "velocities", + vec3_data_type_tag(velocities), + )?; + } + Ok(()) } pub(super) fn minimum_record_format_for_molecule(molecule: &Molecule) -> u32 { @@ -509,39 +361,20 @@ fn section_size(key_len: usize, payload_len: usize) -> usize { } fn estimate_serialized_len(molecule: &Molecule) -> usize { - let positions_bytes = match &molecule.positions { - Vec3Data::F32(values) => values.len() * 12, - Vec3Data::F64(values) => values.len() * 24, - }; + let positions_bytes = vec3_payload_len(&molecule.positions); let mut total = 4 + positions_bytes + molecule.atomic_numbers.len() + 2; if let Some(charges) = &molecule.charges { - let payload_len = match charges { - FloatArrayData::F32(values) => values.len() * 4, - FloatArrayData::F64(values) => values.len() * 8, - }; - total += section_size("charges".len(), payload_len); + total += section_size("charges".len(), float_array_payload_len(charges)); } if let Some(cell) = &molecule.cell { - let payload_len = match cell { - Mat3Data::F32(_) => 36, - Mat3Data::F64(_) => 72, - }; - total += section_size("cell".len(), payload_len); + total += section_size("cell".len(), mat3_payload_len(cell)); } if let Some(energy) = &molecule.energy { - let payload_len = match energy { - FloatScalarData::F32(_) => 4, - FloatScalarData::F64(_) => 8, - }; - total += section_size("energy".len(), payload_len); + total += section_size("energy".len(), float_scalar_payload_len(energy)); } if let Some(forces) = &molecule.forces { - let payload_len = match forces { - Vec3Data::F32(values) => values.len() * 12, - Vec3Data::F64(values) => values.len() * 24, - }; - total += section_size("forces".len(), payload_len); + total += section_size("forces".len(), vec3_payload_len(forces)); } if let Some(name) = &molecule.name { total += section_size("name".len(), name.len()); @@ -550,18 +383,10 @@ fn estimate_serialized_len(molecule: &Molecule) -> usize { total += section_size("pbc".len(), 3); } if let Some(stress) = &molecule.stress { - let payload_len = match stress { - Mat3Data::F32(_) => 36, - Mat3Data::F64(_) => 72, - }; - total += section_size("stress".len(), payload_len); + total += section_size("stress".len(), mat3_payload_len(stress)); } if let Some(velocities) = &molecule.velocities { - let payload_len = match velocities { - Vec3Data::F32(values) => values.len() * 12, - Vec3Data::F64(values) => values.len() * 24, - }; - total += section_size("velocities".len(), payload_len); + total += section_size("velocities".len(), vec3_payload_len(velocities)); } for (key, value) in &molecule.atom_properties { @@ -630,89 +455,6 @@ fn write_sections(buf: &mut Vec, molecule: &Molecule) { } } -fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { - Ok(match type_tag { - TYPE_FLOAT => { - if payload.len() < 8 { - return Err(Error::InvalidData("f64 property truncated".into())); - } - PropertyValue::Float(f64::from_le_bytes(arr(&payload[..8])?)) - } - TYPE_INT => { - if payload.len() < 8 { - return Err(Error::InvalidData("i64 property truncated".into())); - } - PropertyValue::Int(i64::from_le_bytes(arr(&payload[..8])?)) - } - TYPE_STRING => PropertyValue::String( - std::str::from_utf8(payload) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in property".into()))? - .to_string(), - ), - TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), - TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), - TYPE_I64_ARRAY => { - if !payload.len().is_multiple_of(8) { - return Err(Error::InvalidData( - "i64 array payload length not divisible by 8".into(), - )); - } - PropertyValue::IntArray( - payload - .chunks_exact(8) - .map(|c| Ok(i64::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - TYPE_F32_ARRAY => { - if !payload.len().is_multiple_of(4) { - return Err(Error::InvalidData( - "f32 array payload length not divisible by 4".into(), - )); - } - PropertyValue::Float32Array( - payload - .chunks_exact(4) - .map(|c| Ok(f32::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - TYPE_VEC3_F64 => { - if !payload.len().is_multiple_of(24) { - return Err(Error::InvalidData( - "vec3 payload length not divisible by 24".into(), - )); - } - PropertyValue::Vec3ArrayF64( - payload - .chunks_exact(24) - .map(|c| { - Ok([ - f64::from_le_bytes(arr(&c[0..8])?), - f64::from_le_bytes(arr(&c[8..16])?), - f64::from_le_bytes(arr(&c[16..24])?), - ]) - }) - .collect::>()?, - ) - } - TYPE_I32_ARRAY => { - if !payload.len().is_multiple_of(4) { - return Err(Error::InvalidData( - "i32 array payload length not divisible by 4".into(), - )); - } - PropertyValue::Int32Array( - payload - .chunks_exact(4) - .map(|c| Ok(i32::from_le_bytes(arr(c)?))) - .collect::>()?, - ) - } - _ => return Err(Error::InvalidData(format!("Unknown type tag {}", type_tag))), - }) -} - pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> Result> { validate_record_format_compat(molecule, record_format)?; @@ -732,10 +474,10 @@ fn decode_positions( bytes: &[u8], pos: &mut usize, n_atoms: usize, - positions_type: u8, + layout: SoaLayout, ) -> Result { let positions_len = n_atoms - .checked_mul(positions_stride(positions_type)?) + .checked_mul(layout.positions_stride) .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; let positions_end = pos .checked_add(positions_len) @@ -746,106 +488,11 @@ fn decode_positions( )); } let payload = &bytes[*pos..positions_end]; - let positions = match positions_type { - TYPE_VEC3_F32 => { - let mut values = Vec::with_capacity(n_atoms); - for chunk in payload.chunks_exact(12) { - values.push([ - f32::from_le_bytes(arr(&chunk[0..4])?), - f32::from_le_bytes(arr(&chunk[4..8])?), - f32::from_le_bytes(arr(&chunk[8..12])?), - ]); - } - Vec3Data::F32(values) - } - TYPE_VEC3_F64 => { - let mut values = Vec::with_capacity(n_atoms); - for chunk in payload.chunks_exact(24) { - values.push([ - f64::from_le_bytes(arr(&chunk[0..8])?), - f64::from_le_bytes(arr(&chunk[8..16])?), - f64::from_le_bytes(arr(&chunk[16..24])?), - ]); - } - Vec3Data::F64(values) - } - _ => { - return Err(Error::InvalidData(format!( - "Unsupported positions type tag {}", - positions_type - ))); - } - }; + let positions = decode_vec3_data(payload, layout.positions_type, "positions")?; *pos = positions_end; Ok(positions) } -fn decode_float_scalar_data( - payload: &[u8], - type_tag: u8, - field_name: &str, -) -> Result { - match type_tag { - TYPE_FLOAT => { - if payload.len() != 8 { - return Err(Error::InvalidData(format!( - "{field_name} f64 payload truncated" - ))); - } - Ok(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))) - } - TYPE_FLOAT32 => { - if payload.len() != 4 { - return Err(Error::InvalidData(format!( - "{field_name} f32 payload truncated" - ))); - } - Ok(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))) - } - _ => Err(Error::InvalidData(format!( - "Unsupported {field_name} type tag {}", - type_tag - ))), - } -} - -fn decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { - match type_tag { - TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), - TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), - _ => Err(Error::InvalidData(format!( - "Unsupported {field_name} type tag {}", - type_tag - ))), - } -} - -fn decode_float_array_data( - payload: &[u8], - type_tag: u8, - field_name: &str, -) -> Result { - match type_tag { - TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), - TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), - _ => Err(Error::InvalidData(format!( - "Unsupported {field_name} type tag {}", - type_tag - ))), - } -} - -fn decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { - match type_tag { - TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), - TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), - _ => Err(Error::InvalidData(format!( - "Unsupported {field_name} type tag {}", - type_tag - ))), - } -} - fn decode_builtin_section( mol: &mut Molecule, key: &str, @@ -877,7 +524,7 @@ fn decode_builtin_section( Ok(()) } -fn deserialize_molecule_soa_with_positions(bytes: &[u8], positions_type: u8) -> Result { +fn deserialize_molecule_soa_with_layout(bytes: &[u8], layout: SoaLayout) -> Result { if bytes.len() < 6 { return Err(Error::InvalidData("SOA record too small".into())); } @@ -886,7 +533,7 @@ fn deserialize_molecule_soa_with_positions(bytes: &[u8], positions_type: u8) -> let n = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; pos += 4; - let positions = decode_positions(bytes, &mut pos, n, positions_type)?; + let positions = decode_positions(bytes, &mut pos, n, layout)?; let z_end = pos .checked_add(n) .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; @@ -961,6 +608,5 @@ pub(super) fn deserialize_molecule_soa( record_format: u32, positions_type_hint: Option, ) -> Result { - let positions_type = resolve_positions_type(record_format, positions_type_hint)?; - deserialize_molecule_soa_with_positions(bytes, positions_type) + deserialize_molecule_soa_with_layout(bytes, resolve_layout(record_format, positions_type_hint)?) } diff --git a/atompack/src/types.rs b/atompack/src/types.rs new file mode 100644 index 0000000..506e4af --- /dev/null +++ b/atompack/src/types.rs @@ -0,0 +1,162 @@ +// Copyright 2026 Entalpic +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PropertyValue { + Float(f64), + Int(i64), + String(String), + FloatArray(Vec), + Vec3Array(Vec<[f32; 3]>), + IntArray(Vec), + Float32Array(Vec), + Vec3ArrayF64(Vec<[f64; 3]>), + Int32Array(Vec), +} + +impl PropertyValue { + pub fn len(&self) -> Option { + match self { + PropertyValue::FloatArray(values) => Some(values.len()), + PropertyValue::Vec3Array(values) => Some(values.len()), + PropertyValue::IntArray(values) => Some(values.len()), + PropertyValue::Float32Array(values) => Some(values.len()), + PropertyValue::Vec3ArrayF64(values) => Some(values.len()), + PropertyValue::Int32Array(values) => Some(values.len()), + _ => None, + } + } + + pub fn is_empty(&self) -> bool { + self.len() == Some(0) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Vec3Data { + F32(Vec<[f32; 3]>), + F64(Vec<[f64; 3]>), +} + +impl Vec3Data { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn atom_position(&self, index: usize) -> Option<[f32; 3]> { + match self { + Self::F32(values) => values.get(index).copied(), + Self::F64(values) => values + .get(index) + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]), + } + } + + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0] as f64, value[1] as f64, value[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatScalarData { + F32(f32), + F64(f64), +} + +impl FloatScalarData { + pub fn as_f64(&self) -> f64 { + match self { + Self::F32(value) => *value as f64, + Self::F64(value) => *value, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatArrayData { + F32(Vec), + F64(Vec), +} + +impl FloatArrayData { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn to_f64_vec(&self) -> Vec { + match self { + Self::F32(values) => values.iter().map(|value| *value as f64).collect(), + Self::F64(values) => values.clone(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Mat3Data { + F32([[f32; 3]; 3]), + F64([[f64; 3]; 3]), +} + +impl Mat3Data { + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0] as f32, row[1] as f32, row[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0] as f64, row[1] as f64, row[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + } + } +} From a94bc99096e31b782dfdbff4a82cdd8b16086fdb Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sun, 10 May 2026 14:28:48 +0200 Subject: [PATCH 7/9] fix: satisfy ci lint on schema promotion guards --- atompack/src/storage/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index 783a26e..886aceb 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -424,7 +424,7 @@ impl AtomDatabase { fn resolved_record_format_for_schema(&self, schema: &SchemaLock) -> Result { match validate_schema_lock_for_record_format(self.record_format, schema) { Ok(()) => Ok(self.record_format), - Err(current_err) if self.can_promote_record_format() => { + Err(_) if self.can_promote_record_format() => { validate_schema_lock_for_record_format(RECORD_FORMAT_SOA_V3, schema)?; Ok(RECORD_FORMAT_SOA_V3) } @@ -448,7 +448,7 @@ impl AtomDatabase { let hint = positions_type_hint.or(lock.positions_type); let record = match record_schema(bytes, record_format, hint) { Ok(record) => record, - Err(current_err) if can_promote && record_format == RECORD_FORMAT_SOA_V2 => { + Err(_) if can_promote && record_format == RECORD_FORMAT_SOA_V2 => { let record = record_schema(bytes, RECORD_FORMAT_SOA_V3, hint)?; record_format = RECORD_FORMAT_SOA_V3; record From ecb3113b39fb55d5d0ffd33461e5a59c00457ee3 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sun, 10 May 2026 14:59:32 +0200 Subject: [PATCH 8/9] perf: simplify batched soa writer --- atompack-py/src/database_batch.rs | 605 +++++++++++++--------------- atompack-py/src/lib.rs | 5 +- atompack-py/src/molecule.rs | 5 +- atompack-py/src/molecule_helpers.rs | 241 ++++++----- 4 files changed, 430 insertions(+), 426 deletions(-) diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 3a55bd3..75fef12 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -1,6 +1,10 @@ use super::*; -use crate::molecule::{SoaRecord, SoaSection, build_soa_record}; +use crate::molecule::{ + SoaBuiltinPayloads, SoaCustomSection, SoaTypedPayload, build_soa_record_unchecked, +}; +use crate::soa::is_per_atom; use atompack::storage::{DatabaseSchema, DatabaseSchemaSection}; +use numpy::PyReadonlyArray3; struct BatchSectionColumn { key: String, @@ -12,9 +16,9 @@ struct BatchSectionColumn { } impl BatchSectionColumn { - fn section_for<'a>(&'a self, index: usize) -> SoaSection<'a> { + fn custom_section_for<'a>(&'a self, index: usize) -> SoaCustomSection<'a> { if let Some(strings) = &self.strings { - return SoaSection { + return SoaCustomSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, @@ -23,82 +27,41 @@ impl BatchSectionColumn { } let start = index * self.slot_bytes; let end = start + self.slot_bytes; - SoaSection { + SoaCustomSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, payload: &self.payload[start..end], } } -} -fn batch_section_is_per_atom(kind: u8, key: &str) -> bool { - match kind { - KIND_ATOM_PROP => true, - KIND_MOL_PROP => false, - KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), - _ => false, - } -} - -fn database_schema_section(column: &BatchSectionColumn) -> PyResult { - let per_atom = batch_section_is_per_atom(column.kind, &column.key); - let elem_bytes = if column.type_tag == TYPE_STRING { - 0 - } else { - let elem_bytes = type_tag_elem_bytes(column.type_tag); - if elem_bytes == 0 { - return Err(PyValueError::new_err(format!( - "Unsupported section type tag {} for '{}'", - column.type_tag, column.key - ))); + fn payload_at<'a>(&'a self, index: usize) -> SoaTypedPayload<'a> { + let start = index * self.slot_bytes; + let end = start + self.slot_bytes; + SoaTypedPayload { + type_tag: self.type_tag, + payload: &self.payload[start..end], } - elem_bytes - }; - let slot_bytes = if column.type_tag == TYPE_STRING { - 0 - } else if per_atom { - elem_bytes - } else { - column.slot_bytes - }; + } - Ok(DatabaseSchemaSection { - kind: column.kind, - key: column.key.clone(), - type_tag: column.type_tag, - per_atom, - elem_bytes, - slot_bytes, - }) + fn string_at(&self, index: usize) -> Option<&str> { + self.strings.as_ref().map(|strings| strings[index].as_str()) + } } -fn build_batch_schema<'a, I>(positions_type: u8, columns: I) -> PyResult -where - I: IntoIterator, -{ - let sections = columns - .into_iter() - .map(database_schema_section) - .collect::>>()?; - Ok(DatabaseSchema { - positions_type: Some(positions_type), - sections, - }) +fn schema_section(kind: u8, key: &str, type_tag: u8, slot_bytes: usize) -> DatabaseSchemaSection { + DatabaseSchemaSection { + kind, + key: key.to_string(), + type_tag, + per_atom: is_per_atom(kind, key, type_tag), + elem_bytes: type_tag_elem_bytes(type_tag), + slot_bytes, + } } -fn push_builtin_section<'a>( - sections: &mut Vec>, - key: &'a str, - type_tag: u8, - payload: &'a [u8], -) { - sections.push(SoaSection { - kind: KIND_BUILTIN, - key, - type_tag, - payload, - }); +fn schema_section_from_column(column: &BatchSectionColumn) -> DatabaseSchemaSection { + schema_section(column.kind, &column.key, column.type_tag, column.slot_bytes) } fn reject_reserved_key(key: &str) -> PyResult<()> { @@ -473,74 +436,51 @@ fn extract_custom_columns( Ok(columns) } -struct FastMat3Column { - type_tag: u8, - slot_bytes: usize, - payload: Vec, +enum FastMat3Slice<'a> { + F32(&'a [f32]), + F64(&'a [f64]), } -impl FastMat3Column { - fn from_optional( - value: Option<&Bound<'_, PyAny>>, - batch: usize, - label: &str, - ) -> PyResult> { - let Some(value) = value else { - return Ok(None); - }; - if let Some(arr) = PyFloatArray3::from_any(value) { - match arr { - PyFloatArray3::F32(arr) => { - let ro = arr.readonly(); - if ro.as_array().shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "{label} must have shape ({}, 3, 3)", - batch - ))); - } - let slice = ro.as_slice().map_err(|_| { - PyValueError::new_err(format!("{label} must be C-contiguous")) - })?; - return Ok(Some(Self { - type_tag: TYPE_MAT3X3_F32, - slot_bytes: 36, - payload: bytemuck::cast_slice::(slice).to_vec(), - })); - } - PyFloatArray3::F64(arr) => { - let ro = arr.readonly(); - if ro.as_array().shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "{label} must have shape ({}, 3, 3)", - batch - ))); - } - let slice = ro.as_slice().map_err(|_| { - PyValueError::new_err(format!("{label} must be C-contiguous")) - })?; - return Ok(Some(Self { - type_tag: TYPE_MAT3X3_F64, - slot_bytes: 72, - payload: bytemuck::cast_slice::(slice).to_vec(), - })); - } - } +impl<'a> FastMat3Slice<'a> { + fn from_f32(ro: &'a PyReadonlyArray3<'_, f32>, batch: usize, label: &str) -> PyResult { + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); } - Ok(None) + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + Ok(Self::F32(slice)) } - fn type_tag(&self) -> u8 { - self.type_tag - } - - fn slot_bytes(&self) -> usize { - self.slot_bytes + fn from_f64(ro: &'a PyReadonlyArray3<'_, f64>, batch: usize, label: &str) -> PyResult { + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + Ok(Self::F64(slice)) } - fn payload_bytes(&self, index: usize) -> &[u8] { - let start = index * self.slot_bytes; - let end = start + self.slot_bytes; - &self.payload[start..end] + fn payload_at(&self, index: usize) -> SoaTypedPayload<'_> { + let start = index * 9; + let end = start + 9; + match self { + Self::F32(slice) => SoaTypedPayload { + type_tag: TYPE_MAT3X3_F32, + payload: bytemuck::cast_slice::(&slice[start..end]), + }, + Self::F64(slice) => SoaTypedPayload { + type_tag: TYPE_MAT3X3_F64, + payload: bytemuck::cast_slice::(&slice[start..end]), + }, + } } } @@ -696,17 +636,45 @@ fn try_add_arrays_batch_fast_canonical( None }; - let cell_slice = match FastMat3Column::from_optional(cell, batch, "cell")? { - Some(column) => Some(column), - None if cell.is_some() => return Ok(false), + let cell_ro_f32 = match cell.and_then(PyFloatArray3::from_any) { + Some(PyFloatArray3::F32(arr)) => Some(arr.readonly()), + Some(PyFloatArray3::F64(_)) => None, + None => None, + }; + let cell_ro_f64 = match cell.and_then(PyFloatArray3::from_any) { + Some(PyFloatArray3::F64(arr)) => Some(arr.readonly()), + Some(PyFloatArray3::F32(_)) => None, None => None, }; + let cell_slice = if let Some(ro) = cell_ro_f64.as_ref() { + Some(FastMat3Slice::from_f64(ro, batch, "cell")?) + } else if let Some(ro) = cell_ro_f32.as_ref() { + Some(FastMat3Slice::from_f32(ro, batch, "cell")?) + } else if cell.is_some() { + return Ok(false); + } else { + None + }; - let stress_slice = match FastMat3Column::from_optional(stress, batch, "stress")? { - Some(column) => Some(column), - None if stress.is_some() => return Ok(false), + let stress_ro_f32 = match stress.and_then(PyFloatArray3::from_any) { + Some(PyFloatArray3::F32(arr)) => Some(arr.readonly()), + Some(PyFloatArray3::F64(_)) => None, None => None, }; + let stress_ro_f64 = match stress.and_then(PyFloatArray3::from_any) { + Some(PyFloatArray3::F64(arr)) => Some(arr.readonly()), + Some(PyFloatArray3::F32(_)) => None, + None => None, + }; + let stress_slice = if let Some(ro) = stress_ro_f64.as_ref() { + Some(FastMat3Slice::from_f64(ro, batch, "stress")?) + } else if let Some(ro) = stress_ro_f32.as_ref() { + Some(FastMat3Slice::from_f32(ro, batch, "stress")?) + } else if stress.is_some() { + return Ok(false); + } else { + None + }; let pbc_ro = pbc.map(|arr| arr.readonly()); let pbc_slice = if let Some(ro) = pbc_ro.as_ref() { @@ -736,120 +704,73 @@ fn try_add_arrays_batch_fast_canonical( } let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; - let mut builtin_columns = Vec::new(); + let mut schema_sections = Vec::with_capacity(8 + custom_columns.len()); if energy_slice.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "energy".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_FLOAT, - slot_bytes: 8, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section( + KIND_BUILTIN, + "energy", + TYPE_FLOAT, + std::mem::size_of::(), + )); } if forces_slice.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "forces".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_VEC3_F32, - slot_bytes: n_atoms * 12, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section(KIND_BUILTIN, "forces", TYPE_VEC3_F32, 12)); } if charges_slice.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "charges".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_F64_ARRAY, - slot_bytes: n_atoms * 8, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section(KIND_BUILTIN, "charges", TYPE_F64_ARRAY, 8)); } if velocities_slice.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "velocities".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_VEC3_F32, - slot_bytes: n_atoms * 12, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section( + KIND_BUILTIN, + "velocities", + TYPE_VEC3_F32, + 12, + )); } if let Some(column) = cell_slice.as_ref() { - builtin_columns.push(BatchSectionColumn { - key: "cell".to_string(), - kind: KIND_BUILTIN, - type_tag: column.type_tag(), - slot_bytes: column.slot_bytes(), - payload: Vec::new(), - strings: None, - }); + let (type_tag, slot_bytes) = match column { + FastMat3Slice::F32(_) => (TYPE_MAT3X3_F32, 36), + FastMat3Slice::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + schema_sections.push(schema_section(KIND_BUILTIN, "cell", type_tag, slot_bytes)); } if let Some(column) = stress_slice.as_ref() { - builtin_columns.push(BatchSectionColumn { - key: "stress".to_string(), - kind: KIND_BUILTIN, - type_tag: column.type_tag(), - slot_bytes: column.slot_bytes(), - payload: Vec::new(), - strings: None, - }); + let (type_tag, slot_bytes) = match column { + FastMat3Slice::F32(_) => (TYPE_MAT3X3_F32, 36), + FastMat3Slice::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + schema_sections.push(schema_section(KIND_BUILTIN, "stress", type_tag, slot_bytes)); } if pbc_slice.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "pbc".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_BOOL3, - slot_bytes: 3, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)); } if name.is_some() { - builtin_columns.push(BatchSectionColumn { - key: "name".to_string(), - kind: KIND_BUILTIN, - type_tag: TYPE_STRING, - slot_bytes: 0, - payload: Vec::new(), - strings: None, - }); + schema_sections.push(schema_section(KIND_BUILTIN, "name", TYPE_STRING, 0)); } - - let batch_schema = build_batch_schema( - TYPE_VEC3_F32, - builtin_columns.iter().chain(custom_columns.iter()), - )?; - let record_format = inner - .record_format_for_schema(batch_schema.clone()) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - let builtin_section_count = usize::from(energy_slice.is_some()) - + usize::from(forces_slice.is_some()) - + usize::from(charges_slice.is_some()) - + usize::from(velocities_slice.is_some()) - + usize::from(cell_slice.is_some()) - + usize::from(stress_slice.is_some()) - + usize::from(pbc_slice.is_some()) - + usize::from(name.is_some()); - + schema_sections.extend(custom_columns.iter().map(schema_section_from_column)); + let batch_schema = DatabaseSchema { + positions_type: Some(TYPE_VEC3_F32), + sections: schema_sections, + }; let build_record = |i: usize| { let pos_start = i * n_atoms * 3; let pos_end = pos_start + n_atoms * 3; let z_start = i * n_atoms; let z_end = z_start + n_atoms; - let forces_payload = forces_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); - let charges_payload = charges_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[z_start..z_end])); - let velocities_payload = velocities_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); - let cell_payload = cell_slice.as_ref().map(|column| column.payload_bytes(i)); - let stress_payload = stress_slice.as_ref().map(|column| column.payload_bytes(i)); + let forces_payload = forces_slice.as_ref().map(|slice| SoaTypedPayload { + type_tag: TYPE_VEC3_F32, + payload: bytemuck::cast_slice::(&slice[pos_start..pos_end]), + }); + let charges_payload = charges_slice.as_ref().map(|slice| SoaTypedPayload { + type_tag: TYPE_F64_ARRAY, + payload: bytemuck::cast_slice::(&slice[z_start..z_end]), + }); + let velocities_payload = velocities_slice.as_ref().map(|slice| SoaTypedPayload { + type_tag: TYPE_VEC3_F32, + payload: bytemuck::cast_slice::(&slice[pos_start..pos_end]), + }); + let cell_payload = cell_slice.as_ref().map(|column| column.payload_at(i)); + let stress_payload = stress_slice.as_ref().map(|column| column.payload_at(i)); let energy_bytes = energy_slice.as_ref().map(|slice| slice[i].to_le_bytes()); let pbc_payload = pbc_slice.as_ref().map(|slice| { [ @@ -858,57 +779,30 @@ fn try_add_arrays_batch_fast_canonical( slice[i * 3 + 2] as u8, ] }); + let custom_sections: Vec> = custom_columns + .iter() + .map(|column| column.custom_section_for(i)) + .collect(); - let mut sections = Vec::with_capacity(builtin_section_count + custom_columns.len()); - if let Some(payload) = charges_payload { - push_builtin_section(&mut sections, "charges", TYPE_F64_ARRAY, payload); - } - if let Some(payload) = cell_payload { - push_builtin_section( - &mut sections, - "cell", - cell_slice - .as_ref() - .map(FastMat3Column::type_tag) - .expect("cell type tag must exist when payload exists"), - payload, - ); - } - if let Some(bytes) = energy_bytes.as_ref() { - push_builtin_section(&mut sections, "energy", TYPE_FLOAT, bytes); - } - if let Some(payload) = forces_payload { - push_builtin_section(&mut sections, "forces", TYPE_VEC3_F32, payload); - } - if let Some(names) = name.as_ref() { - push_builtin_section(&mut sections, "name", TYPE_STRING, names[i].as_bytes()); - } - if let Some(payload) = pbc_payload.as_ref() { - push_builtin_section(&mut sections, "pbc", TYPE_BOOL3, payload); - } - if let Some(payload) = stress_payload { - push_builtin_section( - &mut sections, - "stress", - stress_slice - .as_ref() - .map(FastMat3Column::type_tag) - .expect("stress type tag must exist when payload exists"), - payload, - ); - } - if let Some(payload) = velocities_payload { - push_builtin_section(&mut sections, "velocities", TYPE_VEC3_F32, payload); - } - sections.extend(custom_columns.iter().map(|column| column.section_for(i))); - - build_soa_record(SoaRecord { - record_format, - positions_type: TYPE_VEC3_F32, - positions: bytemuck::cast_slice::(&pos_slice[pos_start..pos_end]), - atomic_numbers: &z_slice[z_start..z_end], - sections: §ions, - }) + build_soa_record_unchecked( + TYPE_VEC3_F32, + bytemuck::cast_slice::(&pos_slice[pos_start..pos_end]), + &z_slice[z_start..z_end], + SoaBuiltinPayloads { + energy: energy_bytes.as_ref().map(|bytes| SoaTypedPayload { + type_tag: TYPE_FLOAT, + payload: bytes.as_ref(), + }), + forces: forces_payload, + charges: charges_payload, + velocities: velocities_payload, + cell: cell_payload, + stress: stress_payload, + pbc: pbc_payload, + name: name.as_ref().map(|names| names[i].as_str()), + }, + &custom_sections, + ) .map(|record| (record, n_atoms as u32)) }; @@ -1193,142 +1087,191 @@ pub(super) fn add_arrays_batch_impl( let (batch, n_atoms, positions_type, positions_payload) = extract_positions_payload(positions)?; let atomic_numbers_payload = extract_atomic_numbers_payload(atomic_numbers, batch, n_atoms)?; - let mut builtin_columns = Vec::new(); - if let Some(energy) = energy { + let energy_column = if let Some(energy) = energy { let Some(array) = PyFloatArray1::from_any(energy) else { return Err(PyValueError::new_err( "energy must be a float32 or float64 ndarray with shape (batch,)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray1::F32(arr) => { extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT32)? } PyFloatArray1::F64(arr) => { extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT)? } - }); - } - if let Some(forces) = forces { + }) + } else { + None + }; + let forces_column = if let Some(forces) = forces { let Some(array) = PyFloatArray3::from_any(forces) else { return Err(PyValueError::new_err( "forces must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray3::F32(arr) => { extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F32)? } PyFloatArray3::F64(arr) => { extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F64)? } - }); - } - if let Some(charges) = charges { + }) + } else { + None + }; + let charges_column = if let Some(charges) = charges { let Some(array) = PyFloatArray2::from_any(charges) else { return Err(PyValueError::new_err( "charges must be a float32 or float64 ndarray with shape (batch, n_atoms)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray2::F32(arr) => { extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F32_ARRAY)? } PyFloatArray2::F64(arr) => { extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F64_ARRAY)? } - }); - } - if let Some(velocities) = velocities { + }) + } else { + None + }; + let velocities_column = if let Some(velocities) = velocities { let Some(array) = PyFloatArray3::from_any(velocities) else { return Err(PyValueError::new_err( "velocities must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray3::F32(arr) => { extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F32)? } PyFloatArray3::F64(arr) => { extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F64)? } - }); - } - if let Some(cell) = cell { + }) + } else { + None + }; + let cell_column = if let Some(cell) = cell { let Some(array) = PyFloatArray3::from_any(cell) else { return Err(PyValueError::new_err( "cell must be a float32 or float64 ndarray with shape (batch, 3, 3)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray3::F32(arr) => { extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F32)? } PyFloatArray3::F64(arr) => { extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F64)? } - }); - } - if let Some(stress) = stress { + }) + } else { + None + }; + let stress_column = if let Some(stress) = stress { let Some(array) = PyFloatArray3::from_any(stress) else { return Err(PyValueError::new_err( "stress must be a float32 or float64 ndarray with shape (batch, 3, 3)", )); }; - builtin_columns.push(match array { + Some(match array { PyFloatArray3::F32(arr) => { extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F32)? } PyFloatArray3::F64(arr) => { extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F64)? } - }); + }) + } else { + None + }; + let pbc_column = match pbc { + Some(pbc) => Some(extract_builtin_pbc_column(pbc, batch)?), + None => None, + }; + let name_column = extract_builtin_name_column(name, batch)?; + + let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + let mut schema_sections = Vec::with_capacity(8 + custom_columns.len()); + if let Some(column) = energy_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); } - if let Some(pbc) = pbc { - builtin_columns.push(extract_builtin_pbc_column(pbc, batch)?); + if let Some(column) = forces_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); } - if let Some(name) = extract_builtin_name_column(name, batch)? { - builtin_columns.push(name); + if let Some(column) = charges_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); } - - let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + if let Some(column) = velocities_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); + } + if let Some(column) = cell_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); + } + if let Some(column) = stress_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); + } + if let Some(column) = pbc_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); + } + if let Some(column) = name_column.as_ref() { + schema_sections.push(schema_section_from_column(column)); + } + schema_sections.extend(custom_columns.iter().map(schema_section_from_column)); + let batch_schema = DatabaseSchema { + positions_type: Some(positions_type), + sections: schema_sections, + }; let positions_slot_bytes = n_atoms .checked_mul(type_tag_elem_bytes(positions_type)) .ok_or_else(|| PyValueError::new_err("positions byte length overflow"))?; - let batch_schema = build_batch_schema( - positions_type, - builtin_columns.iter().chain(custom_columns.iter()), - )?; - let record_format = inner - .record_format_for_schema(batch_schema.clone()) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - let build_record = |index: usize| -> Result<(Vec, u32), String> { let pos_start = index * positions_slot_bytes; let pos_end = pos_start + positions_slot_bytes; let z_start = index * n_atoms; let z_end = z_start + n_atoms; - let mut sections = Vec::with_capacity(builtin_columns.len() + custom_columns.len()); - sections.extend( - builtin_columns - .iter() - .map(|column| column.section_for(index)), - ); - sections.extend( - custom_columns - .iter() - .map(|column| column.section_for(index)), - ); + let custom_sections: Vec> = custom_columns + .iter() + .map(|column| column.custom_section_for(index)) + .collect(); - let record = build_soa_record(SoaRecord { - record_format, + let record = build_soa_record_unchecked( positions_type, - positions: &positions_payload[pos_start..pos_end], - atomic_numbers: &atomic_numbers_payload[z_start..z_end], - sections: §ions, - })?; + &positions_payload[pos_start..pos_end], + &atomic_numbers_payload[z_start..z_end], + SoaBuiltinPayloads { + energy: energy_column + .as_ref() + .map(|column| column.payload_at(index)), + forces: forces_column + .as_ref() + .map(|column| column.payload_at(index)), + charges: charges_column + .as_ref() + .map(|column| column.payload_at(index)), + velocities: velocities_column + .as_ref() + .map(|column| column.payload_at(index)), + cell: cell_column.as_ref().map(|column| column.payload_at(index)), + stress: stress_column + .as_ref() + .map(|column| column.payload_at(index)), + pbc: pbc_column.as_ref().map(|column| { + let payload = column.payload_at(index); + [payload.payload[0], payload.payload[1], payload.payload[2]] + }), + name: name_column + .as_ref() + .and_then(|column| column.string_at(index)), + }, + &custom_sections, + )?; Ok((record, n_atoms as u32)) }; diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index 251fbd9..3e6b6e3 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -125,9 +125,8 @@ pub(crate) use self::py_dtypes::{ parse_mat3_field, parse_positions_field, parse_property_value, parse_vec3_field, }; pub(crate) use self::soa::{ - LazySection, SectionRef, SectionSchema, SoaContext, SoaMoleculeView, is_per_atom, - parse_mol_fast_soa, read_f64_scalar, read_i64_scalar, section_schema_from_ref, - type_tag_elem_bytes, validate_section_payload, + LazySection, SectionSchema, SoaContext, SoaMoleculeView, parse_mol_fast_soa, read_f64_scalar, + read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes, }; mod database; diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 81e128e..3d53029 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -59,8 +59,9 @@ enum MoleculeBacking { mod helpers; pub(crate) use self::helpers::{ - SoaRecord, SoaSection, build_soa_record, cast_or_decode_f32, cast_or_decode_f64, - cast_or_decode_i32, cast_or_decode_i64, pyarray1_from_cow, pyarray2_from_cow, + SoaBuiltinPayloads, SoaCustomSection, SoaTypedPayload, build_soa_record_unchecked, + cast_or_decode_f32, cast_or_decode_f64, cast_or_decode_i32, cast_or_decode_i64, + pyarray1_from_cow, pyarray2_from_cow, }; use self::helpers::{into_py_any, property_section_to_pyobject, property_value_to_pyobject}; diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index a71f6db..ad7e739 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -105,139 +105,200 @@ pub(crate) fn pyarray2_from_cow<'py, T: Element + Clone>( } } -pub(crate) struct SoaSection<'a> { - pub(crate) kind: u8, - pub(crate) key: &'a str, +pub(crate) struct SoaTypedPayload<'a> { pub(crate) type_tag: u8, pub(crate) payload: &'a [u8], } -pub(crate) struct SoaRecord<'a> { - pub(crate) record_format: u32, - pub(crate) positions_type: u8, - pub(crate) positions: &'a [u8], - pub(crate) atomic_numbers: &'a [u8], - pub(crate) sections: &'a [SoaSection<'a>], +pub(crate) struct SoaBuiltinPayloads<'a> { + pub(crate) energy: Option>, + pub(crate) forces: Option>, + pub(crate) charges: Option>, + pub(crate) velocities: Option>, + pub(crate) cell: Option>, + pub(crate) stress: Option>, + pub(crate) pbc: Option<[u8; 3]>, + pub(crate) name: Option<&'a str>, +} + +pub(crate) struct SoaCustomSection<'a> { + pub(crate) kind: u8, + pub(crate) key: &'a str, + pub(crate) type_tag: u8, + pub(crate) payload: &'a [u8], } -fn write_soa_section_raw(buf: &mut Vec, section: &SoaSection<'_>) -> Result<(), String> { - let key_len: u8 = section - .key +fn write_soa_section_raw( + buf: &mut Vec, + kind: u8, + key: &str, + type_tag: u8, + payload: &[u8], +) -> Result<(), String> { + let key_len: u8 = key .len() .try_into() - .map_err(|_| format!("Section key '{}' is too long", section.key))?; - let payload_len: u32 = section - .payload + .map_err(|_| format!("Section key '{}' is too long", key))?; + let payload_len: u32 = payload .len() .try_into() - .map_err(|_| format!("Section '{}' payload is too large", section.key))?; - buf.push(section.kind); + .map_err(|_| format!("Section '{}' payload is too large", key))?; + buf.push(kind); buf.push(key_len); - buf.extend_from_slice(section.key.as_bytes()); - buf.push(section.type_tag); + buf.extend_from_slice(key.as_bytes()); + buf.push(type_tag); buf.extend_from_slice(&payload_len.to_le_bytes()); - buf.extend_from_slice(section.payload); + buf.extend_from_slice(payload); Ok(()) } -pub(crate) fn build_soa_record(record: SoaRecord<'_>) -> Result, String> { - if !matches!( - record.record_format, - RECORD_FORMAT_SOA_V2 | RECORD_FORMAT_SOA_V3 - ) { - return Err(format!( - "Unsupported record format {}", - record.record_format - )); - } - let positions_elem_bytes = match record.positions_type { +pub(crate) fn build_soa_record_unchecked( + positions_type: u8, + positions: &[u8], + atomic_numbers: &[u8], + builtins: SoaBuiltinPayloads<'_>, + custom_sections: &[SoaCustomSection<'_>], +) -> Result, String> { + let positions_elem_bytes = match positions_type { TYPE_VEC3_F32 => 12usize, TYPE_VEC3_F64 => 24usize, other => { return Err(format!("Unsupported positions type tag {}", other)); } }; - if record.record_format == RECORD_FORMAT_SOA_V2 && record.positions_type != TYPE_VEC3_F32 { - return Err("record format 2 only supports float32 positions".to_string()); - } - if !record.positions.len().is_multiple_of(positions_elem_bytes) { + if !positions.len().is_multiple_of(positions_elem_bytes) { return Err(format!( "positions payload length ({}) is not a multiple of {}", - record.positions.len(), + positions.len(), positions_elem_bytes )); } - let n_atoms = record.positions.len() / positions_elem_bytes; - if record.atomic_numbers.len() != n_atoms { + let n_atoms = positions.len() / positions_elem_bytes; + if atomic_numbers.len() != n_atoms { return Err(format!( "Atomic numbers length ({}) doesn't match atom count ({})", - record.atomic_numbers.len(), + atomic_numbers.len(), n_atoms )); } - let mut n_sections = 0u16; let mut payload_bytes = 0usize; let mut section_overhead = 0usize; - let mut account_section = |payload_len: usize, key_len: usize| { - n_sections += 1; - payload_bytes += payload_len; - section_overhead += 1 + 1 + key_len + 1 + 4; + let mut section_count = 0usize; + let mut account_section = |key: &str, payload: &[u8]| { + section_count += 1; + payload_bytes += payload.len(); + section_overhead += 1 + 1 + key.len() + 1 + 4; }; - for section in record.sections { - let parsed = SectionRef { - kind: section.kind, - key: section.key, - type_tag: section.type_tag, - payload: section.payload, - }; - let per_atom = is_per_atom(parsed.kind, parsed.key, parsed.type_tag); - let elem_bytes = match parsed.type_tag { - TYPE_STRING => 0, - tag if per_atom => { - let elem_bytes = type_tag_elem_bytes(tag); - if elem_bytes == 0 { - return Err(format!( - "Unsupported per-atom section type tag {} for key '{}'", - tag, parsed.key - )); - } - elem_bytes - } - TYPE_FLOAT | TYPE_INT => 8, - TYPE_FLOAT32 => 4, - TYPE_BOOL3 => 3, - TYPE_MAT3X3_F32 => 36, - TYPE_MAT3X3_F64 => 72, - _ => parsed.payload.len(), - }; - let slot_bytes = if parsed.type_tag == TYPE_STRING { - 0 - } else if per_atom { - elem_bytes - } else { - parsed.payload.len() - }; - validate_section_payload(&parsed, per_atom, elem_bytes, slot_bytes, n_atoms) - .map_err(|e| format!("{}", e))?; - account_section(parsed.payload.len(), parsed.key.len()); + if let Some(payload) = builtins.charges.as_ref() { + account_section("charges", payload.payload); + } + if let Some(payload) = builtins.cell.as_ref() { + account_section("cell", payload.payload); + } + if let Some(payload) = builtins.energy.as_ref() { + account_section("energy", payload.payload); + } + if let Some(payload) = builtins.forces.as_ref() { + account_section("forces", payload.payload); + } + if let Some(name) = builtins.name { + account_section("name", name.as_bytes()); + } + if let Some(payload) = builtins.pbc.as_ref() { + account_section("pbc", payload); + } + if let Some(payload) = builtins.stress.as_ref() { + account_section("stress", payload.payload); + } + if let Some(payload) = builtins.velocities.as_ref() { + account_section("velocities", payload.payload); + } + for section in custom_sections { + account_section(section.key, section.payload); } + let n_sections: u16 = section_count + .try_into() + .map_err(|_| "Too many SOA sections".to_string())?; + let mut buf = Vec::with_capacity( - 4 + record.positions.len() - + record.atomic_numbers.len() - + 2 - + section_overhead - + payload_bytes, + 4 + positions.len() + atomic_numbers.len() + 2 + section_overhead + payload_bytes, ); buf.extend_from_slice(&(n_atoms as u32).to_le_bytes()); - buf.extend_from_slice(record.positions); - buf.extend_from_slice(record.atomic_numbers); + buf.extend_from_slice(positions); + buf.extend_from_slice(atomic_numbers); buf.extend_from_slice(&n_sections.to_le_bytes()); - for section in record.sections { - write_soa_section_raw(&mut buf, section)?; + if let Some(payload) = builtins.charges.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "charges", + payload.type_tag, + payload.payload, + )?; + } + if let Some(payload) = builtins.cell.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "cell", + payload.type_tag, + payload.payload, + )?; + } + if let Some(payload) = builtins.energy.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "energy", + payload.type_tag, + payload.payload, + )?; + } + if let Some(payload) = builtins.forces.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "forces", + payload.type_tag, + payload.payload, + )?; + } + if let Some(name) = builtins.name { + write_soa_section_raw(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes())?; + } + if let Some(payload) = builtins.pbc.as_ref() { + write_soa_section_raw(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, payload)?; + } + if let Some(payload) = builtins.stress.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "stress", + payload.type_tag, + payload.payload, + )?; + } + if let Some(payload) = builtins.velocities.as_ref() { + write_soa_section_raw( + &mut buf, + KIND_BUILTIN, + "velocities", + payload.type_tag, + payload.payload, + )?; + } + for section in custom_sections { + write_soa_section_raw( + &mut buf, + section.kind, + section.key, + section.type_tag, + section.payload, + )?; } Ok(buf) From fcbf0091a7da6f465fe2e0e88fa4d2ede29cbff6 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sun, 10 May 2026 15:08:03 +0200 Subject: [PATCH 9/9] refactor: simplify python batch setup helpers --- atompack-py/src/database_batch.rs | 240 ++++++++++++++++------------ atompack-py/src/molecule_helpers.rs | 211 ++++++++++++++---------- 2 files changed, 262 insertions(+), 189 deletions(-) diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 75fef12..42d2c06 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -919,6 +919,28 @@ fn extract_builtin_scalar_column( }) } +fn extract_optional_builtin_scalar_column( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + key: &str, + f32_type_tag: u8, + f64_type_tag: u8, +) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + let Some(array) = PyFloatArray1::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape (batch,)", + key + ))); + }; + Ok(Some(match array { + PyFloatArray1::F32(arr) => extract_builtin_scalar_column(&arr, batch, key, f32_type_tag)?, + PyFloatArray1::F64(arr) => extract_builtin_scalar_column(&arr, batch, key, f64_type_tag)?, + })) +} + fn extract_builtin_float_array_column( arr: &Bound<'_, PyArray2>, batch: usize, @@ -947,6 +969,33 @@ fn extract_builtin_float_array_column( }) } +fn extract_optional_builtin_float_array_column( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + n_atoms: usize, + key: &str, + f32_type_tag: u8, + f64_type_tag: u8, +) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + let Some(array) = PyFloatArray2::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape (batch, n_atoms)", + key + ))); + }; + Ok(Some(match array { + PyFloatArray2::F32(arr) => { + extract_builtin_float_array_column(&arr, batch, n_atoms, key, f32_type_tag)? + } + PyFloatArray2::F64(arr) => { + extract_builtin_float_array_column(&arr, batch, n_atoms, key, f64_type_tag)? + } + })) +} + fn extract_builtin_vec3_column( arr: &Bound<'_, PyArray3>, batch: usize, @@ -975,6 +1024,33 @@ fn extract_builtin_vec3_column( }) } +fn extract_optional_builtin_vec3_column( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + n_atoms: usize, + key: &str, + f32_type_tag: u8, + f64_type_tag: u8, +) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + let Some(array) = PyFloatArray3::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + key + ))); + }; + Ok(Some(match array { + PyFloatArray3::F32(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, key, f32_type_tag)? + } + PyFloatArray3::F64(arr) => { + extract_builtin_vec3_column(&arr, batch, n_atoms, key, f64_type_tag)? + } + })) +} + fn extract_builtin_mat3_column( arr: &Bound<'_, PyArray3>, batch: usize, @@ -1002,6 +1078,28 @@ fn extract_builtin_mat3_column( }) } +fn extract_optional_builtin_mat3_column( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + key: &str, + f32_type_tag: u8, + f64_type_tag: u8, +) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + let Some(array) = PyFloatArray3::from_any(value) else { + return Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape (batch, 3, 3)", + key + ))); + }; + Ok(Some(match array { + PyFloatArray3::F32(arr) => extract_builtin_mat3_column(&arr, batch, key, f32_type_tag)?, + PyFloatArray3::F64(arr) => extract_builtin_mat3_column(&arr, batch, key, f64_type_tag)?, + })) +} + fn extract_builtin_pbc_column( pbc: &Bound<'_, PyArray2>, batch: usize, @@ -1087,108 +1185,46 @@ pub(super) fn add_arrays_batch_impl( let (batch, n_atoms, positions_type, positions_payload) = extract_positions_payload(positions)?; let atomic_numbers_payload = extract_atomic_numbers_payload(atomic_numbers, batch, n_atoms)?; - let energy_column = if let Some(energy) = energy { - let Some(array) = PyFloatArray1::from_any(energy) else { - return Err(PyValueError::new_err( - "energy must be a float32 or float64 ndarray with shape (batch,)", - )); - }; - Some(match array { - PyFloatArray1::F32(arr) => { - extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT32)? - } - PyFloatArray1::F64(arr) => { - extract_builtin_scalar_column(&arr, batch, "energy", TYPE_FLOAT)? - } - }) - } else { - None - }; - let forces_column = if let Some(forces) = forces { - let Some(array) = PyFloatArray3::from_any(forces) else { - return Err(PyValueError::new_err( - "forces must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", - )); - }; - Some(match array { - PyFloatArray3::F32(arr) => { - extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F32)? - } - PyFloatArray3::F64(arr) => { - extract_builtin_vec3_column(&arr, batch, n_atoms, "forces", TYPE_VEC3_F64)? - } - }) - } else { - None - }; - let charges_column = if let Some(charges) = charges { - let Some(array) = PyFloatArray2::from_any(charges) else { - return Err(PyValueError::new_err( - "charges must be a float32 or float64 ndarray with shape (batch, n_atoms)", - )); - }; - Some(match array { - PyFloatArray2::F32(arr) => { - extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F32_ARRAY)? - } - PyFloatArray2::F64(arr) => { - extract_builtin_float_array_column(&arr, batch, n_atoms, "charges", TYPE_F64_ARRAY)? - } - }) - } else { - None - }; - let velocities_column = if let Some(velocities) = velocities { - let Some(array) = PyFloatArray3::from_any(velocities) else { - return Err(PyValueError::new_err( - "velocities must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", - )); - }; - Some(match array { - PyFloatArray3::F32(arr) => { - extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F32)? - } - PyFloatArray3::F64(arr) => { - extract_builtin_vec3_column(&arr, batch, n_atoms, "velocities", TYPE_VEC3_F64)? - } - }) - } else { - None - }; - let cell_column = if let Some(cell) = cell { - let Some(array) = PyFloatArray3::from_any(cell) else { - return Err(PyValueError::new_err( - "cell must be a float32 or float64 ndarray with shape (batch, 3, 3)", - )); - }; - Some(match array { - PyFloatArray3::F32(arr) => { - extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F32)? - } - PyFloatArray3::F64(arr) => { - extract_builtin_mat3_column(&arr, batch, "cell", TYPE_MAT3X3_F64)? - } - }) - } else { - None - }; - let stress_column = if let Some(stress) = stress { - let Some(array) = PyFloatArray3::from_any(stress) else { - return Err(PyValueError::new_err( - "stress must be a float32 or float64 ndarray with shape (batch, 3, 3)", - )); - }; - Some(match array { - PyFloatArray3::F32(arr) => { - extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F32)? - } - PyFloatArray3::F64(arr) => { - extract_builtin_mat3_column(&arr, batch, "stress", TYPE_MAT3X3_F64)? - } - }) - } else { - None - }; + let energy_column = + extract_optional_builtin_scalar_column(energy, batch, "energy", TYPE_FLOAT32, TYPE_FLOAT)?; + let forces_column = extract_optional_builtin_vec3_column( + forces, + batch, + n_atoms, + "forces", + TYPE_VEC3_F32, + TYPE_VEC3_F64, + )?; + let charges_column = extract_optional_builtin_float_array_column( + charges, + batch, + n_atoms, + "charges", + TYPE_F32_ARRAY, + TYPE_F64_ARRAY, + )?; + let velocities_column = extract_optional_builtin_vec3_column( + velocities, + batch, + n_atoms, + "velocities", + TYPE_VEC3_F32, + TYPE_VEC3_F64, + )?; + let cell_column = extract_optional_builtin_mat3_column( + cell, + batch, + "cell", + TYPE_MAT3X3_F32, + TYPE_MAT3X3_F64, + )?; + let stress_column = extract_optional_builtin_mat3_column( + stress, + batch, + "stress", + TYPE_MAT3X3_F32, + TYPE_MAT3X3_F64, + )?; let pbc_column = match pbc { Some(pbc) => Some(extract_builtin_pbc_column(pbc, batch)?), None => None, diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index ad7e739..45e1940 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -152,6 +152,62 @@ fn write_soa_section_raw( Ok(()) } +#[inline] +fn account_optional_typed_section( + count: &mut usize, + payload_bytes: &mut usize, + section_overhead: &mut usize, + key: &str, + payload: Option<&SoaTypedPayload<'_>>, +) { + if let Some(payload) = payload { + *count += 1; + *payload_bytes += payload.payload.len(); + *section_overhead += 1 + 1 + key.len() + 1 + 4; + } +} + +#[inline] +fn account_optional_string_section( + count: &mut usize, + payload_bytes: &mut usize, + section_overhead: &mut usize, + key: &str, + value: Option<&str>, +) { + if let Some(value) = value { + *count += 1; + *payload_bytes += value.len(); + *section_overhead += 1 + 1 + key.len() + 1 + 4; + } +} + +#[inline] +fn write_optional_typed_section( + buf: &mut Vec, + kind: u8, + key: &str, + payload: Option<&SoaTypedPayload<'_>>, +) -> Result<(), String> { + if let Some(payload) = payload { + write_soa_section_raw(buf, kind, key, payload.type_tag, payload.payload)?; + } + Ok(()) +} + +#[inline] +fn write_optional_string_section( + buf: &mut Vec, + kind: u8, + key: &str, + value: Option<&str>, +) -> Result<(), String> { + if let Some(value) = value { + write_soa_section_raw(buf, kind, key, TYPE_STRING, value.as_bytes())?; + } + Ok(()) +} + pub(crate) fn build_soa_record_unchecked( positions_type: u8, positions: &[u8], @@ -185,38 +241,64 @@ pub(crate) fn build_soa_record_unchecked( let mut payload_bytes = 0usize; let mut section_overhead = 0usize; let mut section_count = 0usize; - let mut account_section = |key: &str, payload: &[u8]| { + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "charges", + builtins.charges.as_ref(), + ); + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "cell", + builtins.cell.as_ref(), + ); + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "energy", + builtins.energy.as_ref(), + ); + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "forces", + builtins.forces.as_ref(), + ); + account_optional_string_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "name", + builtins.name, + ); + if builtins.pbc.is_some() { section_count += 1; - payload_bytes += payload.len(); - section_overhead += 1 + 1 + key.len() + 1 + 4; - }; - - if let Some(payload) = builtins.charges.as_ref() { - account_section("charges", payload.payload); - } - if let Some(payload) = builtins.cell.as_ref() { - account_section("cell", payload.payload); - } - if let Some(payload) = builtins.energy.as_ref() { - account_section("energy", payload.payload); - } - if let Some(payload) = builtins.forces.as_ref() { - account_section("forces", payload.payload); - } - if let Some(name) = builtins.name { - account_section("name", name.as_bytes()); - } - if let Some(payload) = builtins.pbc.as_ref() { - account_section("pbc", payload); - } - if let Some(payload) = builtins.stress.as_ref() { - account_section("stress", payload.payload); - } - if let Some(payload) = builtins.velocities.as_ref() { - account_section("velocities", payload.payload); - } + payload_bytes += 3; + section_overhead += 1 + 1 + "pbc".len() + 1 + 4; + } + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "stress", + builtins.stress.as_ref(), + ); + account_optional_typed_section( + &mut section_count, + &mut payload_bytes, + &mut section_overhead, + "velocities", + builtins.velocities.as_ref(), + ); for section in custom_sections { - account_section(section.key, section.payload); + section_count += 1; + payload_bytes += section.payload.len(); + section_overhead += 1 + 1 + section.key.len() + 1 + 4; } let n_sections: u16 = section_count @@ -231,66 +313,21 @@ pub(crate) fn build_soa_record_unchecked( buf.extend_from_slice(atomic_numbers); buf.extend_from_slice(&n_sections.to_le_bytes()); - if let Some(payload) = builtins.charges.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "charges", - payload.type_tag, - payload.payload, - )?; - } - if let Some(payload) = builtins.cell.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "cell", - payload.type_tag, - payload.payload, - )?; - } - if let Some(payload) = builtins.energy.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "energy", - payload.type_tag, - payload.payload, - )?; - } - if let Some(payload) = builtins.forces.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "forces", - payload.type_tag, - payload.payload, - )?; - } - if let Some(name) = builtins.name { - write_soa_section_raw(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes())?; - } + write_optional_typed_section(&mut buf, KIND_BUILTIN, "charges", builtins.charges.as_ref())?; + write_optional_typed_section(&mut buf, KIND_BUILTIN, "cell", builtins.cell.as_ref())?; + write_optional_typed_section(&mut buf, KIND_BUILTIN, "energy", builtins.energy.as_ref())?; + write_optional_typed_section(&mut buf, KIND_BUILTIN, "forces", builtins.forces.as_ref())?; + write_optional_string_section(&mut buf, KIND_BUILTIN, "name", builtins.name)?; if let Some(payload) = builtins.pbc.as_ref() { write_soa_section_raw(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, payload)?; } - if let Some(payload) = builtins.stress.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "stress", - payload.type_tag, - payload.payload, - )?; - } - if let Some(payload) = builtins.velocities.as_ref() { - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "velocities", - payload.type_tag, - payload.payload, - )?; - } + write_optional_typed_section(&mut buf, KIND_BUILTIN, "stress", builtins.stress.as_ref())?; + write_optional_typed_section( + &mut buf, + KIND_BUILTIN, + "velocities", + builtins.velocities.as_ref(), + )?; for section in custom_sections { write_soa_section_raw( &mut buf,