diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index 67dddae..a28c97e 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -7,6 +7,10 @@ import numpy.typing as npt from . import hub as hub +Float1D = npt.NDArray[np.float32] | npt.NDArray[np.float64] +Float2D = npt.NDArray[np.float32] | npt.NDArray[np.float64] +Float3D = npt.NDArray[np.float32] | npt.NDArray[np.float64] + class Atom: """ Represents a single atom with 3D coordinates and atomic number. @@ -115,29 +119,29 @@ class Molecule: def __init__( self, - positions: npt.NDArray[np.float32], + positions: Float2D, atomic_numbers: npt.NDArray[np.uint8], *, energy: float | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ..., + forces: Float2D | None = ..., + charges: Float1D | None = ..., + velocities: Float2D | None = ..., + cell: Float2D | None = ..., + stress: Float2D | None = ..., pbc: tuple[bool, bool, bool] | None = ..., name: str | None = ..., ) -> None: ... @staticmethod def from_arrays( - positions: npt.NDArray[np.float32], + positions: Float2D, atomic_numbers: npt.NDArray[np.uint8], *, energy: float | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ..., + forces: Float2D | None = ..., + charges: Float1D | None = ..., + velocities: Float2D | None = ..., + cell: Float2D | None = ..., + stress: Float2D | None = ..., pbc: tuple[bool, bool, bool] | None = ..., name: str | None = ..., ) -> Molecule: @@ -204,7 +208,7 @@ class Molecule: ... @property - def forces(self) -> npt.NDArray[np.float32] | None: + def forces(self) -> Float2D | None: """ Per-atom forces. @@ -216,7 +220,7 @@ class Molecule: ... @forces.setter - def forces(self, value: npt.NDArray[np.float32]) -> None: ... + def forces(self, value: Float2D) -> None: ... @property def energy(self) -> float | None: """ @@ -232,7 +236,7 @@ class Molecule: @energy.setter def energy(self, value: float | None) -> None: ... @property - def charges(self) -> npt.NDArray[np.float64] | None: + def charges(self) -> Float1D | None: """ Per-atom partial charges. @@ -244,9 +248,9 @@ class Molecule: ... @charges.setter - def charges(self, value: npt.NDArray[np.float64]) -> None: ... + def charges(self, value: Float1D) -> None: ... @property - def velocities(self) -> npt.NDArray[np.float32] | None: + def velocities(self) -> Float2D | None: """ Per-atom velocities. @@ -258,9 +262,9 @@ class Molecule: ... @velocities.setter - def velocities(self, value: npt.NDArray[np.float32]) -> None: ... + def velocities(self, value: Float2D) -> None: ... @property - def cell(self) -> npt.NDArray[np.float64] | None: + def cell(self) -> Float2D | None: """ Unit cell for periodic systems. @@ -272,9 +276,9 @@ class Molecule: ... @cell.setter - def cell(self, value: npt.NDArray[np.float64]) -> None: ... + def cell(self, value: Float2D) -> None: ... @property - def stress(self) -> npt.NDArray[np.float64] | None: + def stress(self) -> Float2D | None: """ Virial stress tensor. @@ -286,9 +290,9 @@ class Molecule: ... @stress.setter - def stress(self, value: npt.NDArray[np.float32] | npt.NDArray[np.float64]) -> None: ... + def stress(self, value: Float2D) -> None: ... @property - def positions(self) -> npt.NDArray[np.float32]: + def positions(self) -> Float2D: """ Atomic positions (read-only). @@ -505,15 +509,15 @@ class Database: ... def add_arrays_batch( self, - positions: npt.NDArray[np.float32], + positions: Float3D, atomic_numbers: npt.NDArray[np.uint8], *, - energy: npt.NDArray[np.float64] | None = ..., - forces: npt.NDArray[np.float32] | None = ..., - charges: npt.NDArray[np.float64] | None = ..., - velocities: npt.NDArray[np.float32] | None = ..., - cell: npt.NDArray[np.float64] | None = ..., - stress: npt.NDArray[np.float64] | None = ..., + energy: Float1D | None = ..., + forces: Float3D | None = ..., + charges: Float2D | None = ..., + velocities: Float3D | None = ..., + cell: Float3D | None = ..., + stress: Float3D | None = ..., pbc: npt.NDArray[np.bool_] | None = ..., name: Sequence[str] | None = ..., properties: dict[str, Any] | None = ..., diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index f472cc9..ac3688d 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -13,6 +13,8 @@ pub(crate) struct PyAtomDatabase { impl PyAtomDatabase { fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult { let compression = self.inner.compression(); + let record_format = self.inner.record_format(); + let positions_type = self.inner.positions_type(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); if use_mmap { @@ -22,7 +24,11 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", index)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes) + SoaMoleculeView::from_shared_bytes_inner( + bytes, + record_format, + positions_type, + ) } else { let compressed = self.inner.get_compressed_slice(index).ok_or_else(|| { @@ -43,7 +49,11 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed) + SoaMoleculeView::from_bytes_inner( + decompressed, + record_format, + positions_type, + ) } }) .map_err(|e| PyValueError::new_err(format!("{}", e))); @@ -55,7 +65,7 @@ impl PyAtomDatabase { let raw = raw_bytes.pop().ok_or_else(|| { PyValueError::new_err(format!("Missing raw bytes for molecule {}", index)) })?; - SoaMoleculeView::from_bytes(raw) + SoaMoleculeView::from_bytes(raw, record_format, positions_type) } } @@ -117,10 +127,13 @@ impl PyAtomDatabase { /// Add a molecule to the database fn add_molecule(&mut self, molecule: &PyMolecule) -> PyResult<()> { - if let Some((soa_bytes, n_atoms)) = molecule.soa_bytes() { + if let Some(view) = molecule.as_view() { return self .inner - .add_raw_soa_records(&[(soa_bytes, n_atoms)]) + .add_raw_soa_records_with_schema( + &[(view.raw_bytes(), view.n_atoms as u32)], + view.database_schema()?, + ) .map_err(|e| PyValueError::new_err(format!("{}", e))); } let owned = molecule.clone_as_owned()?; @@ -131,22 +144,47 @@ impl PyAtomDatabase { /// Add multiple molecules (processed in parallel) fn add_molecules(&mut self, molecules: Vec>) -> PyResult<()> { - // Split into view-backed (fast path) and owned molecules let mut raw_records: Vec<(&[u8], u32)> = Vec::new(); + let mut raw_views: Vec<&SoaMoleculeView> = Vec::new(); let mut owned_molecules: Vec = Vec::new(); for m in &molecules { - if let Some((soa_bytes, n_atoms)) = m.soa_bytes() { - raw_records.push((soa_bytes, n_atoms)); + if let Some(view) = m.as_view() { + raw_records.push((view.raw_bytes(), view.n_atoms as u32)); + raw_views.push(view); } else { owned_molecules.push(m.clone_as_owned()?); } } if !raw_records.is_empty() { - self.inner - .add_raw_soa_records(&raw_records) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + let mut fast_schema = None; + if let Some((first, rest)) = raw_views.split_first() { + let mut all_match = true; + for view in rest { + if !view.same_schema_as(first)? { + all_match = false; + break; + } + } + if all_match { + fast_schema = Some(first.database_schema()?); + } + } + + if let Some(schema) = fast_schema { + self.inner + .add_raw_soa_records_with_schema(&raw_records, schema) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + } else { + let raw_records_with_positions = raw_views + .iter() + .map(|view| (view.raw_bytes(), view.n_atoms as u32, view.positions_type)) + .collect::>(); + self.inner + .add_raw_soa_records_with_positions_type(&raw_records_with_positions) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + } } if !owned_molecules.is_empty() { let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect(); @@ -179,14 +217,14 @@ impl PyAtomDatabase { fn add_arrays_batch( &mut self, py: Python<'_>, - positions: &Bound<'_, PyArray3>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray2>, - energy: Option<&Bound<'_, PyArray1>>, - forces: Option<&Bound<'_, PyArray3>>, - charges: Option<&Bound<'_, PyArray2>>, - velocities: Option<&Bound<'_, PyArray3>>, - cell: Option<&Bound<'_, PyArray3>>, - stress: Option<&Bound<'_, PyArray3>>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, pbc: Option<&Bound<'_, PyArray2>>, name: Option>, properties: Option<&Bound<'_, PyDict>>, @@ -238,6 +276,8 @@ impl PyAtomDatabase { } let compression = self.inner.compression(); + let record_format = self.inner.record_format(); + let positions_type = self.inner.positions_type(); let use_mmap = self.inner.get_compressed_slice(0).is_some(); let views: Vec = if use_mmap { @@ -250,7 +290,11 @@ impl PyAtomDatabase { let bytes = self.inner.get_shared_mmap_bytes(idx).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", idx)) })?; - SoaMoleculeView::from_shared_bytes_inner(bytes) + SoaMoleculeView::from_shared_bytes_inner( + bytes, + record_format, + positions_type, + ) } else { let compressed = self.inner.get_compressed_slice(idx).ok_or_else(|| { @@ -271,7 +315,11 @@ impl PyAtomDatabase { compression, Some(uncompressed_size), )?; - SoaMoleculeView::from_bytes_inner(decompressed) + SoaMoleculeView::from_bytes_inner( + decompressed, + record_format, + positions_type, + ) } }) .collect() @@ -283,7 +331,7 @@ impl PyAtomDatabase { .map_err(|e| PyValueError::new_err(format!("{}", e)))?; raw_bytes .into_iter() - .map(SoaMoleculeView::from_bytes) + .map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format, positions_type)) .collect::>>() .map_err(|e| invalid_data(format!("{}", e))) } diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 477e600..40560d4 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -1,7 +1,8 @@ use super::*; -use crate::molecule::{SoaBuiltinPayloads, SoaCustomSection, build_soa_record_with_custom}; +use crate::molecule::{SoaRecord, SoaSection, build_soa_record}; +use atompack::storage::{DatabaseSchema, DatabaseSchemaSection}; -struct BatchCustomColumn { +struct BatchSectionColumn { key: String, kind: u8, type_tag: u8, @@ -10,10 +11,10 @@ struct BatchCustomColumn { strings: Option>, } -impl BatchCustomColumn { - fn section_for<'a>(&'a self, index: usize) -> SoaCustomSection<'a> { +impl BatchSectionColumn { + fn section_for<'a>(&'a self, index: usize) -> SoaSection<'a> { if let Some(strings) = &self.strings { - return SoaCustomSection { + return SoaSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, @@ -22,7 +23,7 @@ impl BatchCustomColumn { } let start = index * self.slot_bytes; let end = start + self.slot_bytes; - SoaCustomSection { + SoaSection { kind: self.kind, key: self.key.as_str(), type_tag: self.type_tag, @@ -31,6 +32,75 @@ impl BatchCustomColumn { } } +fn batch_section_is_per_atom(kind: u8, key: &str) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn database_schema_section(column: &BatchSectionColumn) -> PyResult { + let per_atom = batch_section_is_per_atom(column.kind, &column.key); + let elem_bytes = if column.type_tag == TYPE_STRING { + 0 + } else { + let elem_bytes = type_tag_elem_bytes(column.type_tag); + if elem_bytes == 0 { + return Err(PyValueError::new_err(format!( + "Unsupported section type tag {} for '{}'", + column.type_tag, column.key + ))); + } + elem_bytes + }; + let slot_bytes = if column.type_tag == TYPE_STRING { + 0 + } else if per_atom { + elem_bytes + } else { + column.slot_bytes + }; + + Ok(DatabaseSchemaSection { + kind: column.kind, + key: column.key.clone(), + type_tag: column.type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn build_batch_schema<'a, I>(positions_type: u8, columns: I) -> PyResult +where + I: IntoIterator, +{ + let sections = columns + .into_iter() + .map(database_schema_section) + .collect::>>()?; + Ok(DatabaseSchema { + positions_type: Some(positions_type), + sections, + }) +} + +fn push_builtin_section<'a>( + sections: &mut Vec>, + key: &'a str, + type_tag: u8, + payload: &'a [u8], +) { + sections.push(SoaSection { + kind: KIND_BUILTIN, + key, + type_tag, + payload, + }); +} + fn reject_reserved_key(key: &str) -> PyResult<()> { if key == "stress" { return Err(PyValueError::new_err( @@ -67,7 +137,7 @@ fn extract_string_column( batch: usize, key: &str, kind: u8, -) -> PyResult> { +) -> PyResult> { let Ok(strings) = value.extract::>() else { return Ok(None); }; @@ -79,7 +149,7 @@ fn extract_string_column( batch ))); } - Ok(Some(BatchCustomColumn { + Ok(Some(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_STRING, @@ -94,7 +164,7 @@ fn extract_scalar_column_f64>( batch: usize, key: &str, kind: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); if view.len() != batch { @@ -106,7 +176,7 @@ fn extract_scalar_column_f64>( ))); } let values: Vec = view.iter().copied().map(Into::into).collect(); - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_FLOAT, @@ -121,7 +191,7 @@ fn extract_scalar_column_i64>( batch: usize, key: &str, kind: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); if view.len() != batch { @@ -133,7 +203,7 @@ fn extract_scalar_column_i64>( ))); } let values: Vec = view.iter().copied().map(Into::into).collect(); - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag: TYPE_INT, @@ -149,7 +219,7 @@ fn extract_matrix_column( key: &str, kind: u8, type_tag: u8, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); let shape = view.shape(); @@ -162,7 +232,7 @@ fn extract_matrix_column( let slice = readonly.as_slice().map_err(|_| { PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key)) })?; - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag, @@ -180,7 +250,7 @@ fn extract_vec3_column( kind: u8, type_tag: u8, shape_label: &str, -) -> PyResult { +) -> PyResult { let readonly = arr.readonly(); let view = readonly.as_array(); let shape = view.shape(); @@ -193,7 +263,7 @@ fn extract_vec3_column( let slice = readonly.as_slice().map_err(|_| { PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key)) })?; - Ok(BatchCustomColumn { + Ok(BatchSectionColumn { key: key.to_string(), kind, type_tag, @@ -208,7 +278,7 @@ fn extract_property_column( batch: usize, key: &str, kind: u8, -) -> PyResult> { +) -> PyResult> { if let Some(column) = extract_string_column(value, batch, key, kind)? { return Ok(Some(column)); } @@ -300,7 +370,7 @@ fn extract_atom_property_column( batch: usize, n_atoms: usize, key: &str, -) -> PyResult> { +) -> PyResult> { if let Ok(arr) = value.cast::>() { let column = extract_matrix_column(arr, batch, key, KIND_ATOM_PROP, TYPE_F64_ARRAY)?; if column.slot_bytes != n_atoms * std::mem::size_of::() { @@ -371,7 +441,7 @@ fn extract_custom_columns( atom_properties: Option<&Bound<'_, PyDict>>, batch: usize, n_atoms: usize, -) -> PyResult> { +) -> PyResult> { let mut columns = Vec::new(); let mut seen_keys = std::collections::HashSet::new(); @@ -408,23 +478,130 @@ fn extract_custom_columns( Ok(columns) } +struct FastMat3Column { + type_tag: u8, + slot_bytes: usize, + payload: Vec, +} + +impl FastMat3Column { + fn from_optional( + value: Option<&Bound<'_, PyAny>>, + batch: usize, + label: &str, + ) -> PyResult> { + let Some(value) = value else { + return Ok(None); + }; + if let Ok(arr) = value.cast::>() { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F32, + slot_bytes: 36, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } + if let Ok(arr) = value.cast::>() { + let ro = arr.readonly(); + if ro.as_array().shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{label} must have shape ({}, 3, 3)", + batch + ))); + } + let slice = ro + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{label} must be C-contiguous")))?; + return Ok(Some(Self { + type_tag: TYPE_MAT3X3_F64, + slot_bytes: 72, + payload: bytemuck::cast_slice::(slice).to_vec(), + })); + } + Ok(None) + } + + fn type_tag(&self) -> u8 { + self.type_tag + } + + fn slot_bytes(&self) -> usize { + self.slot_bytes + } + + fn payload_bytes(&self, index: usize) -> &[u8] { + let start = index * self.slot_bytes; + let end = start + self.slot_bytes; + &self.payload[start..end] + } +} + #[allow(clippy::too_many_arguments)] -pub(super) fn add_arrays_batch_impl( +fn try_add_arrays_batch_fast_canonical( inner: &mut AtomDatabase, py: Python<'_>, - positions: &Bound<'_, PyArray3>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray2>, - energy: Option<&Bound<'_, PyArray1>>, - forces: Option<&Bound<'_, PyArray3>>, - charges: Option<&Bound<'_, PyArray2>>, - velocities: Option<&Bound<'_, PyArray3>>, - cell: Option<&Bound<'_, PyArray3>>, - stress: Option<&Bound<'_, PyArray3>>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, pbc: Option<&Bound<'_, PyArray2>>, name: Option>, properties: Option<&Bound<'_, PyDict>>, atom_properties: Option<&Bound<'_, PyDict>>, -) -> PyResult<()> { +) -> PyResult { + let Ok(positions) = positions.cast::>() else { + return Ok(false); + }; + let energy = match energy { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let forces = match forces { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let charges = match charges { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let velocities = match velocities { + Some(value) => { + let Ok(arr) = value.cast::>() else { + return Ok(false); + }; + Some(arr) + } + None => None, + }; + let pos = positions.readonly(); let pos_arr = pos.as_array(); let pos_shape = pos_arr.shape(); @@ -520,38 +697,16 @@ pub(super) fn add_arrays_batch_impl( None }; - let cell_ro = cell.map(|arr| arr.readonly()); - let cell_slice = if let Some(ro) = cell_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "cell must have shape ({}, 3, 3)", - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("cell must be C-contiguous"))?, - ) - } else { - None + let cell_slice = match FastMat3Column::from_optional(cell, batch, "cell")? { + Some(column) => Some(column), + None if cell.is_some() => return Ok(false), + None => None, }; - let stress_ro = stress.map(|arr| arr.readonly()); - let stress_slice = if let Some(ro) = stress_ro.as_ref() { - let view = ro.as_array(); - if view.shape() != [batch, 3, 3] { - return Err(PyValueError::new_err(format!( - "stress must have shape ({}, 3, 3)", - batch - ))); - } - Some( - ro.as_slice() - .map_err(|_| PyValueError::new_err("stress must be C-contiguous"))?, - ) - } else { - None + let stress_slice = match FastMat3Column::from_optional(stress, batch, "stress")? { + Some(column) => Some(column), + None if stress.is_some() => return Ok(false), + None => None, }; let pbc_ro = pbc.map(|arr| arr.readonly()); @@ -582,6 +737,103 @@ pub(super) fn add_arrays_batch_impl( } let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + let mut builtin_columns = Vec::new(); + if energy_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "energy".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_FLOAT, + slot_bytes: 8, + payload: Vec::new(), + strings: None, + }); + } + if forces_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "forces".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_VEC3_F32, + slot_bytes: n_atoms * 12, + payload: Vec::new(), + strings: None, + }); + } + if charges_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "charges".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_F64_ARRAY, + slot_bytes: n_atoms * 8, + payload: Vec::new(), + strings: None, + }); + } + if velocities_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "velocities".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_VEC3_F32, + slot_bytes: n_atoms * 12, + payload: Vec::new(), + strings: None, + }); + } + if let Some(column) = cell_slice.as_ref() { + builtin_columns.push(BatchSectionColumn { + key: "cell".to_string(), + kind: KIND_BUILTIN, + type_tag: column.type_tag(), + slot_bytes: column.slot_bytes(), + payload: Vec::new(), + strings: None, + }); + } + if let Some(column) = stress_slice.as_ref() { + builtin_columns.push(BatchSectionColumn { + key: "stress".to_string(), + kind: KIND_BUILTIN, + type_tag: column.type_tag(), + slot_bytes: column.slot_bytes(), + payload: Vec::new(), + strings: None, + }); + } + if pbc_slice.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "pbc".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_BOOL3, + slot_bytes: 3, + payload: Vec::new(), + strings: None, + }); + } + if name.is_some() { + builtin_columns.push(BatchSectionColumn { + key: "name".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_STRING, + slot_bytes: 0, + payload: Vec::new(), + strings: None, + }); + } + + let batch_schema = build_batch_schema( + TYPE_VEC3_F32, + builtin_columns.iter().chain(custom_columns.iter()), + )?; + let record_format = inner + .record_format_for_schema(batch_schema.clone()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + let builtin_section_count = usize::from(energy_slice.is_some()) + + usize::from(forces_slice.is_some()) + + usize::from(charges_slice.is_some()) + + usize::from(velocities_slice.is_some()) + + usize::from(cell_slice.is_some()) + + usize::from(stress_slice.is_some()) + + usize::from(pbc_slice.is_some()) + + usize::from(name.is_some()); let build_record = |i: usize| { let pos_start = i * n_atoms * 3; @@ -597,42 +849,526 @@ pub(super) fn add_arrays_batch_impl( let velocities_payload = velocities_slice .as_ref() .map(|slice| bytemuck::cast_slice::(&slice[pos_start..pos_end])); - let mat_start = i * 9; - let mat_end = mat_start + 9; - let cell_payload = cell_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[mat_start..mat_end])); - let stress_payload = stress_slice - .as_ref() - .map(|slice| bytemuck::cast_slice::(&slice[mat_start..mat_end])); - let pbc_value = pbc_slice - .as_ref() - .map(|slice| [slice[i * 3], slice[i * 3 + 1], slice[i * 3 + 2]]); - let name_value = name.as_ref().map(|names| names[i].as_str()); - let custom_sections: Vec> = custom_columns - .iter() - .map(|column| column.section_for(i)) - .collect(); - - build_soa_record_with_custom( - &pos_slice[pos_start..pos_end], - &z_slice[z_start..z_end], - SoaBuiltinPayloads { - energy: energy_slice.as_ref().map(|slice| slice[i]), - forces: forces_payload, - charges: charges_payload, - velocities: velocities_payload, - cell: cell_payload, - stress: stress_payload, - pbc: pbc_value, - name: name_value, - }, - &custom_sections, - ) - .map(|record| record.into_parts()) + let cell_payload = cell_slice.as_ref().map(|column| column.payload_bytes(i)); + let stress_payload = stress_slice.as_ref().map(|column| column.payload_bytes(i)); + let energy_bytes = energy_slice.as_ref().map(|slice| slice[i].to_le_bytes()); + let pbc_payload = pbc_slice.as_ref().map(|slice| { + [ + slice[i * 3] as u8, + slice[i * 3 + 1] as u8, + slice[i * 3 + 2] as u8, + ] + }); + + let mut sections = Vec::with_capacity(builtin_section_count + custom_columns.len()); + if let Some(payload) = charges_payload { + push_builtin_section(&mut sections, "charges", TYPE_F64_ARRAY, payload); + } + if let Some(payload) = cell_payload { + push_builtin_section( + &mut sections, + "cell", + cell_slice + .as_ref() + .map(FastMat3Column::type_tag) + .expect("cell type tag must exist when payload exists"), + payload, + ); + } + if let Some(bytes) = energy_bytes.as_ref() { + push_builtin_section(&mut sections, "energy", TYPE_FLOAT, bytes); + } + if let Some(payload) = forces_payload { + push_builtin_section(&mut sections, "forces", TYPE_VEC3_F32, payload); + } + if let Some(names) = name.as_ref() { + push_builtin_section(&mut sections, "name", TYPE_STRING, names[i].as_bytes()); + } + if let Some(payload) = pbc_payload.as_ref() { + push_builtin_section(&mut sections, "pbc", TYPE_BOOL3, payload); + } + if let Some(payload) = stress_payload { + push_builtin_section( + &mut sections, + "stress", + stress_slice + .as_ref() + .map(FastMat3Column::type_tag) + .expect("stress type tag must exist when payload exists"), + payload, + ); + } + if let Some(payload) = velocities_payload { + push_builtin_section(&mut sections, "velocities", TYPE_VEC3_F32, payload); + } + sections.extend(custom_columns.iter().map(|column| column.section_for(i))); + + build_soa_record(SoaRecord { + record_format, + positions_type: TYPE_VEC3_F32, + positions: bytemuck::cast_slice::(&pos_slice[pos_start..pos_end]), + atomic_numbers: &z_slice[z_start..z_end], + sections: §ions, + }) + .map(|record| (record, n_atoms as u32)) + }; + + let records: Vec<(Vec, u32)> = if batch >= 1024 { + use rayon::prelude::*; + (0..batch) + .into_par_iter() + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + } else { + (0..batch) + .map(build_record) + .collect::, _>>() + .map_err(PyValueError::new_err)? + }; + + py.detach(move || inner.add_owned_soa_records_with_schema(records, batch_schema)) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + Ok(true) +} + +fn extract_positions_payload(value: &Bound<'_, PyAny>) -> PyResult<(usize, usize, u8, Vec)> { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F32, + bytemuck::cast_slice::(slice).to_vec(), + )); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 3 || view.shape()[2] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (batch, n_atoms, 3)", + )); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err("positions must be C-contiguous"))?; + return Ok(( + view.shape()[0], + view.shape()[1], + TYPE_VEC3_F64, + bytemuck::cast_slice::(slice).to_vec(), + )); + } + Err(PyValueError::new_err( + "positions must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )) +} + +fn extract_atomic_numbers_payload( + atomic_numbers: &Bound<'_, PyArray2>, + batch: usize, + n_atoms: usize, +) -> PyResult> { + let readonly = atomic_numbers.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms] { + return Err(PyValueError::new_err(format!( + "atomic_numbers must have shape ({}, {})", + batch, n_atoms + ))); + } + Ok(readonly + .as_slice() + .map_err(|_| PyValueError::new_err("atomic_numbers must be C-contiguous"))? + .to_vec()) +} + +fn extract_builtin_scalar_column( + arr: &Bound<'_, PyArray1>, + batch: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.len() != batch { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match batch size ({})", + key, + view.len(), + batch + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} + +fn extract_builtin_float_array_column( + arr: &Bound<'_, PyArray2>, + batch: usize, + n_atoms: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, {})", + key, batch, n_atoms + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: n_atoms * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} + +fn extract_builtin_vec3_column( + arr: &Bound<'_, PyArray3>, + batch: usize, + n_atoms: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, n_atoms, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, {}, 3)", + key, batch, n_atoms + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: n_atoms * 3 * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} + +fn extract_builtin_mat3_column( + arr: &Bound<'_, PyArray3>, + batch: usize, + key: &str, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, 3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3, 3)", + key, batch + ))); + } + let slice = readonly + .as_slice() + .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", key)))?; + Ok(BatchSectionColumn { + key: key.to_string(), + kind: KIND_BUILTIN, + type_tag, + slot_bytes: 9 * std::mem::size_of::(), + payload: bytemuck::cast_slice::(slice).to_vec(), + strings: None, + }) +} + +fn extract_builtin_pbc_column( + pbc: &Bound<'_, PyArray2>, + batch: usize, +) -> PyResult { + let readonly = pbc.readonly(); + let view = readonly.as_array(); + if view.shape() != [batch, 3] { + return Err(PyValueError::new_err(format!( + "pbc must have shape ({}, 3)", + batch + ))); + } + Ok(BatchSectionColumn { + key: "pbc".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_BOOL3, + slot_bytes: 3, + payload: view.iter().map(|value| u8::from(*value)).collect(), + strings: None, + }) +} + +fn extract_builtin_name_column( + name: Option>, + batch: usize, +) -> PyResult> { + let Some(names) = name else { + return Ok(None); + }; + if names.len() != batch { + return Err(PyValueError::new_err(format!( + "name length ({}) doesn't match batch size ({})", + names.len(), + batch + ))); + } + Ok(Some(BatchSectionColumn { + key: "name".to_string(), + kind: KIND_BUILTIN, + type_tag: TYPE_STRING, + slot_bytes: 0, + payload: Vec::new(), + strings: Some(names), + })) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn add_arrays_batch_impl( + inner: &mut AtomDatabase, + py: Python<'_>, + positions: &Bound<'_, PyAny>, + atomic_numbers: &Bound<'_, PyArray2>, + energy: Option<&Bound<'_, PyAny>>, + forces: Option<&Bound<'_, PyAny>>, + charges: Option<&Bound<'_, PyAny>>, + velocities: Option<&Bound<'_, PyAny>>, + cell: Option<&Bound<'_, PyAny>>, + stress: Option<&Bound<'_, PyAny>>, + pbc: Option<&Bound<'_, PyArray2>>, + name: Option>, + properties: Option<&Bound<'_, PyDict>>, + atom_properties: Option<&Bound<'_, PyDict>>, +) -> PyResult<()> { + if try_add_arrays_batch_fast_canonical( + inner, + py, + positions, + atomic_numbers, + energy, + forces, + charges, + velocities, + cell, + stress, + pbc, + name.clone(), + properties, + atom_properties, + )? { + return Ok(()); + } + + let (batch, n_atoms, positions_type, positions_payload) = extract_positions_payload(positions)?; + let atomic_numbers_payload = extract_atomic_numbers_payload(atomic_numbers, batch, n_atoms)?; + + let mut builtin_columns = Vec::new(); + if let Some(energy) = energy { + if let Ok(arr) = energy.cast::>() { + builtin_columns.push(extract_builtin_scalar_column( + arr, + batch, + "energy", + TYPE_FLOAT32, + )?); + } else if let Ok(arr) = energy.cast::>() { + builtin_columns.push(extract_builtin_scalar_column( + arr, batch, "energy", TYPE_FLOAT, + )?); + } else { + return Err(PyValueError::new_err( + "energy must be a float32 or float64 ndarray with shape (batch,)", + )); + } + } + if let Some(forces) = forces { + if let Ok(arr) = forces.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "forces", + TYPE_VEC3_F32, + )?); + } else if let Ok(arr) = forces.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "forces", + TYPE_VEC3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "forces must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )); + } + } + if let Some(charges) = charges { + if let Ok(arr) = charges.cast::>() { + builtin_columns.push(extract_builtin_float_array_column( + arr, + batch, + n_atoms, + "charges", + TYPE_F32_ARRAY, + )?); + } else if let Ok(arr) = charges.cast::>() { + builtin_columns.push(extract_builtin_float_array_column( + arr, + batch, + n_atoms, + "charges", + TYPE_F64_ARRAY, + )?); + } else { + return Err(PyValueError::new_err( + "charges must be a float32 or float64 ndarray with shape (batch, n_atoms)", + )); + } + } + if let Some(velocities) = velocities { + if let Ok(arr) = velocities.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "velocities", + TYPE_VEC3_F32, + )?); + } else if let Ok(arr) = velocities.cast::>() { + builtin_columns.push(extract_builtin_vec3_column( + arr, + batch, + n_atoms, + "velocities", + TYPE_VEC3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "velocities must be a float32 or float64 ndarray with shape (batch, n_atoms, 3)", + )); + } + } + if let Some(cell) = cell { + if let Ok(arr) = cell.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "cell", + TYPE_MAT3X3_F32, + )?); + } else if let Ok(arr) = cell.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "cell", + TYPE_MAT3X3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "cell must be a float32 or float64 ndarray with shape (batch, 3, 3)", + )); + } + } + if let Some(stress) = stress { + if let Ok(arr) = stress.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "stress", + TYPE_MAT3X3_F32, + )?); + } else if let Ok(arr) = stress.cast::>() { + builtin_columns.push(extract_builtin_mat3_column( + arr, + batch, + "stress", + TYPE_MAT3X3_F64, + )?); + } else { + return Err(PyValueError::new_err( + "stress must be a float32 or float64 ndarray with shape (batch, 3, 3)", + )); + } + } + if let Some(pbc) = pbc { + builtin_columns.push(extract_builtin_pbc_column(pbc, batch)?); + } + if let Some(name) = extract_builtin_name_column(name, batch)? { + builtin_columns.push(name); + } + + let custom_columns = extract_custom_columns(properties, atom_properties, batch, n_atoms)?; + let positions_slot_bytes = n_atoms + .checked_mul(type_tag_elem_bytes(positions_type)) + .ok_or_else(|| PyValueError::new_err("positions byte length overflow"))?; + + let batch_schema = build_batch_schema( + positions_type, + builtin_columns.iter().chain(custom_columns.iter()), + )?; + let record_format = inner + .record_format_for_schema(batch_schema.clone()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + + let build_record = |index: usize| -> Result<(Vec, u32), String> { + let pos_start = index * positions_slot_bytes; + let pos_end = pos_start + positions_slot_bytes; + let z_start = index * n_atoms; + let z_end = z_start + n_atoms; + + let mut sections = Vec::with_capacity(builtin_columns.len() + custom_columns.len()); + sections.extend( + builtin_columns + .iter() + .map(|column| column.section_for(index)), + ); + sections.extend( + custom_columns + .iter() + .map(|column| column.section_for(index)), + ); + + let record = build_soa_record(SoaRecord { + record_format, + positions_type, + positions: &positions_payload[pos_start..pos_end], + atomic_numbers: &atomic_numbers_payload[z_start..z_end], + sections: §ions, + })?; + Ok((record, n_atoms as u32)) }; - let serialized: Vec<(Vec, u32)> = if batch >= 1024 { + let records: Vec<(Vec, u32)> = if batch >= 1024 { use rayon::prelude::*; (0..batch) .into_par_iter() @@ -646,6 +1382,6 @@ pub(super) fn add_arrays_batch_impl( .map_err(PyValueError::new_err)? }; - py.detach(|| inner.add_owned_soa_records(serialized)) + py.detach(move || inner.add_owned_soa_records_with_schema(records, batch_schema)) .map_err(|e| PyValueError::new_err(format!("{}", e))) } diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index c02b187..56de991 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -4,6 +4,11 @@ use crate::molecule::{ pyarray1_from_cow, pyarray2_from_cow, }; +enum FlatPositions { + F32(Vec), + F64(Vec), +} + pub(super) fn get_molecules_flat_soa_impl<'py>( inner: &AtomDatabase, py: Python<'py>, @@ -45,22 +50,129 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let compression = inner.compression(); let use_mmap = inner.get_compressed_slice(0).is_some(); - + let record_format = inner.record_format(); + let positions_type = match record_format { + RECORD_FORMAT_SOA_V2 => TYPE_VEC3_F32, + RECORD_FORMAT_SOA_V3 => inner + .positions_type() + .ok_or_else(|| invalid_data("Missing position dtype for batch"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let schema_info = inner.schema_info(); let raw_bytes_owned: Option>>; let schema: Vec; + let use_ordered_schema: bool; + + let ordered_schema_from_first = |bytes: &[u8]| -> atompack::Result> { + let first_md = parse_mol_fast_soa(bytes, record_format, Some(positions_type))?; + let n = first_md.n_atoms; + first_md + .sections + .iter() + .map(|s| section_schema_from_ref(s, n)) + .collect::>() + }; + + let schema_matches_ordered = + |ordered: &[SectionSchema], from_lock: &[SectionSchema]| { + if ordered.len() != from_lock.len() { + return false; + } + let ordered_lookup: std::collections::HashMap<(u8, &str), &SectionSchema> = + ordered + .iter() + .map(|entry| ((entry.kind, entry.key.as_str()), entry)) + .collect(); + from_lock.iter().all(|entry| { + ordered_lookup + .get(&(entry.kind, entry.key.as_str())) + .is_some_and(|candidate| { + candidate.type_tag == entry.type_tag + && candidate.per_atom == entry.per_atom + && candidate.elem_bytes == entry.elem_bytes + && candidate.slot_bytes == entry.slot_bytes + }) + }) + }; - if use_mmap { + if let Some(schema_info) = schema_info { + let schema_from_lock: Vec = schema_info + .sections + .into_iter() + .map(|section| SectionSchema { + kind: section.kind, + key: section.key, + type_tag: section.type_tag, + per_atom: section.per_atom, + elem_bytes: section.elem_bytes, + slot_bytes: section.slot_bytes, + }) + .collect(); + if use_mmap { + if compression == CompressionType::None { + let shared = inner.get_shared_mmap_bytes(indices[0]).ok_or_else(|| { + invalid_data(format!("Missing mmap bytes for molecule {}", indices[0])) + })?; + let ordered = ordered_schema_from_first(shared.as_slice())?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = None; + } else { + let compressed = + inner.get_compressed_slice(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing compressed bytes for molecule {}", + indices[0] + )) + })?; + let uncompressed_size = + inner.uncompressed_size(indices[0]).ok_or_else(|| { + invalid_data(format!( + "Missing uncompressed size for molecule {}", + indices[0] + )) + })? as usize; + let first_bytes = atompack::decompress_bytes( + compressed, + compression, + Some(uncompressed_size), + )?; + let ordered = ordered_schema_from_first(&first_bytes)?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = None; + } + } else { + let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; + let ordered = ordered_schema_from_first(&raw_bytes[0])?; + use_ordered_schema = schema_matches_ordered(&ordered, &schema_from_lock); + schema = if use_ordered_schema { + ordered + } else { + schema_from_lock + }; + raw_bytes_owned = Some(raw_bytes); + } + } else if use_mmap { if compression == CompressionType::None { let shared = inner.get_shared_mmap_bytes(indices[0]).ok_or_else(|| { invalid_data(format!("Missing mmap bytes for molecule {}", indices[0])) })?; - let first_md = parse_mol_fast_soa(shared.as_slice())?; - let n = first_md.n_atoms; - schema = first_md - .sections - .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; + schema = ordered_schema_from_first(shared.as_slice())?; + use_ordered_schema = true; } else { let compressed = inner.get_compressed_slice(indices[0]).ok_or_else(|| { invalid_data(format!( @@ -80,31 +192,52 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( compression, Some(uncompressed_size), )?; - let first_md = parse_mol_fast_soa(&first_bytes)?; - let n = first_md.n_atoms; - schema = first_md - .sections - .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; + schema = ordered_schema_from_first(&first_bytes)?; + use_ordered_schema = true; } raw_bytes_owned = None; } else { let (raw_bytes, _) = inner.read_decompress_parallel(&indices)?; - let first_md = parse_mol_fast_soa(&raw_bytes[0])?; - let n = first_md.n_atoms; - schema = first_md - .sections - .iter() - .map(|s| section_schema_from_ref(s, n)) - .collect::>()?; + schema = ordered_schema_from_first(&raw_bytes[0])?; + use_ordered_schema = true; raw_bytes_owned = Some(raw_bytes); } - let schema_keys: Vec<(u8, &[u8])> = - schema.iter().map(|s| (s.kind, s.key.as_bytes())).collect(); - - let mut positions = vec![0f32; total_atoms * 3]; + let schema_keys: Vec<(u8, &[u8])> = if use_ordered_schema { + schema + .iter() + .map(|entry| (entry.kind, entry.key.as_bytes())) + .collect() + } else { + Vec::new() + }; + let mut schema_lookup: std::collections::HashMap< + u8, + std::collections::HashMap, + > = std::collections::HashMap::new(); + if !use_ordered_schema { + for (index, entry) in schema.iter().enumerate() { + schema_lookup + .entry(entry.kind) + .or_default() + .insert(entry.key.clone(), index); + } + } + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let mut positions = match positions_type { + TYPE_VEC3_F32 => FlatPositions::F32(vec![0f32; total_atoms * 3]), + TYPE_VEC3_F64 => FlatPositions::F64(vec![0u8; total_atoms * positions_stride]), + _ => unreachable!(), + }; let mut atomic_numbers = vec![0u8; total_atoms]; let mut section_buffers: Vec> = schema @@ -131,7 +264,14 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( }) .collect(); - let pos_buf = RawBuf::new(&mut positions); + let pos_buf_f32 = match &mut positions { + FlatPositions::F32(values) => Some(RawBuf::new(values)), + FlatPositions::F64(_) => None, + }; + let pos_buf_f64 = match &mut positions { + FlatPositions::F32(_) => None, + FlatPositions::F64(values) => Some(RawBuf::new(values)), + }; let z_buf = RawBuf::new(&mut atomic_numbers); let sec_bufs: Vec> = section_buffers .iter_mut() @@ -153,24 +293,39 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .collect(); let process_mol = |i: usize, mol_bytes: &[u8]| -> atompack::Result<()> { - let md = parse_mol_fast_soa(mol_bytes)?; + let md = parse_mol_fast_soa(mol_bytes, record_format, Some(positions_type))?; let atom_off = offsets[i]; let n = md.n_atoms; - if md.sections.len() != schema.len() { - return Err(invalid_data(format!( - "SOA schema mismatch for molecule {}: expected {} sections, got {}", - i, - schema.len(), - md.sections.len() - ))); - } unsafe { - std::ptr::copy_nonoverlapping( - md.positions_bytes.as_ptr(), - pos_buf.at(atom_off * 3) as *mut u8, - n * 12, - ); + match positions_type { + TYPE_VEC3_F32 => { + let pos_buf = pos_buf_f32 + .as_ref() + .ok_or_else(|| invalid_data("missing f32 position buffer"))?; + std::ptr::copy_nonoverlapping( + md.positions_bytes.as_ptr(), + pos_buf.at(atom_off * 3) as *mut u8, + n * 12, + ); + } + TYPE_VEC3_F64 => { + let pos_buf = pos_buf_f64 + .as_ref() + .ok_or_else(|| invalid_data("missing f64 position buffer"))?; + std::ptr::copy_nonoverlapping( + md.positions_bytes.as_ptr(), + pos_buf.at(atom_off * positions_stride), + n * positions_stride, + ); + } + other => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + other + ))); + } + } std::ptr::copy_nonoverlapping( md.atomic_numbers_bytes.as_ptr(), z_buf.at(atom_off), @@ -178,59 +333,172 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ); } - for (section_idx, sec) in md.sections.iter().enumerate() { - let schema_entry = &schema[section_idx]; - let expected_key = &schema_keys[section_idx]; - if sec.kind != expected_key.0 || sec.key.as_bytes() != expected_key.1 { + if use_ordered_schema { + if md.sections.len() != schema.len() { return Err(invalid_data(format!( - "SOA schema order mismatch at molecule {} for section '{}'", - i, sec.key + "SOA schema mismatch for molecule {}: expected {} sections, got {}", + i, + schema.len(), + md.sections.len() ))); } - if schema_entry.per_atom { - let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { - invalid_data(format!("Section '{}' payload length overflow", sec.key)) - })?; - if sec.payload.len() != expected { + for (section_idx, sec) in md.sections.iter().enumerate() { + let schema_entry = &schema[section_idx]; + let expected_key = &schema_keys[section_idx]; + + if sec.kind != expected_key.0 || sec.key.as_bytes() != expected_key.1 { + return Err(invalid_data(format!( + "SOA schema order mismatch at molecule {} for section '{}'", + i, sec.key + ))); + } + + if sec.type_tag != schema_entry.type_tag { + return Err(invalid_data(format!( + "SOA schema mismatch at molecule {} for section '{}'", + i, sec.key + ))); + } + + if schema_entry.per_atom { + let expected = + n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { + invalid_data(format!( + "Section '{}' payload length overflow", + sec.key + )) + })?; + if sec.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, + sec.payload.len(), + expected + ))); + } + } else if schema_entry.slot_bytes != 0 + && sec.payload.len() != schema_entry.slot_bytes + { return Err(invalid_data(format!( "Section '{}' has invalid payload length {} (expected {})", sec.key, sec.payload.len(), - expected + schema_entry.slot_bytes ))); } + + if schema_entry.slot_bytes == 0 { + if let Some(ref mtx) = string_mutexes[section_idx] { + let val = Some( + std::str::from_utf8(sec.payload) + .map_err(|_| { + invalid_data(format!( + "Invalid UTF-8 in section '{}'", + sec.key + )) + })? + .to_string(), + ); + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("string section mutex poisoned"))?; + guard[i] = val; + } + } else { + let buf = &sec_bufs[section_idx]; + let offset = if schema_entry.per_atom { + atom_off * schema_entry.elem_bytes + } else { + i * schema_entry.slot_bytes + }; + unsafe { + std::ptr::copy_nonoverlapping( + sec.payload.as_ptr(), + buf.at(offset), + sec.payload.len(), + ); + } + } } + } else { + for sec in &md.sections { + let section_idx = schema_lookup + .get(&sec.kind) + .and_then(|entries| entries.get(sec.key)) + .copied() + .ok_or_else(|| { + invalid_data(format!( + "Unexpected SOA section '{}' in molecule {}", + sec.key, i + )) + })?; + let schema_entry = &schema[section_idx]; - if schema_entry.slot_bytes == 0 { - if let Some(ref mtx) = string_mutexes[section_idx] { - let val = Some( - std::str::from_utf8(sec.payload) - .map_err(|_| { - invalid_data(format!( - "Invalid UTF-8 in section '{}'", - sec.key - )) - })? - .to_string(), - ); - let mut guard = mtx - .lock() - .map_err(|_| invalid_data("string section mutex poisoned"))?; - guard[i] = val; + if sec.type_tag != schema_entry.type_tag { + return Err(invalid_data(format!( + "SOA schema mismatch at molecule {} for section '{}'", + i, sec.key + ))); } - } else { - let buf = &sec_bufs[section_idx]; - let offset = if schema_entry.per_atom { - atom_off * schema_entry.elem_bytes - } else { - i * schema_entry.slot_bytes - }; - unsafe { - std::ptr::copy_nonoverlapping( - sec.payload.as_ptr(), - buf.at(offset), + + if schema_entry.per_atom { + let expected = + n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { + invalid_data(format!( + "Section '{}' payload length overflow", + sec.key + )) + })?; + if sec.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, + sec.payload.len(), + expected + ))); + } + } else if schema_entry.slot_bytes != 0 + && sec.payload.len() != schema_entry.slot_bytes + { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + sec.key, sec.payload.len(), - ); + schema_entry.slot_bytes + ))); + } + + if schema_entry.slot_bytes == 0 { + if let Some(ref mtx) = string_mutexes[section_idx] { + let val = Some( + std::str::from_utf8(sec.payload) + .map_err(|_| { + invalid_data(format!( + "Invalid UTF-8 in section '{}'", + sec.key + )) + })? + .to_string(), + ); + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("string section mutex poisoned"))?; + guard[i] = val; + } + } else { + let buf = &sec_bufs[section_idx]; + let offset = if schema_entry.per_atom { + atom_off * schema_entry.elem_bytes + } else { + i * schema_entry.slot_bytes + }; + unsafe { + std::ptr::copy_nonoverlapping( + sec.payload.as_ptr(), + buf.at(offset), + sec.payload.len(), + ); + } } } } @@ -271,7 +539,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( }) .collect() } else { - let raw_bytes = raw_bytes_owned.unwrap(); + let raw_bytes = raw_bytes_owned.expect("raw bytes must exist without mmap"); raw_bytes .par_iter() .enumerate() @@ -287,7 +555,6 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( schema, section_buffers, string_sections, - n_mols, total_atoms, ))) }) @@ -300,7 +567,6 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( schema, section_buffers, string_results, - _n_mols, total_atoms, ) = match result { None => { @@ -320,12 +586,20 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let dict = PyDict::new(py); dict.set_item("n_atoms", PyArray1::from_vec(py, n_atoms_vec))?; - dict.set_item( - "positions", - PyArray1::from_vec(py, positions) - .reshape([total_atoms, 3]) - .map_err(|e| PyValueError::new_err(format!("{}", e)))?, - )?; + match positions { + FlatPositions::F32(values) => { + dict.set_item( + "positions", + PyArray1::from_vec(py, values) + .reshape([total_atoms, 3]) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?, + )?; + } + FlatPositions::F64(bytes) => { + let arr = cast_or_decode_f64(&bytes)?; + dict.set_item("positions", pyarray2_from_cow(py, arr, total_atoms, 3)?)?; + } + } dict.set_item("atomic_numbers", PyArray1::from_vec(py, atomic_numbers))?; let atom_props_dict = PyDict::new(py); @@ -361,6 +635,10 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let arr = cast_or_decode_f64(&buf)?; target.set_item(&s.key, pyarray1_from_cow(py, arr))?; } + TYPE_FLOAT32 => { + let arr = cast_or_decode_f32(&buf)?; + target.set_item(&s.key, pyarray1_from_cow(py, arr))?; + } TYPE_INT => { let arr = cast_or_decode_i64(&buf)?; target.set_item(&s.key, pyarray1_from_cow(py, arr))?; @@ -411,6 +689,16 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .map_err(|e| PyValueError::new_err(format!("{}", e)))?, )?; } + TYPE_MAT3X3_F32 => { + let arr = cast_or_decode_f32(&buf)?; + let n = arr.len() / 9; + target.set_item( + &s.key, + pyarray1_from_cow(py, arr) + .reshape([n, 3, 3]) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?, + )?; + } _ => {} } } diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index dcca0c5..e20e0d1 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -8,8 +8,8 @@ #![allow(unsafe_op_in_unsafe_fn)] use atompack::{ - Atom, AtomDatabase, Molecule, SharedMmapBytes, atom::PropertyValue, - compression::CompressionType, + Atom, AtomDatabase, FloatArrayData, FloatScalarData, Mat3Data, Molecule, SharedMmapBytes, + Vec3Data, atom::PropertyValue, compression::CompressionType, }; use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayMethods}; use pyo3::exceptions::{PyFileExistsError, PyIndexError, PyKeyError, PyTypeError, PyValueError}; @@ -111,704 +111,19 @@ const TYPE_VEC3_F64: u8 = 7; const TYPE_I32_ARRAY: u8 = 8; const TYPE_BOOL3: u8 = 9; const TYPE_MAT3X3_F64: u8 = 10; +const TYPE_FLOAT32: u8 = 11; +const TYPE_MAT3X3_F32: u8 = 12; -/// A single parsed section reference (zero-copy into decompressed bytes). -#[derive(Clone)] -struct SectionRef<'a> { - kind: u8, - key: &'a str, - type_tag: u8, - payload: &'a [u8], -} - -/// Per-molecule extracted data (references into decompressed bytes). -struct MolData<'a> { - n_atoms: usize, - positions_bytes: &'a [u8], // n_atoms * 12 - atomic_numbers_bytes: &'a [u8], // n_atoms - sections: Vec>, -} - -#[derive(Clone)] -struct SectionSchema { - kind: u8, - key: String, - type_tag: u8, - per_atom: bool, - elem_bytes: usize, - slot_bytes: usize, -} - -/// Parse SOA format bytes into MolData without allocation. -/// -/// Layout: -/// [n_atoms:u32][positions:n*12][atomic_numbers:n] -/// [n_sections:u16] -/// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] -fn parse_mol_fast_soa(bytes: &[u8]) -> atompack::Result> { - let mut pos = 0usize; - let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; - let positions_len = n_atoms - .checked_mul(12) - .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; - let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; - let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; - let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; - - let mut sections = Vec::with_capacity(n_sections); - for _ in 0..n_sections { - let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; - let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; - let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; - let key = std::str::from_utf8(key_bytes) - .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; - let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; - let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; - let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; - sections.push(SectionRef { - kind, - key, - type_tag, - payload, - }); - } - - Ok(MolData { - n_atoms, - positions_bytes, - atomic_numbers_bytes, - sections, - }) -} - -fn section_schema_from_ref( - section: &SectionRef<'_>, - n_atoms: usize, -) -> atompack::Result { - let per_atom = is_per_atom(section.kind, section.key, section.type_tag); - let elem_bytes = match section.type_tag { - TYPE_STRING => 0, - tag if per_atom => { - let elem_bytes = type_tag_elem_bytes(tag); - if elem_bytes == 0 { - return Err(invalid_data(format!( - "Unsupported per-atom section type tag {} for key '{}'", - tag, section.key - ))); - } - elem_bytes - } - TYPE_FLOAT | TYPE_INT => 8, - TYPE_BOOL3 => 3, - TYPE_MAT3X3_F64 => 72, - _ => section.payload.len(), - }; - let slot_bytes = if section.type_tag == TYPE_STRING { - 0 - } else if per_atom { - elem_bytes - } else { - section.payload.len() - }; - - validate_section_payload(section, per_atom, elem_bytes, slot_bytes, n_atoms)?; - - Ok(SectionSchema { - kind: section.kind, - key: section.key.to_string(), - type_tag: section.type_tag, - per_atom, - elem_bytes, - slot_bytes, - }) -} - -fn validate_section_payload( - section: &SectionRef<'_>, - per_atom: bool, - elem_bytes: usize, - slot_bytes: usize, - n_atoms: usize, -) -> atompack::Result<()> { - match section.type_tag { - TYPE_STRING => { - std::str::from_utf8(section.payload) - .map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?; - if per_atom { - return Err(invalid_data(format!( - "String section '{}' cannot be per-atom in flat extraction", - section.key - ))); - } - } - TYPE_FLOAT | TYPE_INT | TYPE_BOOL3 | TYPE_MAT3X3_F64 => { - if section.payload.len() != slot_bytes { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - slot_bytes - ))); - } - } - _ if per_atom => { - let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { - invalid_data(format!("Section '{}' payload length overflow", section.key)) - })?; - if section.payload.len() != expected { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - expected - ))); - } - } - _ => { - if slot_bytes != 0 && section.payload.len() != slot_bytes { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} (expected {})", - section.key, - section.payload.len(), - slot_bytes - ))); - } - if elem_bytes != 0 && !section.payload.len().is_multiple_of(elem_bytes) { - return Err(invalid_data(format!( - "Section '{}' has invalid payload length {} for element size {}", - section.key, - section.payload.len(), - elem_bytes - ))); - } - } - } - Ok(()) -} +const RECORD_FORMAT_SOA_V2: u32 = 2; +const RECORD_FORMAT_SOA_V3: u32 = 3; -/// Element size in bytes for a given type tag. Returns 0 for variable-length types. -fn type_tag_elem_bytes(tag: u8) -> usize { - match tag { - TYPE_FLOAT => 8, - TYPE_INT => 8, - TYPE_STRING => 0, // variable - TYPE_F64_ARRAY => 8, - TYPE_VEC3_F32 => 12, - TYPE_I64_ARRAY => 8, - TYPE_F32_ARRAY => 4, - TYPE_VEC3_F64 => 24, - TYPE_I32_ARRAY => 4, - TYPE_BOOL3 => 3, - TYPE_MAT3X3_F64 => 72, - _ => 0, - } -} +mod soa; -/// Whether a section with the given kind/key/type_tag is per-atom (vs per-molecule). -fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { - match kind { - KIND_ATOM_PROP => true, - KIND_MOL_PROP => false, - KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), - _ => false, - } -} - -/// Lightweight section descriptor — stores byte offsets into the parent `bytes` buffer. -/// Key is NOT parsed eagerly; use `key()` to read it lazily from `bytes`. -#[derive(Clone, Copy)] -struct LazySection { - kind: u8, - key_start: usize, - key_len: u8, - type_tag: u8, - payload_start: usize, - payload_len: usize, -} - -/// Byte-offset pair for a known builtin section (payload_start, payload_len, type_tag). -type BuiltinSlot = (usize, usize, u8); - -enum SoaBytes { - Owned(Vec), - Shared(SharedMmapBytes), -} - -impl SoaBytes { - #[inline] - fn as_slice(&self) -> &[u8] { - match self { - Self::Owned(bytes) => bytes, - Self::Shared(bytes) => bytes.as_slice(), - } - } -} - -impl std::ops::Deref for SoaBytes { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - self.as_slice() - } -} - -struct SoaMoleculeView { - bytes: SoaBytes, - n_atoms: usize, - positions_start: usize, - atomic_numbers_start: usize, - // Known builtins — zero-alloc, set during from_bytes - forces: Option, - energy: Option, - cell: Option, - stress: Option, - charges: Option, - velocities: Option, - pbc: Option, - name: Option, - // Custom properties — lazy, no String parsing until accessed - custom_sections: Vec, -} - -impl SoaMoleculeView { - /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. - fn from_storage_inner(bytes: SoaBytes) -> atompack::Result { - if bytes.len() < 6 { - return Err(invalid_data("SOA record too small")); - } - - let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; - let mut pos = 4usize; - let positions_start = pos; - pos = n_atoms - .checked_mul(12) - .and_then(|n| pos.checked_add(n)) - .ok_or_else(|| invalid_data("SOA positions overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA record truncated at positions")); - } - - let atomic_numbers_start = pos; - pos = pos - .checked_add(n_atoms) - .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA record truncated at atomic_numbers")); - } - - let n_sections = - u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; - pos += 2; - - let mut forces = None; - let mut energy = None; - let mut cell = None; - let mut stress = None; - let mut charges = None; - let mut velocities = None; - let mut pbc = None; - let mut name = None; - let mut custom_sections = Vec::new(); - - for _ in 0..n_sections { - if pos + 2 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let kind = bytes[pos]; - pos += 1; - let key_len = bytes[pos] as usize; - pos += 1; - if pos + key_len > bytes.len() { - return Err(invalid_data("SOA section key truncated")); - } - let key_start = pos; - pos += key_len; - if pos + 5 > bytes.len() { - return Err(invalid_data("SOA section header truncated")); - } - let type_tag = bytes[pos]; - pos += 1; - let payload_len = u32::from_le_bytes(slice_to_array( - &bytes[pos..pos + 4], - "SOA section payload length", - )?) as usize; - pos += 4; - let payload_start = pos; - pos = pos - .checked_add(payload_len) - .ok_or_else(|| invalid_data("SOA section payload overflow"))?; - if pos > bytes.len() { - return Err(invalid_data("SOA section payload truncated")); - } - - let key_bytes = &bytes[key_start..key_start + key_len]; - if kind == KIND_BUILTIN { - let slot = (payload_start, payload_len, type_tag); - match key_bytes { - b"forces" => forces = Some(slot), - b"energy" => energy = Some(slot), - b"cell" => cell = Some(slot), - b"stress" => stress = Some(slot), - b"charges" => charges = Some(slot), - b"velocities" => velocities = Some(slot), - b"pbc" => pbc = Some(slot), - b"name" => name = Some(slot), - _ => { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - } else { - custom_sections.push(LazySection { - kind, - key_start, - key_len: key_len as u8, - type_tag, - payload_start, - payload_len, - }); - } - } - - Ok(Self { - bytes, - n_atoms, - positions_start, - atomic_numbers_start, - forces, - energy, - cell, - stress, - charges, - velocities, - pbc, - name, - custom_sections, - }) - } - - fn from_bytes_inner(bytes: Vec) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Owned(bytes)) - } - - fn from_shared_bytes_inner(bytes: SharedMmapBytes) -> atompack::Result { - Self::from_storage_inner(SoaBytes::Shared(bytes)) - } - - /// Thin wrapper for call sites that need PyResult. - fn from_bytes(bytes: Vec) -> PyResult { - Self::from_bytes_inner(bytes).map_err(|e| PyValueError::new_err(format!("{}", e))) - } - - fn positions_bytes(&self) -> &[u8] { - &self.bytes[self.positions_start..self.positions_start + self.n_atoms * 12] - } - - fn atomic_numbers_bytes(&self) -> &[u8] { - &self.bytes[self.atomic_numbers_start..self.atomic_numbers_start + self.n_atoms] - } - - #[inline] - fn builtin_payload(&self, slot: BuiltinSlot) -> &[u8] { - &self.bytes[slot.0..slot.0 + slot.1] - } - - fn lazy_section_key(&self, s: &LazySection) -> PyResult<&str> { - std::str::from_utf8(&self.bytes[s.key_start..s.key_start + s.key_len as usize]) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in section key")) - } - - fn lazy_section_payload(&self, s: &LazySection) -> &[u8] { - &self.bytes[s.payload_start..s.payload_start + s.payload_len] - } - - fn find_custom_section(&self, kind: u8, key: &str) -> PyResult> { - for section in &self.custom_sections { - if section.kind == kind && self.lazy_section_key(section)? == key { - return Ok(Some(section)); - } - } - Ok(None) - } - - fn property_keys(&self) -> PyResult> { - self.custom_sections - .iter() - .filter(|s| s.kind == KIND_MOL_PROP) - .map(|s| Ok(self.lazy_section_key(s)?.to_string())) - .collect() - } - - fn atom_at(&self, index: usize) -> PyResult> { - if index >= self.n_atoms { - return Ok(None); - } - let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; - Ok(Some(Atom::new( - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), - self.atomic_numbers_bytes()[index], - ))) - } - - fn energy(&self) -> PyResult> { - match self.energy { - Some(slot) => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), - None => Ok(None), - } - } - - fn pbc(&self) -> PyResult> { - match self.pbc { - Some(slot) => { - let payload = self.builtin_payload(slot); - if payload.len() != 3 { - return Err(PyValueError::new_err("Invalid pbc payload length")); - } - Ok(Some((payload[0] != 0, payload[1] != 0, payload[2] != 0))) - } - None => Ok(None), - } - } - - fn materialize(&self) -> PyResult { - let positions: Vec<[f32; 3]> = (0..self.n_atoms) - .map(|i| { - let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; - Ok([ - f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), - f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), - f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), - ]) - }) - .collect::>()?; - let atomic_numbers = self.atomic_numbers_bytes().to_vec(); - let mut molecule = - Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)?; - - // Builtins - if let Some(slot) = self.charges { - molecule.charges = Some(decode_f64_array(self.builtin_payload(slot))?); - } - if let Some(slot) = self.cell { - molecule.cell = Some(decode_mat3x3_f64(self.builtin_payload(slot))?); - } - if let Some(slot) = self.energy { - molecule.energy = Some(read_f64_scalar(self.builtin_payload(slot))?); - } - if let Some(slot) = self.forces { - molecule.forces = Some(decode_vec3_f32(self.builtin_payload(slot))?); - } - if let Some(slot) = self.name { - let payload = self.builtin_payload(slot); - molecule.name = Some( - std::str::from_utf8(payload) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in name"))? - .to_string(), - ); - } - if let Some(slot) = self.pbc { - let payload = self.builtin_payload(slot); - if payload.len() != 3 { - return Err(PyValueError::new_err("Invalid pbc payload length")); - } - molecule.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); - } - if let Some(slot) = self.stress { - molecule.stress = Some(decode_mat3x3_f64(self.builtin_payload(slot))?); - } - if let Some(slot) = self.velocities { - molecule.velocities = Some(decode_vec3_f32(self.builtin_payload(slot))?); - } - - // Custom properties (lazy — key parsed here only) - for s in &self.custom_sections { - let key = self.lazy_section_key(s)?.to_string(); - let payload = self.lazy_section_payload(s); - match s.kind { - KIND_ATOM_PROP => { - molecule - .atom_properties - .insert(key, decode_property_value(s.type_tag, payload)?); - } - KIND_MOL_PROP => { - molecule - .properties - .insert(key, decode_property_value(s.type_tag, payload)?); - } - _ => {} - } - } - - Ok(molecule) - } -} - -fn read_f64_scalar(payload: &[u8]) -> PyResult { - if payload.len() != 8 { - return Err(PyValueError::new_err("Invalid f64 payload length")); - } - Ok(f64::from_le_bytes(py_slice_to_array( - payload, - "f64 payload", - )?)) -} - -fn read_i64_scalar(payload: &[u8]) -> PyResult { - if payload.len() != 8 { - return Err(PyValueError::new_err("Invalid i64 payload length")); - } - Ok(i64::from_le_bytes(py_slice_to_array( - payload, - "i64 payload", - )?)) -} - -fn decode_f64_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(8) { - return Err(PyValueError::new_err("Invalid f64 array payload length")); - } - payload - .chunks_exact(8) - .map(|chunk| { - Ok(f64::from_le_bytes(py_slice_to_array( - chunk, - "f64 array chunk", - )?)) - }) - .collect() -} - -fn decode_i64_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(8) { - return Err(PyValueError::new_err("Invalid i64 array payload length")); - } - payload - .chunks_exact(8) - .map(|chunk| { - Ok(i64::from_le_bytes(py_slice_to_array( - chunk, - "i64 array chunk", - )?)) - }) - .collect() -} - -fn decode_i32_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(4) { - return Err(PyValueError::new_err("Invalid i32 array payload length")); - } - payload - .chunks_exact(4) - .map(|chunk| { - Ok(i32::from_le_bytes(py_slice_to_array( - chunk, - "i32 array chunk", - )?)) - }) - .collect() -} - -fn decode_f32_array(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(4) { - return Err(PyValueError::new_err("Invalid f32 array payload length")); - } - payload - .chunks_exact(4) - .map(|chunk| { - Ok(f32::from_le_bytes(py_slice_to_array( - chunk, - "f32 array chunk", - )?)) - }) - .collect() -} - -fn decode_vec3_f32(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(12) { - return Err(PyValueError::new_err("Invalid vec3 payload length")); - } - payload - .chunks_exact(12) - .map(|chunk| { - Ok([ - f32::from_le_bytes(py_slice_to_array(&chunk[0..4], "vec3 x")?), - f32::from_le_bytes(py_slice_to_array(&chunk[4..8], "vec3 y")?), - f32::from_le_bytes(py_slice_to_array(&chunk[8..12], "vec3 z")?), - ]) - }) - .collect() -} - -fn decode_vec3_f64(payload: &[u8]) -> PyResult> { - if !payload.len().is_multiple_of(24) { - return Err(PyValueError::new_err("Invalid vec3 payload length")); - } - payload - .chunks_exact(24) - .map(|chunk| { - Ok([ - f64::from_le_bytes(py_slice_to_array(&chunk[0..8], "vec3 x")?), - f64::from_le_bytes(py_slice_to_array(&chunk[8..16], "vec3 y")?), - f64::from_le_bytes(py_slice_to_array(&chunk[16..24], "vec3 z")?), - ]) - }) - .collect() -} - -fn decode_mat3x3_f64(payload: &[u8]) -> PyResult<[[f64; 3]; 3]> { - if payload.len() != 72 { - return Err(PyValueError::new_err("Invalid mat3x3 payload length")); - } - Ok([ - [ - f64::from_le_bytes(py_slice_to_array(&payload[0..8], "mat3x3 [0][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[8..16], "mat3x3 [0][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[16..24], "mat3x3 [0][2]")?), - ], - [ - f64::from_le_bytes(py_slice_to_array(&payload[24..32], "mat3x3 [1][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[32..40], "mat3x3 [1][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[40..48], "mat3x3 [1][2]")?), - ], - [ - f64::from_le_bytes(py_slice_to_array(&payload[48..56], "mat3x3 [2][0]")?), - f64::from_le_bytes(py_slice_to_array(&payload[56..64], "mat3x3 [2][1]")?), - f64::from_le_bytes(py_slice_to_array(&payload[64..72], "mat3x3 [2][2]")?), - ], - ]) -} - -fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { - Ok(match type_tag { - TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), - TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?), - TYPE_STRING => PropertyValue::String( - std::str::from_utf8(payload) - .map_err(|_| PyValueError::new_err("Invalid UTF-8 in string property"))? - .to_string(), - ), - TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), - TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), - TYPE_I64_ARRAY => PropertyValue::IntArray(decode_i64_array(payload)?), - TYPE_F32_ARRAY => PropertyValue::Float32Array(decode_f32_array(payload)?), - TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), - TYPE_I32_ARRAY => PropertyValue::Int32Array(decode_i32_array(payload)?), - _ => { - return Err(PyValueError::new_err(format!( - "Unsupported property type tag {}", - type_tag - ))); - } - }) -} +pub(crate) use self::soa::{ + LazySection, SectionRef, SectionSchema, SoaMoleculeView, is_per_atom, parse_mol_fast_soa, + read_f64_scalar, read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes, + validate_section_payload, +}; mod database; mod molecule; diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 78a83e9..3e082f5 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -59,13 +59,11 @@ enum MoleculeBacking { mod helpers; pub(crate) use self::helpers::{ - SoaBuiltinPayloads, SoaCustomSection, build_soa_record_with_custom, cast_or_decode_f32, - cast_or_decode_f64, cast_or_decode_i32, cast_or_decode_i64, pyarray1_from_cow, - pyarray2_from_cow, -}; -use self::helpers::{ - into_py_any, property_section_to_pyobject, property_value_to_pyobject, pyarray2_from_flat, + SoaRecord, SoaSection, build_soa_record, cast_or_decode_f32, cast_or_decode_f64, + cast_or_decode_i32, cast_or_decode_i64, parse_float_array_field, parse_mat3_field, + parse_vec3_field, pyarray1_from_cow, pyarray2_from_cow, }; +use self::helpers::{into_py_any, property_section_to_pyobject, property_value_to_pyobject}; #[pymethods] impl PyMolecule { @@ -87,13 +85,13 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] fn new( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, @@ -113,17 +111,7 @@ impl PyMolecule { ) } - /// Create a molecule from numpy arrays (fast path). - /// - /// Parameters: - /// - positions: float32 array of shape (n_atoms, 3) - /// - atomic_numbers: uint8 array of shape (n_atoms,) - /// - builtins: optional keyword arguments such as energy, forces, charges, - /// velocities, cell, stress, pbc, and name - /// - /// Builds an SOA view directly from the numpy buffers — no intermediate - /// Atom structs are created. If you later mutate the molecule (e.g. set - /// energy), it will be materialized on demand. + /// Create a molecule from numpy arrays. #[staticmethod] #[pyo3(signature = ( positions, @@ -141,13 +129,13 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] fn from_arrays( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, @@ -214,9 +202,9 @@ impl PyMolecule { copy_arrays: bool, ) -> PyResult> { let numbers = self.atomic_numbers_py(py)?.into_any().unbind(); - let positions = self.positions_py(py)?.into_any().unbind(); + let positions = self.positions_py(py)?; let cell = match self.cell_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let pbc = match self.pbc()? { @@ -224,7 +212,7 @@ impl PyMolecule { None => py.None(), }; let velocities = match self.velocities_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let energy = match self.energy()? { @@ -232,15 +220,15 @@ impl PyMolecule { None => py.None(), }; let forces = match self.forces_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let stress = match self.stress_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; let charges = match self.charges_py(py)? { - Some(value) => value.into_any().unbind(), + Some(value) => value, None => py.None(), }; @@ -299,59 +287,16 @@ impl PyMolecule { /// forces property (getter) #[getter] - fn forces<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn forces<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .forces - .as_ref() - .map(|forces| { - let n_atoms = forces.len(); - let flat: Vec = forces.iter().flat_map(|f| [f[0], f[1], f[2]]).collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.forces else { - return Ok(None); - }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid forces section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f32(payload)?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) + slf.borrow().forces_py(py) } /// forces property (setter) #[setter] - fn set_forces(&mut self, forces: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = forces.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape[1] != 3 { - return Err(PyValueError::new_err("Forces must have shape (n_atoms, 3)")); - } - - let forces_vec: Vec<[f32; 3]> = arr - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - - if forces_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Forces length ({}) doesn't match atom count ({})", - forces_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.forces = Some(forces_vec); + fn set_forces(&mut self, py: Python<'_>, forces: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.forces = Some(parse_vec3_field(forces.bind(py), "Forces", n_atoms)?); Ok(()) } @@ -359,7 +304,7 @@ impl PyMolecule { #[getter] fn energy(&self) -> PyResult> { if let Some(inner) = self.as_owned() { - Ok(inner.energy) + Ok(inner.energy.as_ref().map(FloatScalarData::as_f64)) } else if let Some(view) = self.as_view() { view.energy() } else { @@ -370,242 +315,76 @@ impl PyMolecule { /// Set energy #[setter] fn set_energy(&mut self, energy: Option) -> PyResult<()> { - self.ensure_owned()?.energy = energy; + self.ensure_owned()?.energy = energy.map(FloatScalarData::F64); Ok(()) } /// charges property (getter) #[getter] - fn charges<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn charges<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return Ok(inner - .charges - .as_ref() - .map(|charges| PyArray1::from_slice(py, charges))); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.charges else { - return Ok(None); - }; - if slot.2 != TYPE_F64_ARRAY || slot.1 != view.n_atoms * 8 { - return Err(PyValueError::new_err("Invalid charges section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray1_from_cow(py, data))) + slf.borrow().charges_py(py) } /// charges property (setter) #[setter] - fn set_charges(&mut self, charges: &Bound<'_, PyArray1>) -> PyResult<()> { - let charges_vec: Vec = charges.readonly().as_array().to_vec(); - - if charges_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Charges length ({}) doesn't match atom count ({})", - charges_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.charges = Some(charges_vec); + fn set_charges(&mut self, py: Python<'_>, charges: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.charges = Some(parse_float_array_field( + charges.bind(py), + "Charges", + n_atoms, + )?); Ok(()) } /// velocities property (getter) #[getter] - fn velocities<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn velocities<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .velocities - .as_ref() - .map(|vels| { - let n_atoms = vels.len(); - let flat: Vec = vels.iter().flat_map(|v| [v[0], v[1], v[2]]).collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.velocities else { - return Ok(None); - }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid velocities section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f32(payload)?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) + slf.borrow().velocities_py(py) } /// velocities property (setter) #[setter] - fn set_velocities(&mut self, velocities: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = velocities.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Velocities must have shape (n_atoms, 3)", - )); - } - - let vel_vec: Vec<[f32; 3]> = arr - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(); - - if vel_vec.len() != self.len() { - return Err(PyValueError::new_err(format!( - "Velocities length ({}) doesn't match atom count ({})", - vel_vec.len(), - self.len() - ))); - } - - self.ensure_owned()?.velocities = Some(vel_vec); + fn set_velocities(&mut self, py: Python<'_>, velocities: Py) -> PyResult<()> { + let n_atoms = self.len(); + self.ensure_owned()?.velocities = Some(parse_vec3_field( + velocities.bind(py), + "Velocities", + n_atoms, + )?); Ok(()) } /// cell property (getter) #[getter] - fn cell<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn cell<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .cell - .as_ref() - .map(|cell| { - let flat: Vec = cell - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.cell else { - return Ok(None); - }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid cell section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) + slf.borrow().cell_py(py) } /// cell property (setter) #[setter] - fn set_cell(&mut self, cell: &Bound<'_, PyArray2>) -> PyResult<()> { - let readonly = cell.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Cell must have shape (3, 3)")); - } - - let cell_array: [[f64; 3]; 3] = [ - [arr[[0, 0]], arr[[0, 1]], arr[[0, 2]]], - [arr[[1, 0]], arr[[1, 1]], arr[[1, 2]]], - [arr[[2, 0]], arr[[2, 1]], arr[[2, 2]]], - ]; - - self.ensure_owned()?.cell = Some(cell_array); + fn set_cell(&mut self, py: Python<'_>, cell: Py) -> PyResult<()> { + self.ensure_owned()?.cell = Some(parse_mat3_field(cell.bind(py), "Cell")?); Ok(()) } /// stress property (getter) #[getter] - fn stress<'py>(slf: Bound<'py, Self>) -> PyResult>>> { + fn stress<'py>(slf: Bound<'py, Self>) -> PyResult>> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - return inner - .stress - .as_ref() - .map(|stress| { - let flat: Vec = stress - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) - }) - .transpose(); - } - let Some(view) = molecule.as_view() else { - return Ok(None); - }; - let Some(slot) = view.stress else { - return Ok(None); - }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid stress section")); - } - let payload = view.builtin_payload(slot); - let data = cast_or_decode_f64(payload)?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) + slf.borrow().stress_py(py) } /// stress property (setter) #[setter] fn set_stress(&mut self, py: Python<'_>, stress: Py) -> PyResult<()> { - let stress = stress.bind(py); - if let Ok(arr) = stress.cast::>() { - let readonly = arr.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Stress must have shape (3, 3)")); - } - - let inner = self.ensure_owned()?; - inner.stress = Some([ - [arr[[0, 0]], arr[[0, 1]], arr[[0, 2]]], - [arr[[1, 0]], arr[[1, 1]], arr[[1, 2]]], - [arr[[2, 0]], arr[[2, 1]], arr[[2, 2]]], - ]); - inner.properties.remove("stress"); - return Ok(()); - } - - if let Ok(arr) = stress.cast::>() { - let readonly = arr.readonly(); - let arr = readonly.as_array(); - let shape = arr.shape(); - - if shape != [3, 3] { - return Err(PyValueError::new_err("Stress must have shape (3, 3)")); - } - - let inner = self.ensure_owned()?; - inner.stress = Some([ - [arr[[0, 0]] as f64, arr[[0, 1]] as f64, arr[[0, 2]] as f64], - [arr[[1, 0]] as f64, arr[[1, 1]] as f64, arr[[1, 2]] as f64], - [arr[[2, 0]] as f64, arr[[2, 1]] as f64, arr[[2, 2]] as f64], - ]); - inner.properties.remove("stress"); - return Ok(()); - } - - Err(PyValueError::new_err( - "Stress must be a float32 or float64 ndarray with shape (3, 3)", - )) + let inner = self.ensure_owned()?; + inner.stress = Some(parse_mat3_field(stress.bind(py), "Stress")?); + inner.properties.remove("stress"); + Ok(()) } /// pbc property (getter) @@ -629,19 +408,9 @@ impl PyMolecule { /// positions property (read-only) #[getter] - fn positions<'py>(slf: Bound<'py, Self>) -> PyResult>> { + fn positions<'py>(slf: Bound<'py, Self>) -> PyResult> { let py = slf.py(); - let molecule = slf.borrow(); - if let Some(inner) = molecule.as_owned() { - let flat = inner.positions_flat(); - let n_atoms = inner.len(); - return pyarray2_from_flat(py, flat, n_atoms, 3); - } - let view = molecule.as_view().ok_or_else(|| { - PyValueError::new_err("Molecule is missing both owned and view state") - })?; - let pos_f32 = cast_or_decode_f32(view.positions_bytes())?; - pyarray2_from_cow(py, pos_f32, view.n_atoms, 3) + slf.borrow().positions_py(py) } /// atomic_numbers property (read-only) diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index 4512f7f..cb2f768 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -105,157 +105,75 @@ pub(crate) fn pyarray2_from_cow<'py, T: Element + Clone>( } } -struct WrittenSoaSection { - slot: BuiltinSlot, - section: LazySection, -} - -fn write_soa_section_raw( - buf: &mut Vec, - kind: u8, - key: &str, - type_tag: u8, - payload: &[u8], -) -> Result { - let key_len: u8 = key - .len() - .try_into() - .map_err(|_| format!("Section key '{}' is too long", key))?; - let payload_len: u32 = payload - .len() - .try_into() - .map_err(|_| format!("Section '{}' payload is too large", key))?; - buf.push(kind); - buf.push(key_len); - let key_start = buf.len(); - buf.extend_from_slice(key.as_bytes()); - buf.push(type_tag); - buf.extend_from_slice(&payload_len.to_le_bytes()); - let payload_start = buf.len(); - buf.extend_from_slice(payload); - Ok(WrittenSoaSection { - slot: (payload_start, payload.len(), type_tag), - section: LazySection { - kind, - key_start, - key_len, - type_tag, - payload_start, - payload_len: payload.len(), - }, - }) -} - -pub(crate) struct SoaBuiltinPayloads<'a> { - pub(crate) energy: Option, - pub(crate) forces: Option<&'a [u8]>, - pub(crate) charges: Option<&'a [u8]>, - pub(crate) velocities: Option<&'a [u8]>, - pub(crate) cell: Option<&'a [u8]>, - pub(crate) stress: Option<&'a [u8]>, - pub(crate) pbc: Option<[bool; 3]>, - pub(crate) name: Option<&'a str>, -} - -pub(crate) struct SoaCustomSection<'a> { +pub(crate) struct SoaSection<'a> { pub(crate) kind: u8, pub(crate) key: &'a str, pub(crate) type_tag: u8, pub(crate) payload: &'a [u8], } -pub(crate) struct BuiltSoaRecord { - bytes: Vec, - n_atoms: usize, - positions_start: usize, - atomic_numbers_start: usize, - forces: Option, - energy: Option, - cell: Option, - stress: Option, - charges: Option, - velocities: Option, - pbc: Option, - name: Option, - custom_sections: Vec, +pub(crate) struct SoaRecord<'a> { + pub(crate) record_format: u32, + pub(crate) positions_type: u8, + pub(crate) positions: &'a [u8], + pub(crate) atomic_numbers: &'a [u8], + pub(crate) sections: &'a [SoaSection<'a>], } -impl BuiltSoaRecord { - pub(crate) fn into_parts(self) -> (Vec, u32) { - (self.bytes, self.n_atoms as u32) - } - - pub(crate) fn into_view(self) -> SoaMoleculeView { - SoaMoleculeView { - bytes: SoaBytes::Owned(self.bytes), - n_atoms: self.n_atoms, - positions_start: self.positions_start, - atomic_numbers_start: self.atomic_numbers_start, - forces: self.forces, - energy: self.energy, - cell: self.cell, - stress: self.stress, - charges: self.charges, - velocities: self.velocities, - pbc: self.pbc, - name: self.name, - custom_sections: self.custom_sections, - } - } -} - -pub(crate) fn build_soa_record( - positions: &[f32], - atomic_numbers: &[u8], - builtins: SoaBuiltinPayloads<'_>, -) -> Result { - build_soa_record_with_custom(positions, atomic_numbers, builtins, &[]) +fn write_soa_section_raw(buf: &mut Vec, section: &SoaSection<'_>) -> Result<(), String> { + let key_len: u8 = section + .key + .len() + .try_into() + .map_err(|_| format!("Section key '{}' is too long", section.key))?; + let payload_len: u32 = section + .payload + .len() + .try_into() + .map_err(|_| format!("Section '{}' payload is too large", section.key))?; + buf.push(section.kind); + buf.push(key_len); + buf.extend_from_slice(section.key.as_bytes()); + buf.push(section.type_tag); + buf.extend_from_slice(&payload_len.to_le_bytes()); + buf.extend_from_slice(section.payload); + Ok(()) } -pub(crate) fn build_soa_record_with_custom( - positions: &[f32], - atomic_numbers: &[u8], - builtins: SoaBuiltinPayloads<'_>, - custom_sections: &[SoaCustomSection<'_>], -) -> Result { - if !positions.len().is_multiple_of(3) { - return Err("positions length must be divisible by 3".to_string()); - } - let n_atoms = positions.len() / 3; - if atomic_numbers.len() != n_atoms { +pub(crate) fn build_soa_record(record: SoaRecord<'_>) -> Result, String> { + if !matches!( + record.record_format, + RECORD_FORMAT_SOA_V2 | RECORD_FORMAT_SOA_V3 + ) { return Err(format!( - "Atomic numbers length ({}) doesn't match atom count ({})", - atomic_numbers.len(), - n_atoms + "Unsupported record format {}", + record.record_format )); } - - let validate_bytes = |payload: &[u8], expected: usize, label: &str| -> Result<(), String> { - if payload.len() != expected { - return Err(format!( - "{} payload length ({}) doesn't match expected byte length ({})", - label, - payload.len(), - expected - )); + let positions_elem_bytes = match record.positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + other => { + return Err(format!("Unsupported positions type tag {}", other)); } - Ok(()) }; - - if let Some(payload) = builtins.forces { - validate_bytes(payload, n_atoms * 12, "forces")?; + if record.record_format == RECORD_FORMAT_SOA_V2 && record.positions_type != TYPE_VEC3_F32 { + return Err("record format 2 only supports float32 positions".to_string()); } - if let Some(payload) = builtins.charges { - validate_bytes(payload, n_atoms * 8, "charges")?; - } - if let Some(payload) = builtins.velocities { - validate_bytes(payload, n_atoms * 12, "velocities")?; - } - if let Some(payload) = builtins.cell { - validate_bytes(payload, 72, "cell")?; + if !record.positions.len().is_multiple_of(positions_elem_bytes) { + return Err(format!( + "positions payload length ({}) is not a multiple of {}", + record.positions.len(), + positions_elem_bytes + )); } - if let Some(payload) = builtins.stress { - validate_bytes(payload, 72, "stress")?; + let n_atoms = record.positions.len() / positions_elem_bytes; + if record.atomic_numbers.len() != n_atoms { + return Err(format!( + "Atomic numbers length ({}) doesn't match atom count ({})", + record.atomic_numbers.len(), + n_atoms + )); } let mut n_sections = 0u16; @@ -267,31 +185,7 @@ pub(crate) fn build_soa_record_with_custom( section_overhead += 1 + 1 + key_len + 1 + 4; }; - if let Some(payload) = builtins.charges { - account_section(payload.len(), "charges".len()); - } - if let Some(payload) = builtins.cell { - account_section(payload.len(), "cell".len()); - } - if builtins.energy.is_some() { - account_section(std::mem::size_of::(), "energy".len()); - } - if let Some(payload) = builtins.forces { - account_section(payload.len(), "forces".len()); - } - if let Some(value) = builtins.name { - account_section(value.len(), "name".len()); - } - if builtins.pbc.is_some() { - account_section(3, "pbc".len()); - } - if let Some(payload) = builtins.stress { - account_section(payload.len(), "stress".len()); - } - if let Some(payload) = builtins.velocities { - account_section(payload.len(), "velocities".len()); - } - for section in custom_sections { + for section in record.sections { let parsed = SectionRef { kind: section.kind, key: section.key, @@ -312,7 +206,9 @@ pub(crate) fn build_soa_record_with_custom( elem_bytes } TYPE_FLOAT | TYPE_INT => 8, + TYPE_FLOAT32 => 4, TYPE_BOOL3 => 3, + TYPE_MAT3X3_F32 => 36, TYPE_MAT3X3_F64 => 72, _ => parsed.payload.len(), }; @@ -328,183 +224,175 @@ pub(crate) fn build_soa_record_with_custom( account_section(parsed.payload.len(), parsed.key.len()); } - let positions_start = 4usize; - let atomic_numbers_start = positions_start + positions.len() * 4; let mut buf = Vec::with_capacity( - 4 + positions.len() * 4 + atomic_numbers.len() + 2 + section_overhead + payload_bytes, + 4 + record.positions.len() + + record.atomic_numbers.len() + + 2 + + section_overhead + + payload_bytes, ); buf.extend_from_slice(&(n_atoms as u32).to_le_bytes()); - buf.extend_from_slice(bytemuck::cast_slice::(positions)); - buf.extend_from_slice(atomic_numbers); + buf.extend_from_slice(record.positions); + buf.extend_from_slice(record.atomic_numbers); buf.extend_from_slice(&n_sections.to_le_bytes()); - let mut charges = None; - let mut cell = None; - let mut energy_slot = None; - let mut forces = None; - let mut name = None; - let mut pbc = None; - let mut stress = None; - let mut velocities = None; - let mut custom_slots = Vec::with_capacity(custom_sections.len()); - - if let Some(payload) = builtins.charges { - charges = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, payload)?.slot, - ); - } - if let Some(payload) = builtins.cell { - cell = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, payload)?.slot, - ); - } - if let Some(value) = builtins.energy { - let payload = value.to_le_bytes(); - energy_slot = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "energy", TYPE_FLOAT, &payload)?.slot, - ); - } - if let Some(payload) = builtins.forces { - forces = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, payload)?.slot, - ); - } - if let Some(value) = builtins.name { - name = Some( - write_soa_section_raw( - &mut buf, - KIND_BUILTIN, - "name", - TYPE_STRING, - value.as_bytes(), - )? - .slot, - ); - } - if let Some([a, b, c]) = builtins.pbc { - let payload = [a as u8, b as u8, c as u8]; - pbc = - Some(write_soa_section_raw(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload)?.slot); - } - if let Some(payload) = builtins.stress { - stress = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, payload)?.slot, - ); - } - if let Some(payload) = builtins.velocities { - velocities = Some( - write_soa_section_raw(&mut buf, KIND_BUILTIN, "velocities", TYPE_VEC3_F32, payload)? - .slot, - ); - } - for section in custom_sections { - custom_slots.push( - write_soa_section_raw( - &mut buf, - section.kind, - section.key, - section.type_tag, - section.payload, - )? - .section, - ); - } - - Ok(BuiltSoaRecord { - bytes: buf, - n_atoms, - positions_start, - atomic_numbers_start, - forces, - energy: energy_slot, - cell, - stress, - charges, - velocities, - pbc, - name, - custom_sections: custom_slots, - }) + for section in record.sections { + write_soa_section_raw(&mut buf, section)?; + } + + Ok(buf) +} + +fn molecule_from_positions( + positions: Vec3Data, + atomic_numbers: Vec, +) -> Result { + match positions { + Vec3Data::F32(values) => Molecule::new(values, atomic_numbers), + Vec3Data::F64(values) => Molecule::new_f64(values, atomic_numbers), + } } -fn vec3_f32_payload<'py>( - readonly: &'py numpy::PyReadonlyArray2<'py, f32>, +pub(crate) fn parse_vec3_field( + value: &Bound<'_, PyAny>, label: &str, expected_rows: usize, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - let shape = arr.shape(); - if shape != [expected_rows, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape ({}, 3)", - label, expected_rows - ))); +) -> PyResult { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [expected_rows, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, expected_rows + ))); + } + return Ok(Vec3Data::F32( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [expected_rows, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape ({}, 3)", + label, expected_rows + ))); + } + return Ok(Vec3Data::F64( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) + Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({}, 3)", + label, expected_rows + ))) } -fn vec1_f64_payload<'py>( - readonly: &'py numpy::PyReadonlyArray1<'py, f64>, - label: &str, - expected_len: usize, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - if arr.len() != expected_len { - return Err(PyValueError::new_err(format!( - "{} length ({}) doesn't match atom count ({})", - label, - arr.len(), - expected_len - ))); +fn parse_positions_field(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 2 || view.shape()[1] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (n_atoms, 3)", + )); + } + return Ok(Vec3Data::F32( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape().len() != 2 || view.shape()[1] != 3 { + return Err(PyValueError::new_err( + "positions must have shape (n_atoms, 3)", + )); + } + return Ok(Vec3Data::F64( + view.outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), + )); } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) + Err(PyValueError::new_err( + "positions must be a float32 or float64 ndarray with shape (n_atoms, 3)", + )) } -fn mat3x3_f64_payload<'py>( - readonly: &'py numpy::PyReadonlyArray2<'py, f64>, +pub(crate) fn parse_float_array_field( + value: &Bound<'_, PyAny>, label: &str, -) -> PyResult<&'py [u8]> { - let arr = readonly.as_array(); - if arr.shape() != [3, 3] { - return Err(PyValueError::new_err(format!( - "{} must have shape (3, 3)", - label - ))); + expected_len: usize, +) -> PyResult { + if let Ok(arr) = value.cast::>() { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + return Ok(FloatArrayData::F32(values)); + } + if let Ok(arr) = value.cast::>() { + let values = arr.readonly().as_array().to_vec(); + if values.len() != expected_len { + return Err(PyValueError::new_err(format!( + "{} length ({}) doesn't match atom count ({})", + label, + values.len(), + expected_len + ))); + } + return Ok(FloatArrayData::F64(values)); } - let slice = readonly - .as_slice() - .map_err(|_| PyValueError::new_err(format!("{} must be C-contiguous", label)))?; - Ok(bytemuck::cast_slice::(slice)) + Err(PyValueError::new_err(format!( + "{} must be a float32 or float64 ndarray with shape ({},)", + label, expected_len + ))) } -fn mat3x3_f64_payload_from_any(py: Python<'_>, value: Py, label: &str) -> PyResult> { - let value = value.bind(py); - if let Ok(arr) = value.cast::>() { - let readonly = arr.readonly(); - return Ok(mat3x3_f64_payload(&readonly, label)?.to_vec()); - } +pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResult { if let Ok(arr) = value.cast::>() { let readonly = arr.readonly(); - let arr = readonly.as_array(); - if arr.shape() != [3, 3] { + let view = readonly.as_array(); + if view.shape() != [3, 3] { return Err(PyValueError::new_err(format!( "{} must have shape (3, 3)", label ))); } - let mut payload = Vec::with_capacity(72); - for row in arr.outer_iter() { - for value in row { - payload.extend_from_slice(&(*value as f64).to_le_bytes()); - } + return Ok(Mat3Data::F32([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + if view.shape() != [3, 3] { + return Err(PyValueError::new_err(format!( + "{} must have shape (3, 3)", + label + ))); } - return Ok(payload); + return Ok(Mat3Data::F64([ + [view[[0, 0]], view[[0, 1]], view[[0, 2]]], + [view[[1, 0]], view[[1, 1]], view[[1, 2]]], + [view[[2, 0]], view[[2, 1]], view[[2, 2]]], + ])); } Err(PyValueError::new_err(format!( "{} must be a float32 or float64 ndarray with shape (3, 3)", @@ -512,6 +400,17 @@ fn mat3x3_f64_payload_from_any(py: Python<'_>, value: Py, label: &str) -> ))) } +pub(crate) fn molecule_from_numpy_arrays( + positions: &Bound<'_, PyAny>, + atomic_numbers: &Bound<'_, PyArray1>, +) -> PyResult { + let z = atomic_numbers.readonly(); + let z_arr = z.as_array(); + let atomic_numbers_vec = z_arr.to_vec(); + let positions = parse_positions_field(positions)?; + molecule_from_positions(positions, atomic_numbers_vec).map_err(PyValueError::new_err) +} + pub(super) fn into_py_any<'py, T>(py: Python<'py>, value: T) -> PyResult> where T: IntoPyObject<'py>, @@ -624,31 +523,58 @@ fn is_reserved_ase_array_key(key: &str) -> bool { matches!(key, "numbers" | "positions") } -fn owned_vec3_array<'py>( - py: Python<'py>, - values: &[[f32; 3]], -) -> PyResult>> { - let n_atoms = values.len(); - let flat: Vec = values - .iter() - .flat_map(|value| [value[0], value[1], value[2]]) - .collect(); - pyarray2_from_flat(py, flat, n_atoms, 3) +fn owned_vec3_array<'py>(py: Python<'py>, values: &Vec3Data) -> PyResult> { + Ok(match values { + Vec3Data::F32(values) => { + let n_atoms = values.len(); + let flat: Vec = values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(); + pyarray2_from_flat(py, flat, n_atoms, 3)? + .into_any() + .unbind() + } + Vec3Data::F64(values) => { + let n_atoms = values.len(); + let flat: Vec = values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(); + pyarray2_from_flat(py, flat, n_atoms, 3)? + .into_any() + .unbind() + } + }) } -fn owned_mat3x3_array<'py>( - py: Python<'py>, - values: &[[f64; 3]; 3], -) -> PyResult>> { - let flat: Vec = values - .iter() - .flat_map(|row| [row[0], row[1], row[2]]) - .collect(); - pyarray2_from_flat(py, flat, 3, 3) +fn owned_float_array<'py>(py: Python<'py>, values: &FloatArrayData) -> Py { + match values { + FloatArrayData::F32(values) => PyArray1::from_slice(py, values).into_any().unbind(), + FloatArrayData::F64(values) => PyArray1::from_slice(py, values).into_any().unbind(), + } +} + +fn owned_mat3x3_array<'py>(py: Python<'py>, values: &Mat3Data) -> PyResult> { + Ok(match values { + Mat3Data::F32(values) => { + let flat: Vec = values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(); + pyarray2_from_flat(py, flat, 3, 3)?.into_any().unbind() + } + Mat3Data::F64(values) => { + let flat: Vec = values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(); + pyarray2_from_flat(py, flat, 3, 3)?.into_any().unbind() + } + }) } impl PyMolecule { - #[allow(dead_code)] pub(crate) fn from_owned(inner: Molecule) -> Self { Self { backing: MoleculeBacking::Owned(inner), @@ -661,14 +587,14 @@ impl PyMolecule { } } - pub(super) fn as_owned(&self) -> Option<&Molecule> { + pub(crate) fn as_owned(&self) -> Option<&Molecule> { match &self.backing { MoleculeBacking::Owned(inner) => Some(inner), MoleculeBacking::View(_) => None, } } - pub(super) fn as_view(&self) -> Option<&SoaMoleculeView> { + pub(crate) fn as_view(&self) -> Option<&SoaMoleculeView> { match &self.backing { MoleculeBacking::Owned(_) => None, MoleculeBacking::View(view) => Some(view), @@ -701,22 +627,31 @@ impl PyMolecule { } } - pub(crate) fn soa_bytes(&self) -> Option<(&[u8], u32)> { - match &self.backing { - MoleculeBacking::View(view) => Some((view.bytes.as_slice(), view.n_atoms as u32)), - MoleculeBacking::Owned(_) => None, - } - } - - pub(super) fn positions_py<'py>(&self, py: Python<'py>) -> PyResult>> { + pub(super) fn positions_py<'py>(&self, py: Python<'py>) -> PyResult> { if let Some(inner) = self.as_owned() { - return pyarray2_from_flat(py, inner.positions_flat(), inner.len(), 3); + return owned_vec3_array(py, &inner.positions); } let view = self.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") })?; - let positions = cast_or_decode_f32(view.positions_bytes())?; - pyarray2_from_cow(py, positions, view.n_atoms, 3) + match view.positions_type { + TYPE_VEC3_F32 => { + let positions = cast_or_decode_f32(view.positions_bytes())?; + Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? + .into_any() + .unbind()) + } + TYPE_VEC3_F64 => { + let positions = cast_or_decode_f64(view.positions_bytes())?; + Ok(pyarray2_from_cow(py, positions, view.n_atoms, 3)? + .into_any() + .unbind()) + } + other => Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))), + } } pub(super) fn atomic_numbers_py<'py>( @@ -732,10 +667,7 @@ impl PyMolecule { Ok(PyArray1::from_slice(py, view.atomic_numbers_bytes())) } - pub(super) fn forces_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn forces_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .forces @@ -749,22 +681,39 @@ impl PyMolecule { let Some(slot) = view.forces else { return Ok(None); }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid forces section")); + match slot.2 { + TYPE_VEC3_F32 => { + if slot.1 != view.n_atoms * 12 { + return Err(PyValueError::new_err("Invalid forces section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + TYPE_VEC3_F64 => { + if slot.1 != view.n_atoms * 24 { + return Err(PyValueError::new_err("Invalid forces section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + _ => Err(PyValueError::new_err("Invalid forces section")), } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) } - pub(super) fn charges_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn charges_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return Ok(inner .charges .as_ref() - .map(|charges| PyArray1::from_slice(py, charges))); + .map(|charges| owned_float_array(py, charges))); } let view = self.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") @@ -772,17 +721,26 @@ impl PyMolecule { let Some(slot) = view.charges else { return Ok(None); }; - if slot.2 != TYPE_F64_ARRAY || slot.1 != view.n_atoms * 8 { - return Err(PyValueError::new_err("Invalid charges section")); + match slot.2 { + TYPE_F32_ARRAY => { + if slot.1 != view.n_atoms * 4 { + return Err(PyValueError::new_err("Invalid charges section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) + } + TYPE_F64_ARRAY => { + if slot.1 != view.n_atoms * 8 { + return Err(PyValueError::new_err("Invalid charges section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray1_from_cow(py, data).into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid charges section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray1_from_cow(py, data))) } - pub(super) fn velocities_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn velocities_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .velocities @@ -796,17 +754,34 @@ impl PyMolecule { let Some(slot) = view.velocities else { return Ok(None); }; - if slot.2 != TYPE_VEC3_F32 || slot.1 != view.n_atoms * 12 { - return Err(PyValueError::new_err("Invalid velocities section")); + match slot.2 { + TYPE_VEC3_F32 => { + if slot.1 != view.n_atoms * 12 { + return Err(PyValueError::new_err("Invalid velocities section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + TYPE_VEC3_F64 => { + if slot.1 != view.n_atoms * 24 { + return Err(PyValueError::new_err("Invalid velocities section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some( + pyarray2_from_cow(py, data, view.n_atoms, 3)? + .into_any() + .unbind(), + )) + } + _ => Err(PyValueError::new_err("Invalid velocities section")), } - let data = cast_or_decode_f32(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, view.n_atoms, 3)?)) } - pub(super) fn cell_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn cell_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .cell @@ -820,17 +795,26 @@ impl PyMolecule { let Some(slot) = view.cell else { return Ok(None); }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid cell section")); + match slot.2 { + TYPE_MAT3X3_F32 => { + if slot.1 != 36 { + return Err(PyValueError::new_err("Invalid cell section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + TYPE_MAT3X3_F64 => { + if slot.1 != 72 { + return Err(PyValueError::new_err("Invalid cell section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid cell section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) } - pub(super) fn stress_py<'py>( - &self, - py: Python<'py>, - ) -> PyResult>>> { + pub(super) fn stress_py<'py>(&self, py: Python<'py>) -> PyResult>> { if let Some(inner) = self.as_owned() { return inner .stress @@ -844,11 +828,23 @@ impl PyMolecule { let Some(slot) = view.stress else { return Ok(None); }; - if slot.2 != TYPE_MAT3X3_F64 || slot.1 != 72 { - return Err(PyValueError::new_err("Invalid stress section")); + match slot.2 { + TYPE_MAT3X3_F32 => { + if slot.1 != 36 { + return Err(PyValueError::new_err("Invalid stress section")); + } + let data = cast_or_decode_f32(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + TYPE_MAT3X3_F64 => { + if slot.1 != 72 { + return Err(PyValueError::new_err("Invalid stress section")); + } + let data = cast_or_decode_f64(view.builtin_payload(slot))?; + Ok(Some(pyarray2_from_cow(py, data, 3, 3)?.into_any().unbind())) + } + _ => Err(PyValueError::new_err("Invalid stress section")), } - let data = cast_or_decode_f64(view.builtin_payload(slot))?; - Ok(Some(pyarray2_from_cow(py, data, 3, 3)?)) } pub(super) fn append_owned_ase_properties<'py>( @@ -920,84 +916,59 @@ impl PyMolecule { #[allow(clippy::too_many_arguments)] pub(super) fn from_arrays_impl( py: Python<'_>, - positions: &Bound<'_, PyArray2>, + positions: &Bound<'_, PyAny>, atomic_numbers: &Bound<'_, PyArray1>, energy: Option, - forces: Option<&Bound<'_, PyArray2>>, - charges: Option<&Bound<'_, PyArray1>>, - velocities: Option<&Bound<'_, PyArray2>>, - cell: Option<&Bound<'_, PyArray2>>, + forces: Option>, + charges: Option>, + velocities: Option>, + cell: Option>, stress: Option>, pbc: Option<(bool, bool, bool)>, name: Option, ) -> PyResult { - let pos = positions.readonly(); - let pos_arr = pos.as_array(); - let shape = pos_arr.shape(); - if shape.len() != 2 || shape[1] != 3 { - return Err(PyValueError::new_err( - "positions must have shape (n_atoms, 3)", - )); + let mut molecule = molecule_from_numpy_arrays(positions, atomic_numbers)?; + + if let Some(name) = name { + molecule.name = Some(name); + } + if let Some(energy) = energy { + molecule.energy = Some(FloatScalarData::F64(energy)); + } + if let Some((a, b, c)) = pbc { + molecule.pbc = Some([a, b, c]); } - let z = atomic_numbers.readonly(); - let z_arr = z.as_array(); - let n_atoms = shape[0]; - if z_arr.len() != n_atoms { - return Err(PyValueError::new_err(format!( - "Atomic numbers length ({}) doesn't match atom count ({})", - z_arr.len(), - n_atoms - ))); + let n_atoms = molecule.len(); + + if let Some(forces) = forces { + molecule.forces = Some(parse_vec3_field(forces.bind(py), "forces", n_atoms)?); + } + + if let Some(charges) = charges { + molecule.charges = Some(parse_float_array_field( + charges.bind(py), + "charges", + n_atoms, + )?); + } + + if let Some(velocities) = velocities { + molecule.velocities = Some(parse_vec3_field( + velocities.bind(py), + "velocities", + n_atoms, + )?); + } + + if let Some(cell) = cell { + molecule.cell = Some(parse_mat3_field(cell.bind(py), "cell")?); + } + + if let Some(stress) = stress { + molecule.stress = Some(parse_mat3_field(stress.bind(py), "stress")?); } - let pos_bytes = pos_arr - .as_slice() - .ok_or_else(|| PyValueError::new_err("positions must be C-contiguous"))?; - let z_bytes = z_arr - .as_slice() - .ok_or_else(|| PyValueError::new_err("atomic_numbers must be C-contiguous"))?; - - let forces_readonly = forces.map(|value| value.readonly()); - let charges_readonly = charges.map(|value| value.readonly()); - let velocities_readonly = velocities.map(|value| value.readonly()); - let cell_readonly = cell.map(|value| value.readonly()); - - let forces_payload = forces_readonly - .as_ref() - .map(|readonly| vec3_f32_payload(readonly, "forces", n_atoms)) - .transpose()?; - let charges_payload = charges_readonly - .as_ref() - .map(|readonly| vec1_f64_payload(readonly, "charges", n_atoms)) - .transpose()?; - let velocities_payload = velocities_readonly - .as_ref() - .map(|readonly| vec3_f32_payload(readonly, "velocities", n_atoms)) - .transpose()?; - let cell_payload = cell_readonly - .as_ref() - .map(|readonly| mat3x3_f64_payload(readonly, "cell")) - .transpose()?; - let stress_payload = stress - .map(|value| mat3x3_f64_payload_from_any(py, value, "stress")) - .transpose()?; - let record = build_soa_record( - pos_bytes, - z_bytes, - SoaBuiltinPayloads { - energy, - forces: forces_payload, - charges: charges_payload, - velocities: velocities_payload, - cell: cell_payload, - stress: stress_payload.as_deref(), - pbc: pbc.map(|(a, b, c)| [a, b, c]), - name: name.as_deref(), - }, - ) - .map_err(PyValueError::new_err)?; - - Ok(Self::from_view(record.into_view())) + Ok(Self::from_owned(molecule)) } } diff --git a/atompack-py/src/soa.rs b/atompack-py/src/soa.rs new file mode 100644 index 0000000..ce401e2 --- /dev/null +++ b/atompack-py/src/soa.rs @@ -0,0 +1,1272 @@ +use super::*; +use atompack::storage::{DatabaseSchema, DatabaseSchemaSection}; + +/// A single parsed section reference (zero-copy into decompressed bytes). +#[derive(Clone)] +pub(crate) struct SectionRef<'a> { + pub(crate) kind: u8, + pub(crate) key: &'a str, + pub(crate) type_tag: u8, + pub(crate) payload: &'a [u8], +} + +/// Per-molecule extracted data (references into decompressed bytes). +pub(crate) struct MolData<'a> { + pub(crate) n_atoms: usize, + pub(crate) positions_bytes: &'a [u8], + pub(crate) atomic_numbers_bytes: &'a [u8], + pub(crate) sections: Vec>, +} + +#[derive(Clone)] +pub(crate) struct SectionSchema { + pub(crate) kind: u8, + pub(crate) key: String, + pub(crate) type_tag: u8, + pub(crate) per_atom: bool, + pub(crate) elem_bytes: usize, + pub(crate) slot_bytes: usize, +} + +pub(crate) fn parse_mol_fast_soa_v2(bytes: &[u8]) -> atompack::Result> { + let mut pos = 0usize; + let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; + let positions_len = n_atoms + .checked_mul(12) + .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; + let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; + let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; + let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; + + let mut sections = Vec::with_capacity(n_sections); + for _ in 0..n_sections { + let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; + let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; + let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; + let key = std::str::from_utf8(key_bytes) + .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; + let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; + let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; + let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; + sections.push(SectionRef { + kind, + key, + type_tag, + payload, + }); + } + + Ok(MolData { + n_atoms, + positions_bytes, + atomic_numbers_bytes, + sections, + }) +} + +/// Parse SOA format bytes into MolData without allocation. +/// +/// Layout: +/// [n_atoms:u32][positions:n*(12|24)][atomic_numbers:n] +/// [n_sections:u16] +/// per section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] +pub(crate) fn parse_mol_fast_soa( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> atompack::Result> { + if record_format == RECORD_FORMAT_SOA_V2 { + return parse_mol_fast_soa_v2(bytes); + } + + let mut pos = 0usize; + let n_atoms = read_u32_le_at(bytes, &mut pos, "SOA n_atoms")? as usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let positions_len = n_atoms + .checked_mul(positions_stride) + .ok_or_else(|| invalid_data("SOA positions byte length overflow"))?; + let positions_bytes = read_bytes_at(bytes, &mut pos, positions_len, "SOA positions")?; + let atomic_numbers_bytes = read_bytes_at(bytes, &mut pos, n_atoms, "SOA atomic_numbers")?; + let n_sections = read_u16_le_at(bytes, &mut pos, "SOA n_sections")? as usize; + + let mut sections = Vec::with_capacity(n_sections); + for _ in 0..n_sections { + let kind = read_u8_at(bytes, &mut pos, "SOA section kind")?; + let key_len = read_u8_at(bytes, &mut pos, "SOA section key length")? as usize; + let key_bytes = read_bytes_at(bytes, &mut pos, key_len, "SOA section key")?; + let key = std::str::from_utf8(key_bytes) + .map_err(|_| invalid_data("Invalid UTF-8 in SOA section key"))?; + let type_tag = read_u8_at(bytes, &mut pos, "SOA section type tag")?; + let payload_len = read_u32_le_at(bytes, &mut pos, "SOA section payload length")? as usize; + let payload = read_bytes_at(bytes, &mut pos, payload_len, "SOA section payload")?; + sections.push(SectionRef { + kind, + key, + type_tag, + payload, + }); + } + + Ok(MolData { + n_atoms, + positions_bytes, + atomic_numbers_bytes, + sections, + }) +} + +pub(crate) fn section_schema_from_ref( + section: &SectionRef<'_>, + n_atoms: usize, +) -> atompack::Result { + let per_atom = is_per_atom(section.kind, section.key, section.type_tag); + let elem_bytes = match section.type_tag { + TYPE_STRING => 0, + tag if per_atom => { + let elem_bytes = type_tag_elem_bytes(tag); + if elem_bytes == 0 { + return Err(invalid_data(format!( + "Unsupported per-atom section type tag {} for key '{}'", + tag, section.key + ))); + } + elem_bytes + } + TYPE_FLOAT | TYPE_INT => 8, + TYPE_FLOAT32 => 4, + TYPE_BOOL3 => 3, + TYPE_MAT3X3_F32 => 36, + TYPE_MAT3X3_F64 => 72, + _ => section.payload.len(), + }; + let slot_bytes = if section.type_tag == TYPE_STRING { + 0 + } else if per_atom { + elem_bytes + } else { + section.payload.len() + }; + + validate_section_payload(section, per_atom, elem_bytes, slot_bytes, n_atoms)?; + + Ok(SectionSchema { + kind: section.kind, + key: section.key.to_string(), + type_tag: section.type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +pub(crate) fn validate_section_payload( + section: &SectionRef<'_>, + per_atom: bool, + elem_bytes: usize, + slot_bytes: usize, + n_atoms: usize, +) -> atompack::Result<()> { + match section.type_tag { + TYPE_STRING => { + std::str::from_utf8(section.payload) + .map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?; + if per_atom { + return Err(invalid_data(format!( + "String section '{}' cannot be per-atom in flat extraction", + section.key + ))); + } + } + TYPE_FLOAT | TYPE_INT | TYPE_FLOAT32 | TYPE_BOOL3 | TYPE_MAT3X3_F32 | TYPE_MAT3X3_F64 => { + if section.payload.len() != slot_bytes { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + slot_bytes + ))); + } + } + _ if per_atom => { + let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { + invalid_data(format!("Section '{}' payload length overflow", section.key)) + })?; + if section.payload.len() != expected { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + expected + ))); + } + } + _ => { + if slot_bytes != 0 && section.payload.len() != slot_bytes { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected {})", + section.key, + section.payload.len(), + slot_bytes + ))); + } + if elem_bytes != 0 && !section.payload.len().is_multiple_of(elem_bytes) { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} for element size {}", + section.key, + section.payload.len(), + elem_bytes + ))); + } + } + } + Ok(()) +} + +/// Element size in bytes for a given type tag. Returns 0 for variable-length types. +pub(crate) fn type_tag_elem_bytes(tag: u8) -> usize { + match tag { + TYPE_FLOAT => 8, + TYPE_INT => 8, + TYPE_STRING => 0, + TYPE_F64_ARRAY => 8, + TYPE_VEC3_F32 => 12, + TYPE_I64_ARRAY => 8, + TYPE_F32_ARRAY => 4, + TYPE_VEC3_F64 => 24, + TYPE_I32_ARRAY => 4, + TYPE_BOOL3 => 3, + TYPE_FLOAT32 => 4, + TYPE_MAT3X3_F32 => 36, + TYPE_MAT3X3_F64 => 72, + _ => 0, + } +} + +/// Whether a section with the given kind/key/type_tag is per-atom (vs per-molecule). +pub(crate) fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn database_schema_section( + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + n_atoms: usize, +) -> PyResult { + let per_atom = is_per_atom(kind, key, type_tag); + let elem_bytes = if type_tag == TYPE_STRING { + 0 + } else { + let elem_bytes = type_tag_elem_bytes(type_tag); + if elem_bytes == 0 { + return Err(PyValueError::new_err(format!( + "Unsupported section type tag {} for '{}'", + type_tag, key + ))); + } + elem_bytes + }; + let slot_bytes = if type_tag == TYPE_STRING { + 0 + } else if per_atom { + let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { + PyValueError::new_err(format!("Section '{}' payload length overflow", key)) + })?; + if payload_len != expected { + return Err(PyValueError::new_err(format!( + "Section '{}' has invalid payload length {} (expected {})", + key, payload_len, expected + ))); + } + elem_bytes + } else { + payload_len + }; + + Ok(DatabaseSchemaSection { + kind, + key: key.to_string(), + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn py_decode_float_array_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> PyResult { + match type_tag { + TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> PyResult { + match type_tag { + TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_float_scalar_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> PyResult { + match type_tag { + TYPE_FLOAT => Ok(FloatScalarData::F64(read_f64_scalar(payload)?)), + TYPE_FLOAT32 => Ok(FloatScalarData::F32(read_f32_scalar(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +fn py_decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> PyResult { + match type_tag { + TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), + other => Err(PyValueError::new_err(format!( + "Unsupported {field_name} type tag {}", + other + ))), + } +} + +/// Lightweight section descriptor — stores byte offsets into the parent `bytes` buffer. +/// Key is NOT parsed eagerly; use `key()` to read it lazily from `bytes`. +#[derive(Clone, Copy)] +pub(crate) struct LazySection { + pub(crate) kind: u8, + pub(crate) key_start: usize, + pub(crate) key_len: u8, + pub(crate) type_tag: u8, + pub(crate) payload_start: usize, + pub(crate) payload_len: usize, +} + +/// Byte-offset pair for a known builtin section (payload_start, payload_len, type_tag). +pub(crate) type BuiltinSlot = (usize, usize, u8); + +enum SoaBytes { + Owned(Vec), + Shared(SharedMmapBytes), +} + +impl SoaBytes { + #[inline] + fn as_slice(&self) -> &[u8] { + match self { + Self::Owned(bytes) => bytes, + Self::Shared(bytes) => bytes.as_slice(), + } + } +} + +impl std::ops::Deref for SoaBytes { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +pub(crate) struct SoaMoleculeView { + bytes: SoaBytes, + pub(crate) n_atoms: usize, + pub(crate) positions_type: u8, + positions_start: usize, + positions_len: usize, + atomic_numbers_start: usize, + pub(crate) forces: Option, + pub(crate) energy: Option, + pub(crate) cell: Option, + pub(crate) stress: Option, + pub(crate) charges: Option, + pub(crate) velocities: Option, + pbc: Option, + name: Option, + pub(crate) custom_sections: Vec, +} + +impl SoaMoleculeView { + fn from_storage_v2(bytes: SoaBytes) -> atompack::Result { + if bytes.len() < 6 { + return Err(invalid_data("SOA record too small")); + } + + let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; + let mut pos = 4usize; + let positions_start = pos; + let positions_len = n_atoms + .checked_mul(12) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + pos = pos + .checked_add(positions_len) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA record truncated at positions")); + } + + let atomic_numbers_start = pos; + pos = pos + .checked_add(n_atoms) + .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA record truncated at atomic_numbers")); + } + + let n_sections = + u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; + pos += 2; + + let mut forces = None; + let mut energy = None; + let mut cell = None; + let mut stress = None; + let mut charges = None; + let mut velocities = None; + let mut pbc = None; + let mut name = None; + let mut custom_sections = Vec::new(); + + for _ in 0..n_sections { + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(invalid_data("SOA section key truncated")); + } + let key_start = pos; + pos += key_len; + if pos + 5 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(slice_to_array( + &bytes[pos..pos + 4], + "SOA section payload length", + )?) as usize; + pos += 4; + let payload_start = pos; + pos = pos + .checked_add(payload_len) + .ok_or_else(|| invalid_data("SOA section payload overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA section payload truncated")); + } + + let key_bytes = &bytes[key_start..key_start + key_len]; + if kind == KIND_BUILTIN { + let slot = (payload_start, payload_len, type_tag); + match key_bytes { + b"forces" => forces = Some(slot), + b"energy" => energy = Some(slot), + b"cell" => cell = Some(slot), + b"stress" => stress = Some(slot), + b"charges" => charges = Some(slot), + b"velocities" => velocities = Some(slot), + b"pbc" => pbc = Some(slot), + b"name" => name = Some(slot), + _ => { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + } else { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + + Ok(Self { + bytes, + n_atoms, + positions_type: TYPE_VEC3_F32, + positions_start, + positions_len, + atomic_numbers_start, + forces, + energy, + cell, + stress, + charges, + velocities, + pbc, + name, + custom_sections, + }) + } + + /// Pure-Rust parser — no Python dependency, safe to call from rayon threads. + fn from_storage_inner( + bytes: SoaBytes, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + if record_format == RECORD_FORMAT_SOA_V2 { + return Self::from_storage_v2(bytes); + } + + if bytes.len() < 6 { + return Err(invalid_data("SOA record too small")); + } + + let n_atoms = u32::from_le_bytes(slice_to_array(&bytes[0..4], "SOA atom count")?) as usize; + let mut pos = 4usize; + let positions_type = match record_format { + RECORD_FORMAT_SOA_V3 => positions_type_hint + .ok_or_else(|| invalid_data("Missing positions dtype for record format 3"))?, + _ => { + return Err(invalid_data(format!( + "Unsupported record format {}", + record_format + ))); + } + }; + let positions_stride = match positions_type { + TYPE_VEC3_F32 => 12usize, + TYPE_VEC3_F64 => 24usize, + _ => { + return Err(invalid_data(format!( + "Unsupported positions type tag {}", + positions_type + ))); + } + }; + let positions_start = pos; + let positions_len = n_atoms + .checked_mul(positions_stride) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + pos = pos + .checked_add(positions_len) + .ok_or_else(|| invalid_data("SOA positions overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA record truncated at positions")); + } + + let atomic_numbers_start = pos; + pos = pos + .checked_add(n_atoms) + .ok_or_else(|| invalid_data("SOA atomic_numbers overflow"))?; + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA record truncated at atomic_numbers")); + } + + let n_sections = + u16::from_le_bytes(slice_to_array(&bytes[pos..pos + 2], "SOA section count")?) as usize; + pos += 2; + + let mut forces = None; + let mut energy = None; + let mut cell = None; + let mut stress = None; + let mut charges = None; + let mut velocities = None; + let mut pbc = None; + let mut name = None; + let mut custom_sections = Vec::new(); + + for _ in 0..n_sections { + if pos + 2 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(invalid_data("SOA section key truncated")); + } + let key_start = pos; + pos += key_len; + if pos + 5 > bytes.len() { + return Err(invalid_data("SOA section header truncated")); + } + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(slice_to_array( + &bytes[pos..pos + 4], + "SOA section payload length", + )?) as usize; + pos += 4; + let payload_start = pos; + pos = pos + .checked_add(payload_len) + .ok_or_else(|| invalid_data("SOA section payload overflow"))?; + if pos > bytes.len() { + return Err(invalid_data("SOA section payload truncated")); + } + + let key_bytes = &bytes[key_start..key_start + key_len]; + if kind == KIND_BUILTIN { + let slot = (payload_start, payload_len, type_tag); + match key_bytes { + b"forces" => forces = Some(slot), + b"energy" => energy = Some(slot), + b"cell" => cell = Some(slot), + b"stress" => stress = Some(slot), + b"charges" => charges = Some(slot), + b"velocities" => velocities = Some(slot), + b"pbc" => pbc = Some(slot), + b"name" => name = Some(slot), + _ => { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + } else { + custom_sections.push(LazySection { + kind, + key_start, + key_len: key_len as u8, + type_tag, + payload_start, + payload_len, + }); + } + } + + Ok(Self { + bytes, + n_atoms, + positions_type, + positions_start, + positions_len, + atomic_numbers_start, + forces, + energy, + cell, + stress, + charges, + velocities, + pbc, + name, + custom_sections, + }) + } + + pub(crate) fn from_bytes_inner( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Owned(bytes), record_format, positions_type_hint) + } + + pub(crate) fn from_shared_bytes_inner( + bytes: SharedMmapBytes, + record_format: u32, + positions_type_hint: Option, + ) -> atompack::Result { + Self::from_storage_inner(SoaBytes::Shared(bytes), record_format, positions_type_hint) + } + + pub(crate) fn from_bytes( + bytes: Vec, + record_format: u32, + positions_type_hint: Option, + ) -> PyResult { + Self::from_bytes_inner(bytes, record_format, positions_type_hint) + .map_err(|e| PyValueError::new_err(format!("{}", e))) + } + + pub(crate) fn positions_bytes(&self) -> &[u8] { + &self.bytes[self.positions_start..self.positions_start + self.positions_len] + } + + #[inline] + pub(crate) fn raw_bytes(&self) -> &[u8] { + self.bytes.as_slice() + } + + pub(crate) fn database_schema(&self) -> PyResult { + let mut sections = Vec::with_capacity(8 + self.custom_sections.len()); + + if let Some(slot) = self.energy { + sections.push(database_schema_section( + KIND_BUILTIN, + "energy", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.forces { + sections.push(database_schema_section( + KIND_BUILTIN, + "forces", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.charges { + sections.push(database_schema_section( + KIND_BUILTIN, + "charges", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.velocities { + sections.push(database_schema_section( + KIND_BUILTIN, + "velocities", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.cell { + sections.push(database_schema_section( + KIND_BUILTIN, + "cell", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.stress { + sections.push(database_schema_section( + KIND_BUILTIN, + "stress", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.pbc { + sections.push(database_schema_section( + KIND_BUILTIN, + "pbc", + slot.2, + slot.1, + self.n_atoms, + )?); + } + if let Some(slot) = self.name { + sections.push(database_schema_section( + KIND_BUILTIN, + "name", + slot.2, + slot.1, + self.n_atoms, + )?); + } + for section in &self.custom_sections { + sections.push(database_schema_section( + section.kind, + self.lazy_section_key(section)?, + section.type_tag, + section.payload_len, + self.n_atoms, + )?); + } + + Ok(DatabaseSchema { + positions_type: Some(self.positions_type), + sections, + }) + } + + pub(crate) fn same_schema_as(&self, other: &Self) -> PyResult { + if self.positions_type != other.positions_type + || self.energy != other.energy + || self.forces != other.forces + || self.charges != other.charges + || self.velocities != other.velocities + || self.cell != other.cell + || self.stress != other.stress + || self.pbc != other.pbc + || self.name != other.name + || self.custom_sections.len() != other.custom_sections.len() + { + return Ok(false); + } + + for (left, right) in self + .custom_sections + .iter() + .zip(other.custom_sections.iter()) + { + if left.kind != right.kind + || left.type_tag != right.type_tag + || left.payload_len != right.payload_len + || self.lazy_section_key(left)? != other.lazy_section_key(right)? + { + return Ok(false); + } + } + + Ok(true) + } + + pub(crate) fn atomic_numbers_bytes(&self) -> &[u8] { + &self.bytes[self.atomic_numbers_start..self.atomic_numbers_start + self.n_atoms] + } + + #[inline] + pub(crate) fn builtin_payload(&self, slot: BuiltinSlot) -> &[u8] { + &self.bytes[slot.0..slot.0 + slot.1] + } + + pub(crate) fn lazy_section_key(&self, s: &LazySection) -> PyResult<&str> { + std::str::from_utf8(&self.bytes[s.key_start..s.key_start + s.key_len as usize]) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in section key")) + } + + pub(crate) fn lazy_section_payload(&self, s: &LazySection) -> &[u8] { + &self.bytes[s.payload_start..s.payload_start + s.payload_len] + } + + pub(crate) fn find_custom_section( + &self, + kind: u8, + key: &str, + ) -> PyResult> { + for section in &self.custom_sections { + if section.kind == kind && self.lazy_section_key(section)? == key { + return Ok(Some(section)); + } + } + Ok(None) + } + + pub(crate) fn property_keys(&self) -> PyResult> { + self.custom_sections + .iter() + .filter(|s| s.kind == KIND_MOL_PROP) + .map(|s| Ok(self.lazy_section_key(s)?.to_string())) + .collect() + } + + pub(crate) fn atom_at(&self, index: usize) -> PyResult> { + if index >= self.n_atoms { + return Ok(None); + } + let atomic_number = self.atomic_numbers_bytes()[index]; + Ok(Some(match self.positions_type { + TYPE_VEC3_F32 => { + let pos = &self.positions_bytes()[index * 12..(index + 1) * 12]; + Atom::new( + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "atom x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "atom y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "atom z")?), + atomic_number, + ) + } + TYPE_VEC3_F64 => { + let pos = &self.positions_bytes()[index * 24..(index + 1) * 24]; + Atom::new( + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "atom x")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "atom y")?) as f32, + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "atom z")?) as f32, + atomic_number, + ) + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + })) + } + + pub(crate) fn energy(&self) -> PyResult> { + match self.energy { + Some(slot) => match slot.2 { + TYPE_FLOAT => Ok(Some(read_f64_scalar(self.builtin_payload(slot))?)), + TYPE_FLOAT32 => Ok(Some(read_f32_scalar(self.builtin_payload(slot))? as f64)), + other => Err(PyValueError::new_err(format!( + "Unsupported energy type tag {}", + other + ))), + }, + None => Ok(None), + } + } + + pub(crate) fn pbc(&self) -> PyResult> { + match self.pbc { + Some(slot) => { + let payload = self.builtin_payload(slot); + if payload.len() != 3 { + return Err(PyValueError::new_err("Invalid pbc payload length")); + } + Ok(Some((payload[0] != 0, payload[1] != 0, payload[2] != 0))) + } + None => Ok(None), + } + } + + pub(crate) fn materialize(&self) -> PyResult { + let atomic_numbers = self.atomic_numbers_bytes().to_vec(); + let mut molecule = match self.positions_type { + TYPE_VEC3_F32 => { + let positions: Vec<[f32; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 12..(i + 1) * 12]; + Ok([ + f32::from_le_bytes(py_slice_to_array(&pos[0..4], "position x")?), + f32::from_le_bytes(py_slice_to_array(&pos[4..8], "position y")?), + f32::from_le_bytes(py_slice_to_array(&pos[8..12], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + TYPE_VEC3_F64 => { + let positions: Vec<[f64; 3]> = (0..self.n_atoms) + .map(|i| { + let pos = &self.positions_bytes()[i * 24..(i + 1) * 24]; + Ok([ + f64::from_le_bytes(py_slice_to_array(&pos[0..8], "position x")?), + f64::from_le_bytes(py_slice_to_array(&pos[8..16], "position y")?), + f64::from_le_bytes(py_slice_to_array(&pos[16..24], "position z")?), + ]) + }) + .collect::>()?; + Molecule::new_f64(positions, atomic_numbers).map_err(PyValueError::new_err)? + } + other => { + return Err(PyValueError::new_err(format!( + "Unsupported positions type tag {}", + other + ))); + } + }; + + if let Some(slot) = self.charges { + molecule.charges = Some(py_decode_float_array_data( + self.builtin_payload(slot), + slot.2, + "charges", + )?); + } + if let Some(slot) = self.cell { + molecule.cell = Some(py_decode_mat3_data( + self.builtin_payload(slot), + slot.2, + "cell", + )?); + } + if let Some(slot) = self.energy { + molecule.energy = Some(py_decode_float_scalar_data( + self.builtin_payload(slot), + slot.2, + "energy", + )?); + } + if let Some(slot) = self.forces { + molecule.forces = Some(py_decode_vec3_data( + self.builtin_payload(slot), + slot.2, + "forces", + )?); + } + if let Some(slot) = self.name { + let payload = self.builtin_payload(slot); + molecule.name = Some( + std::str::from_utf8(payload) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in name"))? + .to_string(), + ); + } + if let Some(slot) = self.pbc { + let payload = self.builtin_payload(slot); + if payload.len() != 3 { + return Err(PyValueError::new_err("Invalid pbc payload length")); + } + molecule.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); + } + if let Some(slot) = self.stress { + molecule.stress = Some(py_decode_mat3_data( + self.builtin_payload(slot), + slot.2, + "stress", + )?); + } + if let Some(slot) = self.velocities { + molecule.velocities = Some(py_decode_vec3_data( + self.builtin_payload(slot), + slot.2, + "velocities", + )?); + } + + for s in &self.custom_sections { + let key = self.lazy_section_key(s)?.to_string(); + let payload = self.lazy_section_payload(s); + match s.kind { + KIND_ATOM_PROP => { + molecule + .atom_properties + .insert(key, decode_property_value(s.type_tag, payload)?); + } + KIND_MOL_PROP => { + molecule + .properties + .insert(key, decode_property_value(s.type_tag, payload)?); + } + _ => {} + } + } + + Ok(molecule) + } +} + +pub(crate) fn read_f64_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 8 { + return Err(PyValueError::new_err("Invalid f64 payload length")); + } + Ok(f64::from_le_bytes(py_slice_to_array( + payload, + "f64 payload", + )?)) +} + +fn read_f32_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 4 { + return Err(PyValueError::new_err("Invalid f32 payload length")); + } + Ok(f32::from_le_bytes(py_slice_to_array( + payload, + "f32 payload", + )?)) +} + +pub(crate) fn read_i64_scalar(payload: &[u8]) -> PyResult { + if payload.len() != 8 { + return Err(PyValueError::new_err("Invalid i64 payload length")); + } + Ok(i64::from_le_bytes(py_slice_to_array( + payload, + "i64 payload", + )?)) +} + +fn decode_f64_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(8) { + return Err(PyValueError::new_err("Invalid f64 array payload length")); + } + payload + .chunks_exact(8) + .map(|chunk| { + Ok(f64::from_le_bytes(py_slice_to_array( + chunk, + "f64 array chunk", + )?)) + }) + .collect() +} + +fn decode_i64_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(8) { + return Err(PyValueError::new_err("Invalid i64 array payload length")); + } + payload + .chunks_exact(8) + .map(|chunk| { + Ok(i64::from_le_bytes(py_slice_to_array( + chunk, + "i64 array chunk", + )?)) + }) + .collect() +} + +fn decode_i32_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(4) { + return Err(PyValueError::new_err("Invalid i32 array payload length")); + } + payload + .chunks_exact(4) + .map(|chunk| { + Ok(i32::from_le_bytes(py_slice_to_array( + chunk, + "i32 array chunk", + )?)) + }) + .collect() +} + +fn decode_f32_array(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(4) { + return Err(PyValueError::new_err("Invalid f32 array payload length")); + } + payload + .chunks_exact(4) + .map(|chunk| { + Ok(f32::from_le_bytes(py_slice_to_array( + chunk, + "f32 array chunk", + )?)) + }) + .collect() +} + +fn decode_vec3_f32(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(12) { + return Err(PyValueError::new_err("Invalid vec3 payload length")); + } + payload + .chunks_exact(12) + .map(|chunk| { + Ok([ + f32::from_le_bytes(py_slice_to_array(&chunk[0..4], "vec3 x")?), + f32::from_le_bytes(py_slice_to_array(&chunk[4..8], "vec3 y")?), + f32::from_le_bytes(py_slice_to_array(&chunk[8..12], "vec3 z")?), + ]) + }) + .collect() +} + +fn decode_vec3_f64(payload: &[u8]) -> PyResult> { + if !payload.len().is_multiple_of(24) { + return Err(PyValueError::new_err("Invalid vec3 payload length")); + } + payload + .chunks_exact(24) + .map(|chunk| { + Ok([ + f64::from_le_bytes(py_slice_to_array(&chunk[0..8], "vec3 x")?), + f64::from_le_bytes(py_slice_to_array(&chunk[8..16], "vec3 y")?), + f64::from_le_bytes(py_slice_to_array(&chunk[16..24], "vec3 z")?), + ]) + }) + .collect() +} + +fn decode_mat3x3_f64(payload: &[u8]) -> PyResult<[[f64; 3]; 3]> { + if payload.len() != 72 { + return Err(PyValueError::new_err("Invalid mat3x3 payload length")); + } + Ok([ + [ + f64::from_le_bytes(py_slice_to_array(&payload[0..8], "mat3x3 [0][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[8..16], "mat3x3 [0][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[16..24], "mat3x3 [0][2]")?), + ], + [ + f64::from_le_bytes(py_slice_to_array(&payload[24..32], "mat3x3 [1][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[32..40], "mat3x3 [1][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[40..48], "mat3x3 [1][2]")?), + ], + [ + f64::from_le_bytes(py_slice_to_array(&payload[48..56], "mat3x3 [2][0]")?), + f64::from_le_bytes(py_slice_to_array(&payload[56..64], "mat3x3 [2][1]")?), + f64::from_le_bytes(py_slice_to_array(&payload[64..72], "mat3x3 [2][2]")?), + ], + ]) +} + +fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(PyValueError::new_err("Invalid mat3x3 payload length")); + } + Ok([ + [ + f32::from_le_bytes(py_slice_to_array(&payload[0..4], "mat3x3 [0][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[4..8], "mat3x3 [0][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[8..12], "mat3x3 [0][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[12..16], "mat3x3 [1][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[16..20], "mat3x3 [1][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[20..24], "mat3x3 [1][2]")?), + ], + [ + f32::from_le_bytes(py_slice_to_array(&payload[24..28], "mat3x3 [2][0]")?), + f32::from_le_bytes(py_slice_to_array(&payload[28..32], "mat3x3 [2][1]")?), + f32::from_le_bytes(py_slice_to_array(&payload[32..36], "mat3x3 [2][2]")?), + ], + ]) +} + +fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { + Ok(match type_tag { + TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), + TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?), + TYPE_STRING => PropertyValue::String( + std::str::from_utf8(payload) + .map_err(|_| PyValueError::new_err("Invalid UTF-8 in string property"))? + .to_string(), + ), + TYPE_F64_ARRAY => PropertyValue::FloatArray(decode_f64_array(payload)?), + TYPE_VEC3_F32 => PropertyValue::Vec3Array(decode_vec3_f32(payload)?), + TYPE_I64_ARRAY => PropertyValue::IntArray(decode_i64_array(payload)?), + TYPE_F32_ARRAY => PropertyValue::Float32Array(decode_f32_array(payload)?), + TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), + TYPE_I32_ARRAY => PropertyValue::Int32Array(decode_i32_array(payload)?), + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported property type tag {}", + type_tag + ))); + } + }) +} diff --git a/atompack-py/tests/test_atom_molecule.py b/atompack-py/tests/test_atom_molecule.py index 7dc3e49..5674ca0 100644 --- a/atompack-py/tests/test_atom_molecule.py +++ b/atompack-py/tests/test_atom_molecule.py @@ -127,8 +127,8 @@ def test_molecule_ml_properties_roundtrip() -> None: stress32 = np.eye(3, dtype=np.float32) * 5.0 mol.stress = stress32 - assert mol.stress.dtype == np.float64 - np.testing.assert_allclose(mol.stress, stress32.astype(np.float64)) + assert mol.stress.dtype == np.float32 + np.testing.assert_allclose(mol.stress, stress32) def test_molecule_ml_property_validation() -> None: @@ -136,7 +136,7 @@ def test_molecule_ml_property_validation() -> None: with pytest.raises(ValueError, match=r"Forces must have shape"): mol.forces = np.zeros((2, 2), dtype=np.float32) - with pytest.raises(ValueError, match=r"Forces length"): + with pytest.raises(ValueError, match=r"Forces must have shape"): mol.forces = np.zeros((1, 3), dtype=np.float32) with pytest.raises(ValueError, match=r"Charges length"): @@ -144,7 +144,7 @@ def test_molecule_ml_property_validation() -> None: with pytest.raises(ValueError, match=r"Velocities must have shape"): mol.velocities = np.zeros((2, 2), dtype=np.float32) - with pytest.raises(ValueError, match=r"Velocities length"): + with pytest.raises(ValueError, match=r"Velocities must have shape"): mol.velocities = np.zeros((1, 3), dtype=np.float32) with pytest.raises(ValueError, match=r"Cell must have shape"): @@ -276,13 +276,11 @@ def test_molecule_getitem_validation() -> None: _ = mol[1.5] -def test_from_arrays_rejects_wrong_dtype_positions() -> None: - # positions must be float32; passing float64 should be rejected cleanly, - # not silently truncated or panic across the FFI boundary. - positions = np.zeros((2, 3), dtype=np.float64) # wrong dtype +def test_from_arrays_preserves_float64_positions() -> None: + positions = np.zeros((2, 3), dtype=np.float64) atomic_numbers = np.array([6, 8], dtype=np.uint8) - with pytest.raises((TypeError, ValueError)): - atompack.Molecule.from_arrays(positions, atomic_numbers) + molecule = atompack.Molecule.from_arrays(positions, atomic_numbers) + assert molecule.positions.dtype == np.float64 def test_from_arrays_rejects_wrong_dtype_atomic_numbers() -> None: diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 6973049..4d83dfd 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -62,6 +62,48 @@ def test_database_rejects_invalid_compression(tmp_path: Path) -> None: with pytest.raises(ValueError, match=r"Invalid compression"): atompack.Database(str(tmp_path / "bad.atp"), compression="definitely-not-a-codec") +def test_database_add_arrays_batch_rejects_v2_incompatible_builtin_dtype(tmp_path: Path) -> None: + path = tmp_path / "batch_arrays_v2_compat.atp" + db = atompack.Database(str(path)) + db.add_molecule( + atompack.Molecule( + np.array([[0.0, 0.0, 0.0]], dtype=np.float32), + np.array([6], dtype=np.uint8), + ) + ) + db.flush() + + db = atompack.Database.open(str(path), mmap=False) + positions = np.array([[[0.0, 0.0, 0.0]]], dtype=np.float32) + atomic_numbers = np.array([[6]], dtype=np.uint8) + cell = np.eye(3, dtype=np.float32)[None, ...] + + with pytest.raises(ValueError, match="record format 2 does not support float32 cell"): + db.add_arrays_batch(positions, atomic_numbers, cell=cell) + + +def test_database_add_molecules_view_passthrough_preserves_v3_positions_dtype( + tmp_path: Path, +) -> None: + source_path = tmp_path / "source_v3.atp" + target_path = tmp_path / "target_v3.atp" + + positions = np.array([[0.0, 0.0, 0.0], [1.25, 0.0, 0.0]], dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + + source = atompack.Database(str(source_path)) + source.add_molecule(atompack.Molecule.from_arrays(positions, atomic_numbers)) + source.flush() + + source_view_db = atompack.Database.open(str(source_path)) + target = atompack.Database(str(target_path)) + target.add_molecules([source_view_db[0]]) + target.flush() + + roundtrip = atompack.Database.open(str(target_path))[0] + np.testing.assert_allclose(roundtrip.positions, positions) + assert roundtrip.positions.dtype == np.float64 + def test_database_roundtrip_from_arrays_with_builtins(tmp_path: Path) -> None: path = tmp_path / "from_arrays_builtins.atp" @@ -248,6 +290,40 @@ def test_database_add_arrays_batch_roundtrip_with_custom_properties(tmp_path: Pa assert second.get_property("phase") == "valid" +def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed( + tmp_path: Path, +) -> None: + path = tmp_path / "batch_arrays_float64_geometry.atp" + positions = np.array( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.5, 0.1, 0.2], [1.5, 0.1, 0.2]], + ], + dtype=np.float64, + ) + atomic_numbers = np.array([[6, 8], [1, 8]], dtype=np.uint8) + forces = np.array( + [ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + [[0.6, 0.5, 0.4], [0.3, 0.2, 0.1]], + ], + dtype=np.float64, + ) + + db = atompack.Database(str(path)) + db.add_arrays_batch(positions, atomic_numbers, forces=forces) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + first = reopened[0] + flat = reopened.get_molecules_flat([0, 1]) + + assert first.positions.dtype == np.float64 + assert first.forces.dtype == np.float64 + assert flat["positions"].dtype == np.float64 + assert flat["forces"].dtype == np.float64 + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_reads_are_view_compatible( @@ -579,6 +655,69 @@ def test_get_molecules_flat_empty(tmp_path: Path) -> None: assert batch["atomic_numbers"].shape == (0,) +def test_database_roundtrip_preserves_float64_geometry(tmp_path: Path) -> None: + path = tmp_path / "float64_geometry.atp" + db = atompack.Database(str(path), overwrite=True) + + positions = np.array([[0.0, 0.1, 0.2], [1.0, 1.1, 1.2]], dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + forces = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) + charges = np.array([-0.1, 0.1], dtype=np.float32) + cell = np.eye(3, dtype=np.float32) + stress = np.eye(3, dtype=np.float64) * 2.0 + + mol = atompack.Molecule.from_arrays( + positions, + atomic_numbers, + forces=forces, + charges=charges, + cell=cell, + stress=stress, + ) + db.add_molecule(mol) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + got = reopened.get_molecule(0) + flat = reopened.get_molecules_flat([0]) + + assert got.positions.dtype == np.float64 + assert got.forces.dtype == np.float64 + assert got.charges.dtype == np.float32 + assert got.cell.dtype == np.float32 + assert got.stress.dtype == np.float64 + assert flat["positions"].dtype == np.float64 + assert flat["forces"].dtype == np.float64 + assert flat["charges"].dtype == np.float32 + assert flat["cell"].dtype == np.float32 + assert flat["stress"].dtype == np.float64 + + +def test_get_molecules_flat_late_optional_builtin_zero_fills(tmp_path: Path) -> None: + path = tmp_path / "late_optional_builtin.atp" + db = atompack.Database(str(path), overwrite=True) + + positions = np.zeros((2, 3), dtype=np.float64) + atomic_numbers = np.array([6, 8], dtype=np.uint8) + db.add_molecule(atompack.Molecule.from_arrays(positions, atomic_numbers)) + db.add_molecule( + atompack.Molecule.from_arrays( + positions + 1.0, + atomic_numbers, + forces=np.ones((2, 3), dtype=np.float64), + ) + ) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=False) + batch = reopened.get_molecules_flat([0, 1]) + + assert batch["positions"].dtype == np.float64 + assert batch["forces"].dtype == np.float64 + np.testing.assert_allclose(batch["forces"][:2], np.zeros((2, 3), dtype=np.float64)) + np.testing.assert_allclose(batch["forces"][2:], np.ones((2, 3), dtype=np.float64)) + + def test_database_open_mmap_populate(tmp_path: Path) -> None: # Smoke test for the documented populate=True path. On Linux this # prefaults mapped pages via memmap2's PopulateRead advise; on macOS the diff --git a/atompack/examples/basic_usage.rs b/atompack/examples/basic_usage.rs index a5fecbe..9269657 100644 --- a/atompack/examples/basic_usage.rs +++ b/atompack/examples/basic_usage.rs @@ -3,7 +3,7 @@ //! //! Run with: cargo run -p atompack --example basic_usage -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{Atom, AtomDatabase, FloatScalarData, Molecule, compression::CompressionType}; fn main() -> atompack::Result<()> { let db_path = std::env::temp_dir().join(format!("atompack-basic-{}.atp", std::process::id())); @@ -14,7 +14,7 @@ fn main() -> atompack::Result<()> { Atom::new(-0.24, 0.93, 0.0, 1), ]); water.name = Some("water".to_string()); - water.energy = Some(-76.4); + water.energy = Some(FloatScalarData::F64(-76.4)); let methane = Molecule::from_atoms(vec![ Atom::new(0.0, 0.0, 0.0, 6), @@ -33,7 +33,7 @@ fn main() -> atompack::Result<()> { let roundtrip = db.get_molecule(0)?; assert_eq!(roundtrip.name.as_deref(), Some("water")); - assert_eq!(roundtrip.energy, Some(-76.4)); + assert_eq!(roundtrip.energy, Some(FloatScalarData::F64(-76.4))); assert_eq!(roundtrip.len(), 3); println!("wrote {}", db_path.display()); diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index b890ba3..3edec41 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -69,27 +69,156 @@ impl PropertyValue { } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Vec3Data { + F32(Vec<[f32; 3]>), + F64(Vec<[f64; 3]>), +} + +impl Vec3Data { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn atom_position(&self, index: usize) -> Option<[f32; 3]> { + match self { + Self::F32(values) => values.get(index).copied(), + Self::F64(values) => values + .get(index) + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]), + } + } + + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|value| [value[0] as f64, value[1] as f64, value[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|value| [value[0], value[1], value[2]]) + .collect(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatScalarData { + F32(f32), + F64(f64), +} + +impl FloatScalarData { + pub fn as_f64(&self) -> f64 { + match self { + Self::F32(value) => *value as f64, + Self::F64(value) => *value, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FloatArrayData { + F32(Vec), + F64(Vec), +} + +impl FloatArrayData { + pub fn len(&self) -> usize { + match self { + Self::F32(values) => values.len(), + Self::F64(values) => values.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn to_f64_vec(&self) -> Vec { + match self { + Self::F32(values) => values.iter().map(|value| *value as f64).collect(), + Self::F64(values) => values.clone(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Mat3Data { + F32([[f32; 3]; 3]), + F64([[f64; 3]; 3]), +} + +impl Mat3Data { + pub fn flatten_f32_lossy(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0] as f32, row[1] as f32, row[2] as f32]) + .collect(), + } + } + + pub fn flatten_f64(&self) -> Vec { + match self { + Self::F32(values) => values + .iter() + .flat_map(|row| [row[0] as f64, row[1] as f64, row[2] as f64]) + .collect(), + Self::F64(values) => values + .iter() + .flat_map(|row| [row[0], row[1], row[2]]) + .collect(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Molecule { pub name: Option, - pub positions: Vec<[f32; 3]>, + pub positions: Vec3Data, pub atomic_numbers: Vec, - pub forces: Option>, + pub forces: Option, - pub energy: Option, + pub energy: Option, - pub charges: Option>, + pub charges: Option, - pub velocities: Option>, + pub velocities: Option, - pub cell: Option<[[f64; 3]; 3]>, + pub cell: Option, pub pbc: Option<[bool; 3]>, - pub stress: Option<[[f64; 3]; 3]>, + pub stress: Option, /// Per-atom properties (features or targets) pub atom_properties: HashMap, @@ -100,7 +229,11 @@ pub struct Molecule { impl Molecule { pub fn new(positions: Vec<[f32; 3]>, atomic_numbers: Vec) -> Result { - Self::with_name_internal(None, positions, atomic_numbers) + Self::with_name_internal(None, Vec3Data::F32(positions), atomic_numbers) + } + + pub fn new_f64(positions: Vec<[f64; 3]>, atomic_numbers: Vec) -> Result { + Self::with_name_internal(None, Vec3Data::F64(positions), atomic_numbers) } pub fn with_name( @@ -108,12 +241,20 @@ impl Molecule { positions: Vec<[f32; 3]>, atomic_numbers: Vec, ) -> Result { - Self::with_name_internal(Some(name), positions, atomic_numbers) + Self::with_name_internal(Some(name), Vec3Data::F32(positions), atomic_numbers) + } + + pub fn with_name_f64( + name: String, + positions: Vec<[f64; 3]>, + atomic_numbers: Vec, + ) -> Result { + Self::with_name_internal(Some(name), Vec3Data::F64(positions), atomic_numbers) } fn with_name_internal( name: Option, - positions: Vec<[f32; 3]>, + positions: Vec3Data, atomic_numbers: Vec, ) -> Result { if positions.len() != atomic_numbers.len() { @@ -140,11 +281,15 @@ impl Molecule { } pub fn from_atoms(atoms: Vec) -> Self { - let positions = atoms.iter().map(|atom| atom.position()).collect(); - let atomic_numbers = atoms.iter().map(|atom| atom.atomic_number).collect(); + let mut positions = Vec::with_capacity(atoms.len()); + let mut atomic_numbers = Vec::with_capacity(atoms.len()); + for atom in atoms { + positions.push([atom.x, atom.y, atom.z]); + atomic_numbers.push(atom.atomic_number); + } Self { name: None, - positions, + positions: Vec3Data::F32(positions), atomic_numbers, forces: None, energy: None, @@ -167,38 +312,76 @@ impl Molecule { } pub fn atom(&self, index: usize) -> Option { + let position = self.positions.atom_position(index)?; Some(Atom::new( - self.positions.get(index)?[0], - self.positions.get(index)?[1], - self.positions.get(index)?[2], + position[0], + position[1], + position[2], *self.atomic_numbers.get(index)?, )) } pub fn to_atoms(&self) -> Vec { - self.positions + self.atomic_numbers .iter() - .zip(&self.atomic_numbers) - .map(|(position, &atomic_number)| { - Atom::new(position[0], position[1], position[2], atomic_number) + .enumerate() + .filter_map(|(index, &atomic_number)| { + self.positions + .atom_position(index) + .map(|position| Atom::new(position[0], position[1], position[2], atomic_number)) }) .collect() } pub fn forces_mut(&mut self) -> &mut Vec<[f32; 3]> { let n_atoms = self.len(); - self.forces.get_or_insert_with(|| vec![[0.0; 3]; n_atoms]) + let forces = self + .forces + .get_or_insert_with(|| Vec3Data::F32(vec![[0.0; 3]; n_atoms])); + if let Vec3Data::F64(values) = forces { + let converted: Vec<[f32; 3]> = values + .iter() + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(); + *forces = Vec3Data::F32(converted); + } + match forces { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => unreachable!(), + } } pub fn velocities_mut(&mut self) -> &mut Vec<[f32; 3]> { let n_atoms = self.len(); - self.velocities - .get_or_insert_with(|| vec![[0.0; 3]; n_atoms]) + let velocities = self + .velocities + .get_or_insert_with(|| Vec3Data::F32(vec![[0.0; 3]; n_atoms])); + if let Vec3Data::F64(values) = velocities { + let converted: Vec<[f32; 3]> = values + .iter() + .map(|value| [value[0] as f32, value[1] as f32, value[2] as f32]) + .collect(); + *velocities = Vec3Data::F32(converted); + } + match velocities { + Vec3Data::F32(values) => values, + Vec3Data::F64(_) => unreachable!(), + } } pub fn charges_mut(&mut self) -> &mut Vec { let n_atoms = self.len(); - self.charges.get_or_insert_with(|| vec![0.0; n_atoms]) + let charges = self + .charges + .get_or_insert_with(|| FloatArrayData::F64(vec![0.0; n_atoms])); + if let FloatArrayData::F32(values) = charges { + let converted: Vec = values.iter().map(|value| *value as f64).collect(); + *charges = FloatArrayData::F64(converted); + } + match charges { + FloatArrayData::F64(values) => values, + FloatArrayData::F32(_) => unreachable!(), + } } /// Add a per-atom floating-point property @@ -317,10 +500,7 @@ impl Molecule { /// /// Returns: [x1, y1, z1, x2, y2, z2, ...] pub fn positions_flat(&self) -> Vec { - self.positions - .iter() - .flat_map(|position| [position[0], position[1], position[2]]) - .collect() + self.positions.flatten_f32_lossy() } /// Get atomic numbers as an array @@ -332,25 +512,19 @@ impl Molecule { /// /// Returns: Some([fx1, fy1, fz1, fx2, fy2, fz2, ...]) or None pub fn forces_flat(&self) -> Option> { - self.forces - .as_ref() - .map(|f| f.iter().flat_map(|v| [v[0], v[1], v[2]]).collect()) + self.forces.as_ref().map(Vec3Data::flatten_f32_lossy) } /// Get velocities as a flat array (if present) pub fn velocities_flat(&self) -> Option> { - self.velocities - .as_ref() - .map(|v| v.iter().flat_map(|vec| [vec[0], vec[1], vec[2]]).collect()) + self.velocities.as_ref().map(Vec3Data::flatten_f32_lossy) } /// Get cell as a flat array (if present) /// /// Returns: Some([a_x, a_y, a_z, b_x, b_y, b_z, c_x, c_y, c_z]) or None pub fn cell_flat(&self) -> Option> { - self.cell - .as_ref() - .map(|c| c.iter().flat_map(|row| [row[0], row[1], row[2]]).collect()) + self.cell.as_ref().map(Mat3Data::flatten_f64) } } @@ -393,7 +567,7 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], vec![6, 8]).unwrap(); // Add forces - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); // Test flat forces let forces_flat = mol.forces_flat().unwrap(); @@ -422,13 +596,13 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); // Set energy - mol.energy = Some(-100.5); + mol.energy = Some(FloatScalarData::F64(-100.5)); // Set custom molecular property mol.set_property("homo_lumo_gap".to_string(), PropertyValue::Float(5.2)); // Retrieve - assert_eq!(mol.energy.unwrap(), -100.5); + assert_eq!(mol.energy.as_ref(), Some(&FloatScalarData::F64(-100.5))); assert_eq!(mol.get_scalar("homo_lumo_gap").unwrap(), 5.2); } @@ -448,7 +622,11 @@ mod tests { let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); // Set unit cell (cubic 10x10x10) - mol.cell = Some([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]); + mol.cell = Some(Mat3Data::F64([ + [10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0], + ])); mol.pbc = Some([true, true, true]); // Test cell_flat diff --git a/atompack/src/bin/atompack-bench.rs b/atompack/src/bin/atompack-bench.rs index db5dca4..8c21bda 100644 --- a/atompack/src/bin/atompack-bench.rs +++ b/atompack/src/bin/atompack-bench.rs @@ -4,7 +4,9 @@ //! Run with: //! cargo run -p atompack --release --bin atompack-bench -- --help -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{ + Atom, AtomDatabase, FloatScalarData, Molecule, Vec3Data, compression::CompressionType, +}; use std::env; use std::time::{Duration, Instant}; @@ -101,15 +103,15 @@ fn create_synthetic_molecule(atoms_per_molecule: usize, id: u64, with_props: boo if with_props { let n = atoms_per_molecule; - mol.energy = Some(-1000.0 - (id as f64) * 1e-3); - mol.forces = Some( + mol.energy = Some(FloatScalarData::F64(-1000.0 - (id as f64) * 1e-3)); + mol.forces = Some(Vec3Data::F32( (0..n) .map(|i| { let t = (id as f32) * 0.002 + (i as f32) * 0.02; [(t * 0.7).sin(), (t * 0.9).cos(), (t * 1.1).sin()] }) .collect(), - ); + )); } mol diff --git a/atompack/src/lib.rs b/atompack/src/lib.rs index 8704db5..04a4dfe 100644 --- a/atompack/src/lib.rs +++ b/atompack/src/lib.rs @@ -12,7 +12,9 @@ pub mod atom; pub mod compression; pub mod storage; -pub use atom::{Atom, Molecule}; +pub use atom::{ + Atom, FloatArrayData, FloatScalarData, Mat3Data, Molecule, PropertyValue, Vec3Data, +}; pub use compression::decompress as decompress_bytes; pub use storage::{AtomDatabase, SharedMmapBytes}; diff --git a/atompack/src/storage/header.rs b/atompack/src/storage/header.rs index 63dc4d9..6e9565f 100644 --- a/atompack/src/storage/header.rs +++ b/atompack/src/storage/header.rs @@ -7,6 +7,8 @@ pub(super) struct Header { pub(super) num_molecules: u64, pub(super) compression: CompressionType, pub(super) record_format: u32, + pub(super) schema_offset: u64, + pub(super) schema_len: u64, pub(super) index_offset: u64, pub(super) index_len: u64, } @@ -43,6 +45,11 @@ pub(super) fn encode_header_slot(header: Header) -> [u8; HEADER_SLOT_SIZE] { slot[52..56].copy_from_slice(&compression_level.to_le_bytes()); slot[56..60].copy_from_slice(&header.record_format.to_le_bytes()); + // Keep the legacy v2 header field offsets intact; schema metadata lives in + // bytes that were previously unused. + slot[60..68].copy_from_slice(&header.schema_offset.to_le_bytes()); + slot[68..76].copy_from_slice(&header.schema_len.to_le_bytes()); + let checksum = adler32(&slot[..HEADER_SLOT_SIZE - 4]); slot[HEADER_SLOT_SIZE - 4..HEADER_SLOT_SIZE].copy_from_slice(&checksum.to_le_bytes()); @@ -76,6 +83,8 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option CompressionType::None, @@ -88,6 +97,17 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option file_size { + return None; + } + } + if index_offset == 0 || index_len == 0 { if num_molecules != 0 || index_offset != 0 || index_len != 0 { return None; @@ -113,6 +133,8 @@ fn decode_header_slot(slot: &[u8; HEADER_SLOT_SIZE], file_size: u64) -> Option const TYPE_I32_ARRAY: u8 = 8; const TYPE_BOOL3: u8 = 9; // [bool; 3] const TYPE_MAT3X3_F64: u8 = 10; // [[f64; 3]; 3] +const TYPE_FLOAT32: u8 = 11; // f32 scalar +const TYPE_MAT3X3_F32: u8 = 12; // [[f32; 3]; 3] // Two redundant page-aligned header slots for crash safety. const HEADER_SLOT_SIZE: usize = 4096; @@ -97,6 +108,63 @@ impl SharedMmapBytes { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DatabaseSchemaSection { + pub kind: u8, + pub key: String, + pub type_tag: u8, + pub per_atom: bool, + pub elem_bytes: usize, + pub slot_bytes: usize, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct DatabaseSchema { + pub positions_type: Option, + pub sections: Vec, +} + +fn database_schema_from_lock(lock: &SchemaLock) -> DatabaseSchema { + let sections = lock + .sections + .iter() + .map(|((kind, key), entry)| DatabaseSchemaSection { + kind: *kind, + key: key.clone(), + type_tag: entry.type_tag, + per_atom: entry.per_atom, + elem_bytes: entry.elem_bytes, + slot_bytes: entry.slot_bytes, + }) + .collect(); + DatabaseSchema { + positions_type: lock.positions_type, + sections, + } +} + +fn schema_lock_from_database_schema(schema: DatabaseSchema) -> SchemaLock { + let sections = schema + .sections + .into_iter() + .map(|section| { + ( + (section.kind, section.key), + SchemaEntry { + type_tag: section.type_tag, + per_atom: section.per_atom, + elem_bytes: section.elem_bytes, + slot_bytes: section.slot_bytes, + }, + ) + }) + .collect(); + SchemaLock { + positions_type: schema.positions_type, + sections, + } +} + pub struct AtomDatabase { path: PathBuf, compression: CompressionType, @@ -107,6 +175,7 @@ pub struct AtomDatabase { committed_end: u64, truncate_tail_on_next_write: bool, index: IndexStorage, + schema_lock: Option, file: Option, data_mmap: Option>, } @@ -133,6 +202,8 @@ impl AtomDatabase { num_molecules: 0, compression, record_format: RECORD_FORMAT_SOA, + schema_offset: 0, + schema_len: 0, index_offset: 0, index_len: 0, }; @@ -151,6 +222,7 @@ impl AtomDatabase { committed_end: HEADER_REGION_SIZE, truncate_tail_on_next_write: false, index: IndexStorage::InMemory(Vec::new()), + schema_lock: None, file: None, data_mmap: None, }) @@ -199,9 +271,11 @@ impl AtomDatabase { fn open_v1(path: PathBuf, mut file: File, use_mmap: bool, populate: bool) -> Result { // Read the best valid header slot (crash-safe). let header = read_best_header(&mut file)?; - if header.record_format != RECORD_FORMAT_SOA { + if header.record_format != RECORD_FORMAT_SOA_V2 + && header.record_format != RECORD_FORMAT_SOA_V3 + { return Err(Error::InvalidData(format!( - "Unsupported legacy record format {}. Atompack no longer supports bincode databases.", + "Unsupported record format {}.", header.record_format ))); } @@ -250,6 +324,15 @@ impl AtomDatabase { IndexStorage::InMemory(Vec::new()) }; + let schema_lock = if header.schema_offset > 0 && header.schema_len > 0 { + file.seek(SeekFrom::Start(header.schema_offset))?; + let mut schema_bytes = vec![0u8; header.schema_len as usize]; + file.read_exact(&mut schema_bytes)?; + Some(decode_schema_lock(&schema_bytes)?) + } else { + None + }; + Ok(Self { path, compression: header.compression, @@ -258,6 +341,7 @@ impl AtomDatabase { committed_end, truncate_tail_on_next_write, index, + schema_lock, file: Some(file), data_mmap, }) @@ -281,76 +365,122 @@ impl AtomDatabase { Ok(()) } - // -- Writing ------------------------------------------------------------- - - /// Add a single molecule. - pub fn add_molecule(&mut self, molecule: &Molecule) -> Result<()> { - self.add_molecules(&[molecule]) - } + fn rebuild_schema_lock(&self) -> Result { + let mut lock = SchemaLock::default(); + let compression = self.compression; + let positions_type_hint = self.positions_type(); - /// Add multiple molecules. Serialization and compression run in parallel - /// (rayon); the compressed blobs are then appended sequentially. - pub fn add_molecules(&mut self, molecules: &[&Molecule]) -> Result<()> { - if molecules.is_empty() { - return Ok(()); + if let Some(ref mmap) = self.data_mmap { + for index in 0..self.index.len() { + let entry = self + .index + .get(index) + .ok_or_else(|| Error::InvalidData(format!("Index {} out of bounds", index)))?; + let start = entry.offset as usize; + let end = start + entry.compressed_size as usize; + let bytes = decompress( + &mmap[start..end], + compression, + Some(entry.uncompressed_size as usize), + )?; + let record = record_schema(&bytes, self.record_format, positions_type_hint)?; + merge_schema_lock(&mut lock, &record)?; + } + return Ok(lock); } - let serialized: Vec<(Vec, u32)> = molecules - .par_iter() - .map(|mol| { - let bytes = serialize_molecule_soa(mol)?; - let num_atoms = mol.len() as u32; - Ok((bytes, num_atoms)) - }) - .collect::>>()?; - - self.append_owned_soa_records(serialized) + let mut file = File::open(&self.path)?; + for index in 0..self.index.len() { + let entry = self + .index + .get(index) + .ok_or_else(|| Error::InvalidData(format!("Index {} out of bounds", index)))?; + file.seek(SeekFrom::Start(entry.offset))?; + let mut compressed = vec![0u8; entry.compressed_size as usize]; + file.read_exact(&mut compressed)?; + let bytes = decompress( + &compressed, + compression, + Some(entry.uncompressed_size as usize), + )?; + let record = record_schema(&bytes, self.record_format, positions_type_hint)?; + merge_schema_lock(&mut lock, &record)?; + } + Ok(lock) } - /// Add pre-serialized SOA records, compressing in parallel and appending to the file. - /// - /// Each entry is `(soa_bytes, num_atoms)`. The bytes must be valid SOA-encoded molecule - /// records (the same format `serialize_molecule_soa` produces). This skips serialization - /// entirely — useful when the caller already has SOA bytes (e.g. from a View or - /// direct numpy-to-SOA construction). - pub fn add_raw_soa_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { - if records.is_empty() { - return Ok(()); - } - self.append_soa_records(records) + fn can_promote_record_format(&self) -> bool { + self.record_format == RECORD_FORMAT_SOA_V2 + && self.index.is_empty() + && self.schema_lock.is_none() } - #[doc(hidden)] - pub fn add_owned_soa_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { - if records.is_empty() { - return Ok(()); + fn resolved_record_format_for_schema(&self, schema: &SchemaLock) -> Result { + match validate_schema_lock_for_record_format(self.record_format, schema) { + Ok(()) => Ok(self.record_format), + Err(current_err) if self.can_promote_record_format() => { + validate_schema_lock_for_record_format(RECORD_FORMAT_SOA_V3, schema)?; + Ok(RECORD_FORMAT_SOA_V3) + } + Err(current_err) => Err(current_err), } - self.append_owned_soa_records(records) } - fn append_soa_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); + fn ensure_schema_compatible<'a, I>(&mut self, records: I) -> Result<()> + where + I: IntoIterator)>, + { + let mut lock = match &self.schema_lock { + Some(lock) => lock.clone(), + None if self.index.is_empty() => SchemaLock::default(), + None => self.rebuild_schema_lock()?, + }; + let mut record_format = self.record_format; + let can_promote = self.can_promote_record_format(); + + for (bytes, positions_type_hint) in records { + let hint = positions_type_hint.or(lock.positions_type); + let record = match record_schema(bytes, record_format, hint) { + Ok(record) => record, + Err(current_err) if can_promote && record_format == RECORD_FORMAT_SOA_V2 => { + let record = record_schema(bytes, RECORD_FORMAT_SOA_V3, hint)?; + record_format = RECORD_FORMAT_SOA_V3; + record + } + Err(current_err) => return Err(current_err), + }; + merge_schema_lock(&mut lock, &record)?; } - self.truncate_uncommitted_tail_if_needed()?; + self.record_format = record_format; + self.schema_lock = Some(lock); + Ok(()) + } + fn ensure_schema_lock(&mut self, incoming: &SchemaLock) -> Result<()> { + self.record_format = self.resolved_record_format_for_schema(incoming)?; + let mut lock = match &self.schema_lock { + Some(lock) => lock.clone(), + None if self.index.is_empty() => SchemaLock::default(), + None => self.rebuild_schema_lock()?, + }; + merge_schema_lock(&mut lock, incoming)?; + self.schema_lock = Some(lock); + Ok(()) + } + + fn write_owned_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { let compression = self.compression; - // Step 1: Compress all records in parallel. let compressed_records: Vec<(Vec, u32, u32)> = records - .par_iter() + .into_par_iter() .map(|(bytes, num_atoms)| { let uncompressed_size = bytes.len() as u32; - let compressed = compress(bytes, compression)?; - Ok((compressed, uncompressed_size, *num_atoms)) + let compressed = compress(&bytes, compression)?; + Ok((compressed, uncompressed_size, num_atoms)) }) .collect::>>()?; - // Step 2: Write all compressed data sequentially (file I/O must be sequential) let mut file = OpenOptions::new().append(true).open(&self.path)?; let mut offset = file.seek(SeekFrom::End(0))?; @@ -375,24 +505,15 @@ impl AtomDatabase { Ok(()) } - fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32)>) -> Result<()> { - if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { - return Err(Error::InvalidData( - "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." - .into(), - )); - } - - self.truncate_uncommitted_tail_if_needed()?; - + fn write_borrowed_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { let compression = self.compression; let compressed_records: Vec<(Vec, u32, u32)> = records - .into_par_iter() + .par_iter() .map(|(bytes, num_atoms)| { let uncompressed_size = bytes.len() as u32; - let compressed = compress(&bytes, compression)?; - Ok((compressed, uncompressed_size, num_atoms)) + let compressed = compress(bytes, compression)?; + Ok((compressed, uncompressed_size, *num_atoms)) }) .collect::>>()?; @@ -420,6 +541,200 @@ impl AtomDatabase { Ok(()) } + // -- Writing ------------------------------------------------------------- + + /// Add a single molecule. + pub fn add_molecule(&mut self, molecule: &Molecule) -> Result<()> { + self.add_molecules(&[molecule]) + } + + /// Add multiple molecules. Serialization and compression run in parallel + /// (rayon); the compressed blobs are then appended sequentially. + pub fn add_molecules(&mut self, molecules: &[&Molecule]) -> Result<()> { + if molecules.is_empty() { + return Ok(()); + } + + let target_format = if self.can_promote_record_format() + && molecules.iter().any(|molecule| { + minimum_record_format_for_molecule(molecule) == RECORD_FORMAT_SOA_V3 + }) { + RECORD_FORMAT_SOA_V3 + } else { + self.record_format + }; + + let serialized: Vec<(Vec, u32, SchemaLock)> = molecules + .par_iter() + .map(|mol| { + let bytes = serialize_molecule_soa(mol, target_format)?; + let num_atoms = mol.len() as u32; + Ok((bytes, num_atoms, schema_from_molecule(mol)?)) + }) + .collect::>>()?; + + let mut batch_schema = SchemaLock::default(); + let mut records = Vec::with_capacity(serialized.len()); + for (bytes, num_atoms, schema) in serialized { + merge_schema_lock(&mut batch_schema, &schema)?; + records.push((bytes, num_atoms)); + } + + self.append_owned_soa_records_prevalidated(records, batch_schema) + } + + /// Add pre-serialized SOA records, compressing in parallel and appending to the file. + /// + /// Each entry is `(soa_bytes, num_atoms)`. The bytes must be valid SOA-encoded molecule + /// records (the same format `serialize_molecule_soa` produces). This skips serialization + /// entirely — useful when the caller already has SOA bytes (e.g. from a View or + /// direct numpy-to-SOA construction). + pub fn add_raw_soa_records(&mut self, records: &[(&[u8], u32)]) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_soa_records( + records + .iter() + .map(|(bytes, num_atoms)| (*bytes, *num_atoms, None)), + ) + } + + #[doc(hidden)] + pub fn add_raw_soa_records_with_positions_type( + &mut self, + records: &[(&[u8], u32, u8)], + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_soa_records( + records.iter().map(|(bytes, num_atoms, positions_type)| { + (*bytes, *num_atoms, Some(*positions_type)) + }), + ) + } + + #[doc(hidden)] + pub fn add_raw_soa_records_with_schema( + &mut self, + records: &[(&[u8], u32)], + schema: DatabaseSchema, + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_raw_soa_records_prevalidated(records, schema_lock_from_database_schema(schema)) + } + + #[doc(hidden)] + pub fn add_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_owned_soa_records(records) + } + + #[doc(hidden)] + pub fn add_owned_soa_records_with_schema( + &mut self, + records: Vec<(Vec, u32)>, + schema: DatabaseSchema, + ) -> Result<()> { + if records.is_empty() { + return Ok(()); + } + self.append_owned_soa_records_prevalidated( + records, + schema_lock_from_database_schema(schema), + ) + } + + fn append_soa_records<'a, I>(&mut self, records: I) -> Result<()> + where + I: IntoIterator)>, + { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } + + let records: Vec<(&[u8], u32, Option)> = records.into_iter().collect(); + if records.is_empty() { + return Ok(()); + } + + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_compatible( + records + .iter() + .map(|(bytes, _, positions_type_hint)| (*bytes, *positions_type_hint)), + )?; + let borrowed = records + .iter() + .map(|(bytes, num_atoms, _)| (*bytes, *num_atoms)) + .collect::>(); + self.write_borrowed_records(&borrowed) + } + + fn append_owned_soa_records(&mut self, records: Vec<(Vec, u32, u8)>) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } + + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_compatible( + records + .iter() + .map(|(bytes, _, positions_type)| (bytes.as_slice(), Some(*positions_type))), + )?; + + let records = records + .into_iter() + .map(|(bytes, num_atoms, _positions_type)| (bytes, num_atoms)) + .collect(); + self.write_owned_records(records) + } + + fn append_owned_soa_records_prevalidated( + &mut self, + records: Vec<(Vec, u32)>, + batch_schema: SchemaLock, + ) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } + + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_lock(&batch_schema)?; + self.write_owned_records(records) + } + + fn append_raw_soa_records_prevalidated( + &mut self, + records: &[(&[u8], u32)], + batch_schema: SchemaLock, + ) -> Result<()> { + if matches!(&self.index, IndexStorage::MemoryMapped { .. }) { + return Err(Error::InvalidData( + "Cannot add molecules to a database opened with a memory-mapped index (read-only); reopen without mmap to write." + .into(), + )); + } + + self.truncate_uncommitted_tail_if_needed()?; + self.ensure_schema_lock(&batch_schema)?; + self.write_borrowed_records(records) + } + // -- Reading ------------------------------------------------------------- /// Read a single molecule by index (seek + decompress + deserialize). @@ -446,14 +761,16 @@ impl AtomDatabase { self.compression, Some(mol_index.uncompressed_size as usize), )?; - deserialize_molecule_soa(&decompressed) + deserialize_molecule_soa(&decompressed, self.record_format, self.positions_type()) } /// Read multiple molecules in parallel. pub fn get_molecules(&self, indices: &[usize]) -> Result> { let raw = self.get_raw_bytes(indices)?; raw.into_par_iter() - .map(|bytes| deserialize_molecule_soa(&bytes)) + .map(|bytes| { + deserialize_molecule_soa(&bytes, self.record_format, self.positions_type()) + }) .collect() } @@ -541,8 +858,20 @@ impl AtomDatabase { }; let index_bytes = encode_index(index_vec); + let schema_bytes = self + .schema_lock + .as_ref() + .map(encode_schema_lock) + .transpose()?; let mut file = OpenOptions::new().write(true).open(&self.path)?; + let schema_offset = file.seek(SeekFrom::End(0))?; + let schema_len = if let Some(schema_bytes) = &schema_bytes { + file.write_all(schema_bytes)?; + schema_bytes.len() as u64 + } else { + 0 + }; let index_offset = file.seek(SeekFrom::End(0))?; file.write_all(&index_bytes)?; file.flush()?; @@ -556,6 +885,8 @@ impl AtomDatabase { num_molecules: self.index.len() as u64, compression: self.compression, record_format: self.record_format, + schema_offset: if schema_len > 0 { schema_offset } else { 0 }, + schema_len, index_offset, index_len: index_bytes.len() as u64, }; @@ -593,6 +924,25 @@ impl AtomDatabase { self.compression } + pub fn record_format(&self) -> u32 { + self.record_format + } + + pub fn positions_type(&self) -> Option { + self.schema_lock + .as_ref() + .and_then(|lock| lock.positions_type) + } + + pub fn schema_info(&self) -> Option { + self.schema_lock.as_ref().map(database_schema_from_lock) + } + + #[doc(hidden)] + pub fn record_format_for_schema(&self, schema: DatabaseSchema) -> Result { + self.resolved_record_format_for_schema(&schema_lock_from_database_schema(schema)) + } + /// Compressed bytes for a molecule from the mmap (None if no mmap). pub fn get_compressed_slice(&self, index: usize) -> Option<&[u8]> { let mol_index = self.index.get(index)?; @@ -636,9 +986,44 @@ impl AtomDatabase { #[cfg(test)] mod tests { use super::*; - use crate::Atom; + use crate::{Atom, FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; use tempfile::NamedTempFile; + fn adler32_for_test(bytes: &[u8]) -> u32 { + const MOD_ADLER: u32 = 65_521; + let mut a: u32 = 1; + let mut b: u32 = 0; + for &byte in bytes { + a = (a + (byte as u32)) % MOD_ADLER; + b = (b + a) % MOD_ADLER; + } + (b << 16) | a + } + + fn encode_legacy_v2_header_slot(header: Header) -> [u8; HEADER_SLOT_SIZE] { + let mut slot = [0u8; HEADER_SLOT_SIZE]; + slot[0..4].copy_from_slice(MAGIC); + slot[4..8].copy_from_slice(&FILE_FORMAT_VERSION.to_le_bytes()); + slot[8..16].copy_from_slice(&header.generation.to_le_bytes()); + slot[16..24].copy_from_slice(&header.data_start.to_le_bytes()); + slot[24..32].copy_from_slice(&header.index_offset.to_le_bytes()); + slot[32..40].copy_from_slice(&header.index_len.to_le_bytes()); + slot[40..48].copy_from_slice(&header.num_molecules.to_le_bytes()); + + let (compression_type, compression_level) = match header.compression { + CompressionType::None => (0u8, 0i32), + CompressionType::Lz4 => (1u8, 0i32), + CompressionType::Zstd(level) => (2u8, level), + }; + slot[48] = compression_type; + slot[52..56].copy_from_slice(&compression_level.to_le_bytes()); + slot[56..60].copy_from_slice(&header.record_format.to_le_bytes()); + + let checksum = adler32_for_test(&slot[..HEADER_SLOT_SIZE - 4]); + slot[HEADER_SLOT_SIZE - 4..HEADER_SLOT_SIZE].copy_from_slice(&checksum.to_le_bytes()); + slot + } + fn molecule_from_atoms(atoms: Vec) -> Molecule { Molecule::from_atoms(atoms) } @@ -689,10 +1074,39 @@ mod tests { let mut db = AtomDatabase::open(&path).unwrap(); assert_eq!(db.len(), 1); let mol = db.get_molecule(0).unwrap(); - assert_eq!(mol.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(mol.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); } } + #[test] + fn test_database_open_legacy_v2_header_layout() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + + let header = Header { + generation: 0, + data_start: HEADER_REGION_SIZE, + num_molecules: 0, + compression: CompressionType::None, + record_format: RECORD_FORMAT_SOA_V2, + schema_offset: 0, + schema_len: 0, + index_offset: 0, + index_len: 0, + }; + let slot = encode_legacy_v2_header_slot(header); + let mut file = File::create(&path).unwrap(); + file.write_all(&slot).unwrap(); + file.write_all(&slot).unwrap(); + file.flush().unwrap(); + file.sync_all().unwrap(); + drop(file); + + let db = AtomDatabase::open(&path).unwrap(); + assert_eq!(db.len(), 0); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + } + #[test] fn test_database_with_forces_and_energy() { let temp = NamedTempFile::new().unwrap(); @@ -703,8 +1117,8 @@ mod tests { Atom::new(0.0, 0.0, 0.0, 6), Atom::new(1.0, 0.0, 0.0, 8), ]); - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); - mol.energy = Some(-123.456); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); + mol.energy = Some(FloatScalarData::F64(-123.456)); // Write to database { @@ -721,13 +1135,14 @@ mod tests { // Check forces are preserved assert!(retrieved.forces.is_some()); let forces = retrieved.forces.unwrap(); - assert_eq!(forces.len(), 2); - assert_eq!(forces[0], [0.1, 0.2, 0.3]); - assert_eq!(forces[1], [0.4, 0.5, 0.6]); + assert_eq!( + forces, + Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + ); // Check energy is preserved assert!(retrieved.energy.is_some()); - assert_eq!(retrieved.energy.unwrap(), -123.456); + assert_eq!(retrieved.energy.unwrap(), FloatScalarData::F64(-123.456)); } } @@ -803,7 +1218,7 @@ mod tests { // Opening + writing should truncate the uncommitted tail before appending new data. let mut db = AtomDatabase::open(&path).unwrap(); let mol2 = molecule_from_atoms(vec![Atom::new(1.0, 2.0, 3.0, 8)]); - let mol2_bytes = serialize_molecule_soa(&mol2).unwrap(); + let mol2_bytes = serialize_molecule_soa(&mol2, db.record_format()).unwrap(); let mol2_compressed = compress(&mol2_bytes, compression).unwrap(); db.add_molecule(&mol2).unwrap(); @@ -825,13 +1240,21 @@ mod tests { Atom::new(4.0, 5.0, 6.0, 8), ]); mol.name = Some("water".to_string()); - mol.energy = Some(-42.5); - mol.forces = Some(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]); - mol.charges = Some(vec![-0.5, 0.5]); - mol.velocities = Some(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - mol.cell = Some([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]); + mol.energy = Some(FloatScalarData::F64(-42.5)); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])); + mol.charges = Some(FloatArrayData::F64(vec![-0.5, 0.5])); + mol.velocities = Some(Vec3Data::F32(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])); + mol.cell = Some(Mat3Data::F64([ + [10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0], + ])); mol.pbc = Some([true, true, false]); - mol.stress = Some([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); + mol.stress = Some(Mat3Data::F64([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ])); // atom_properties mol.atom_properties.insert( @@ -866,28 +1289,31 @@ mod tests { // Verify all fields assert_eq!(r.name.as_deref(), Some("water")); assert_eq!(r.len(), 2); - assert_eq!(r.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(r.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); assert_eq!(r.atomic_numbers[0], 6); - assert_eq!(r.positions[1], [4.0, 5.0, 6.0]); + assert_eq!(r.atom(1).unwrap().position(), [4.0, 5.0, 6.0]); assert_eq!(r.atomic_numbers[1], 8); - assert_eq!(r.energy, Some(-42.5)); + assert_eq!(r.energy, Some(FloatScalarData::F64(-42.5))); assert_eq!( r.forces.as_ref().unwrap(), - &[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + &Vec3Data::F32(vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + ); + assert_eq!( + r.charges.as_ref().unwrap(), + &FloatArrayData::F64(vec![-0.5, 0.5]) ); - assert_eq!(r.charges.as_ref().unwrap(), &[-0.5, 0.5]); assert_eq!( r.velocities.as_ref().unwrap(), - &[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + &Vec3Data::F32(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) ); assert_eq!( r.cell.unwrap(), - [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + Mat3Data::F64([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) ); assert_eq!(r.pbc, Some([true, true, false])); assert_eq!( r.stress.unwrap(), - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + Mat3Data::F64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) ); // atom_properties @@ -930,7 +1356,7 @@ mod tests { let mut db = AtomDatabase::open(&path).unwrap(); let r = db.get_molecule(0).unwrap(); assert_eq!(r.len(), 1); - assert_eq!(r.positions[0], [1.0, 2.0, 3.0]); + assert_eq!(r.atom(0).unwrap().position(), [1.0, 2.0, 3.0]); assert_eq!(r.atomic_numbers[0], 6); assert!(r.name.is_none()); assert!(r.energy.is_none()); @@ -1056,4 +1482,115 @@ mod tests { other => panic!("expected String, got {:?}", other), } } + + #[test] + fn test_schema_lock_rejects_position_dtype_mismatch() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mol_f64 = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + let mol_f32 = Molecule::new(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + + db.add_molecule(&mol_f64).unwrap(); + let err = db.add_molecule(&mol_f32).unwrap_err(); + assert!(format!("{}", err).contains("Position dtype mismatch")); + } + + #[test] + fn test_empty_database_stays_v2_for_v2_compatible_first_write() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol.energy = Some(FloatScalarData::F64(-1.0)); + mol.forces = Some(Vec3Data::F32(vec![[0.1, 0.2, 0.3]])); + + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + db.add_molecule(&mol).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + } + + #[test] + fn test_empty_database_promotes_to_v3_when_first_write_requires_it() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol.forces = Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]])); + + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + db.add_molecule(&mol).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V3); + } + + #[test] + fn test_schema_lock_rejects_custom_shape_mismatch() { + use crate::atom::PropertyValue; + + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mut mol1 = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + mol1.set_property( + "spectrum".to_string(), + PropertyValue::FloatArray(vec![1.0, 2.0]), + ); + + let mut mol2 = Molecule::new(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + mol2.set_property("spectrum".to_string(), PropertyValue::FloatArray(vec![3.0])); + + db.add_molecule(&mol1).unwrap(); + let err = db.add_molecule(&mol2).unwrap_err(); + assert!(format!("{}", err).contains("Schema mismatch for section 'spectrum'")); + } + + #[test] + fn test_add_owned_soa_records_rejects_v2_incompatible_builtin_dtype() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let legacy = Molecule::new(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + db.add_molecule(&legacy).unwrap(); + assert_eq!(db.record_format(), RECORD_FORMAT_SOA_V2); + + let mut mol = molecule_from_atoms(vec![Atom::new(0.0, 0.0, 0.0, 6)]); + mol.cell = Some(Mat3Data::F32([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ])); + + let bytes = serialize_molecule_soa(&mol, RECORD_FORMAT_SOA_V3).unwrap(); + let err = db + .add_owned_soa_records(vec![(bytes, mol.len() as u32, TYPE_VEC3_F32)]) + .unwrap_err(); + + assert!( + err.to_string() + .contains("record format 2 does not support float32 cell") + ); + } + + #[test] + fn test_schema_lock_allows_late_optional_builtin() { + let temp = NamedTempFile::new().unwrap(); + let path = temp.path().to_path_buf(); + let mut db = AtomDatabase::create(&path, CompressionType::None).unwrap(); + + let mol1 = Molecule::new_f64(vec![[0.0, 0.0, 0.0]], vec![6]).unwrap(); + let mut mol2 = Molecule::new_f64(vec![[1.0, 1.0, 1.0]], vec![8]).unwrap(); + mol2.forces = Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]])); + + db.add_molecule(&mol1).unwrap(); + db.add_molecule(&mol2).unwrap(); + db.flush().unwrap(); + + let retrieved = db.get_molecule(1).unwrap(); + assert_eq!(retrieved.forces, Some(Vec3Data::F64(vec![[0.1, 0.2, 0.3]]))); + } } diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs new file mode 100644 index 0000000..14da316 --- /dev/null +++ b/atompack/src/storage/schema.rs @@ -0,0 +1,468 @@ +use super::soa::{arr, positions_stride, property_value_type_tag, resolve_positions_type}; +use super::*; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; +use std::collections::BTreeMap; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct SchemaLock { + pub(super) positions_type: Option, + pub(super) sections: BTreeMap<(u8, String), SchemaEntry>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct SchemaEntry { + pub(super) type_tag: u8, + pub(super) per_atom: bool, + pub(super) elem_bytes: usize, + pub(super) slot_bytes: usize, +} + +const SCHEMA_BLOB_VERSION: u32 = 1; + +pub(super) fn positions_type_from_molecule(molecule: &Molecule) -> u8 { + match molecule.positions { + Vec3Data::F32(_) => TYPE_VEC3_F32, + Vec3Data::F64(_) => TYPE_VEC3_F64, + } +} + +pub(super) fn encode_schema_lock(lock: &SchemaLock) -> Result> { + let mut buf = Vec::new(); + buf.extend_from_slice(&SCHEMA_BLOB_VERSION.to_le_bytes()); + buf.push(lock.positions_type.unwrap_or(255)); + buf.extend_from_slice(&(lock.sections.len() as u32).to_le_bytes()); + for ((kind, key), entry) in &lock.sections { + let key_len: u16 = key + .len() + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema key '{}' is too long", key)))?; + let elem_bytes: u32 = entry + .elem_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema elem_bytes overflow for '{}'", key)))?; + let slot_bytes: u32 = entry + .slot_bytes + .try_into() + .map_err(|_| Error::InvalidData(format!("Schema slot_bytes overflow for '{}'", key)))?; + buf.push(*kind); + buf.push(entry.type_tag); + buf.push(u8::from(entry.per_atom)); + buf.extend_from_slice(&key_len.to_le_bytes()); + buf.extend_from_slice(&elem_bytes.to_le_bytes()); + buf.extend_from_slice(&slot_bytes.to_le_bytes()); + buf.extend_from_slice(key.as_bytes()); + } + Ok(buf) +} + +pub(super) fn decode_schema_lock(bytes: &[u8]) -> Result { + if bytes.len() < 9 { + return Err(Error::InvalidData("Schema blob too small".into())); + } + let version = u32::from_le_bytes(arr(&bytes[0..4])?); + if version != SCHEMA_BLOB_VERSION { + return Err(Error::InvalidData(format!( + "Unsupported schema blob version {}", + version + ))); + } + let positions_type = match bytes[4] { + 255 => None, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => Some(bytes[4]), + other => { + return Err(Error::InvalidData(format!( + "Unsupported schema positions type tag {}", + other + ))); + } + }; + let count = u32::from_le_bytes(arr(&bytes[5..9])?) as usize; + let mut pos = 9usize; + let mut sections = BTreeMap::new(); + for _ in 0..count { + if pos + 13 > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let type_tag = bytes[pos]; + pos += 1; + let per_atom = match bytes[pos] { + 0 => false, + 1 => true, + _ => return Err(Error::InvalidData("Invalid schema per_atom flag".into())), + }; + pos += 1; + let key_len = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + let elem_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let slot_bytes = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("Schema blob truncated at key".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in schema key".into()))? + .to_string(); + pos += key_len; + sections.insert( + (kind, key), + SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }, + ); + } + if pos != bytes.len() { + return Err(Error::InvalidData("Schema blob trailing bytes".into())); + } + Ok(SchemaLock { + positions_type, + sections, + }) +} + +fn schema_type_tag_elem_bytes(tag: u8) -> Result { + match tag { + TYPE_FLOAT => Ok(8), + TYPE_INT => Ok(8), + TYPE_STRING => Ok(0), + TYPE_F64_ARRAY => Ok(8), + TYPE_VEC3_F32 => Ok(12), + TYPE_I64_ARRAY => Ok(8), + TYPE_F32_ARRAY => Ok(4), + TYPE_VEC3_F64 => Ok(24), + TYPE_I32_ARRAY => Ok(4), + TYPE_BOOL3 => Ok(3), + TYPE_MAT3X3_F64 => Ok(72), + TYPE_FLOAT32 => Ok(4), + TYPE_MAT3X3_F32 => Ok(36), + _ => Err(Error::InvalidData(format!( + "Unsupported section type tag {}", + tag + ))), + } +} + +fn schema_is_per_atom(kind: u8, key: &str) -> bool { + match kind { + KIND_ATOM_PROP => true, + KIND_MOL_PROP => false, + KIND_BUILTIN => matches!(key, "forces" | "charges" | "velocities"), + _ => false, + } +} + +fn schema_entry( + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + n_atoms: usize, +) -> Result { + let per_atom = schema_is_per_atom(kind, key); + let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; + let slot_bytes = if type_tag == TYPE_STRING { + 0 + } else if per_atom { + match type_tag { + TYPE_F64_ARRAY | TYPE_I64_ARRAY | TYPE_F32_ARRAY | TYPE_I32_ARRAY => elem_bytes, + TYPE_VEC3_F32 | TYPE_VEC3_F64 => elem_bytes, + _ => payload_len + .checked_div(n_atoms.max(1)) + .unwrap_or(elem_bytes), + } + } else { + payload_len + }; + + if per_atom { + let expected = elem_bytes + .checked_mul(n_atoms) + .ok_or_else(|| Error::InvalidData(format!("Schema overflow for section '{}'", key)))?; + if payload_len != expected { + return Err(Error::InvalidData(format!( + "Section '{}' payload length {} does not match expected {}", + key, payload_len, expected + ))); + } + } + + Ok(SchemaEntry { + type_tag, + per_atom, + elem_bytes, + slot_bytes, + }) +} + +fn validate_builtin_type_tag_for_record_format( + record_format: u32, + key: &str, + type_tag: u8, +) -> Result<()> { + match record_format { + RECORD_FORMAT_SOA_V3 => Ok(()), + RECORD_FORMAT_SOA_V2 => match key { + "charges" if type_tag != TYPE_F64_ARRAY => Err(Error::InvalidData( + "record format 2 does not support float32 charges".into(), + )), + "cell" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 cell".into(), + )), + "energy" if type_tag != TYPE_FLOAT => Err(Error::InvalidData( + "record format 2 does not support float32 energy".into(), + )), + "forces" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 forces".into(), + )), + "stress" if type_tag != TYPE_MAT3X3_F64 => Err(Error::InvalidData( + "record format 2 does not support float32 stress".into(), + )), + "velocities" if type_tag != TYPE_VEC3_F32 => Err(Error::InvalidData( + "record format 2 does not support float64 velocities".into(), + )), + _ => Ok(()), + }, + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn validate_schema_lock_for_record_format( + record_format: u32, + schema: &SchemaLock, +) -> Result<()> { + let _ = resolve_positions_type(record_format, schema.positions_type)?; + for ((kind, key), entry) in &schema.sections { + if *kind == KIND_BUILTIN { + validate_builtin_type_tag_for_record_format(record_format, key, entry.type_tag)?; + } + } + Ok(()) +} + +fn property_value_payload_len(value: &PropertyValue) -> usize { + match value { + PropertyValue::Float(_) | PropertyValue::Int(_) => 8, + PropertyValue::String(value) => value.len(), + PropertyValue::FloatArray(values) => values.len() * 8, + PropertyValue::Vec3Array(values) => values.len() * 12, + PropertyValue::IntArray(values) => values.len() * 8, + PropertyValue::Float32Array(values) => values.len() * 4, + PropertyValue::Vec3ArrayF64(values) => values.len() * 24, + PropertyValue::Int32Array(values) => values.len() * 4, + } +} + +pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { + let n_atoms = molecule.len(); + let mut schema = SchemaLock { + positions_type: Some(positions_type_from_molecule(molecule)), + sections: BTreeMap::new(), + }; + + let mut insert = |kind: u8, key: &str, type_tag: u8, payload_len: usize| -> Result<()> { + let entry = schema_entry(kind, key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key.to_string()), entry); + Ok(()) + }; + + if let Some(charges) = &molecule.charges { + let (type_tag, payload_len) = match charges { + FloatArrayData::F32(values) => (TYPE_F32_ARRAY, values.len() * 4), + FloatArrayData::F64(values) => (TYPE_F64_ARRAY, values.len() * 8), + }; + insert(KIND_BUILTIN, "charges", type_tag, payload_len)?; + } + if let Some(cell) = &molecule.cell { + let (type_tag, payload_len) = match cell { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "cell", type_tag, payload_len)?; + } + if let Some(energy) = &molecule.energy { + let (type_tag, payload_len) = match energy { + FloatScalarData::F32(_) => (TYPE_FLOAT32, 4), + FloatScalarData::F64(_) => (TYPE_FLOAT, 8), + }; + insert(KIND_BUILTIN, "energy", type_tag, payload_len)?; + } + if let Some(forces) = &molecule.forces { + let (type_tag, payload_len) = match forces { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "forces", type_tag, payload_len)?; + } + if let Some(name) = &molecule.name { + insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; + } + if molecule.pbc.is_some() { + insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; + } + if let Some(stress) = &molecule.stress { + let (type_tag, payload_len) = match stress { + Mat3Data::F32(_) => (TYPE_MAT3X3_F32, 36), + Mat3Data::F64(_) => (TYPE_MAT3X3_F64, 72), + }; + insert(KIND_BUILTIN, "stress", type_tag, payload_len)?; + } + if let Some(velocities) = &molecule.velocities { + let (type_tag, payload_len) = match velocities { + Vec3Data::F32(values) => (TYPE_VEC3_F32, values.len() * 12), + Vec3Data::F64(values) => (TYPE_VEC3_F64, values.len() * 24), + }; + insert(KIND_BUILTIN, "velocities", type_tag, payload_len)?; + } + + for (key, value) in &molecule.atom_properties { + insert( + KIND_ATOM_PROP, + key, + property_value_type_tag(value), + property_value_payload_len(value), + )?; + } + for (key, value) in &molecule.properties { + insert( + KIND_MOL_PROP, + key, + property_value_type_tag(value), + property_value_payload_len(value), + )?; + } + + Ok(schema) +} + +fn parse_record_schema_with_positions( + bytes: &[u8], + record_format: u32, + positions_type: u8, +) -> Result { + if bytes.len() < 6 { + return Err(Error::InvalidData("SOA record too small".into())); + } + + let mut pos = 0usize; + let n_atoms = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + + let positions_end = pos + .checked_add( + n_atoms + .checked_mul(positions_stride(positions_type)?) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?, + ) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at positions".into(), + )); + } + pos = positions_end; + + let z_end = pos + .checked_add(n_atoms) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; + if z_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at atomic_numbers".into(), + )); + } + pos = z_end; + + if pos + 2 > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at n_sections".into(), + )); + } + let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; + pos += 2; + + let mut schema = SchemaLock { + positions_type: Some(positions_type), + sections: BTreeMap::new(), + }; + + for _ in 0..n_sections { + if pos + 7 > bytes.len() { + return Err(Error::InvalidData("SOA section header truncated".into())); + } + let kind = bytes[pos]; + pos += 1; + let key_len = bytes[pos] as usize; + pos += 1; + if pos + key_len > bytes.len() { + return Err(Error::InvalidData("SOA section key truncated".into())); + } + let key = std::str::from_utf8(&bytes[pos..pos + key_len]) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in section key".into()))? + .to_string(); + pos += key_len; + let type_tag = bytes[pos]; + pos += 1; + let payload_len = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; + pos += 4; + let payload_end = pos + .checked_add(payload_len) + .ok_or_else(|| Error::InvalidData("SOA section payload overflow".into()))?; + if payload_end > bytes.len() { + return Err(Error::InvalidData("SOA section payload truncated".into())); + } + pos = payload_end; + if kind == KIND_BUILTIN { + validate_builtin_type_tag_for_record_format(record_format, &key, type_tag)?; + } + let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; + schema.sections.insert((kind, key), entry); + } + + Ok(schema) +} + +pub(super) fn record_schema( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + parse_record_schema_with_positions(bytes, record_format, positions_type) +} + +pub(super) fn merge_schema_lock(lock: &mut SchemaLock, record: &SchemaLock) -> Result<()> { + match (lock.positions_type, record.positions_type) { + (None, Some(tag)) => lock.positions_type = Some(tag), + (Some(expected), Some(actual)) if expected != actual => { + return Err(Error::InvalidData(format!( + "Position dtype mismatch: expected type tag {}, got {}", + expected, actual + ))); + } + _ => {} + } + + for ((kind, key), entry) in &record.sections { + match lock.sections.get(&(*kind, key.clone())) { + Some(expected) if expected != entry => { + return Err(Error::InvalidData(format!( + "Schema mismatch for section '{}': expected {:?}, got {:?}", + key, expected, entry + ))); + } + Some(_) => {} + None => { + lock.sections.insert((*kind, key.clone()), entry.clone()); + } + } + } + + Ok(()) +} diff --git a/atompack/src/storage/soa.rs b/atompack/src/storage/soa.rs index 6741929..6c52d7b 100644 --- a/atompack/src/storage/soa.rs +++ b/atompack/src/storage/soa.rs @@ -1,4 +1,5 @@ use super::*; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, Vec3Data}; /// Write a single tagged section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload] fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: &[u8]) { @@ -10,7 +11,7 @@ fn write_section(buf: &mut Vec, kind: u8, key: &str, type_tag: u8, payload: buf.extend_from_slice(payload); } -fn property_value_type_tag(value: &PropertyValue) -> u8 { +pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { match value { PropertyValue::Float(_) => TYPE_FLOAT, PropertyValue::Int(_) => TYPE_INT, @@ -24,6 +25,19 @@ fn property_value_type_tag(value: &PropertyValue) -> u8 { } } +fn property_value_payload_len(value: &PropertyValue) -> usize { + match value { + PropertyValue::Float(_) | PropertyValue::Int(_) => 8, + PropertyValue::String(value) => value.len(), + PropertyValue::FloatArray(values) => values.len() * 8, + PropertyValue::Vec3Array(values) => values.len() * 12, + PropertyValue::IntArray(values) => values.len() * 8, + PropertyValue::Float32Array(values) => values.len() * 4, + PropertyValue::Vec3ArrayF64(values) => values.len() * 24, + PropertyValue::Int32Array(values) => values.len() * 4, + } +} + fn extend_f64(b: &mut Vec, v: &[f64]) { for x in v { b.extend_from_slice(&f64::to_le_bytes(*x)); @@ -45,7 +59,7 @@ fn extend_i32(b: &mut Vec, v: &[i32]) { } } -fn property_value_to_bytes(value: &PropertyValue) -> Vec { +pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { match value { PropertyValue::Float(v) => v.to_le_bytes().to_vec(), PropertyValue::Int(v) => v.to_le_bytes().to_vec(), @@ -112,6 +126,36 @@ pub(super) fn decode_vec3_f32(payload: &[u8]) -> Result> { .collect() } +pub(super) fn decode_vec3_f64(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(24) { + return Err(Error::InvalidData( + "vec3 payload length not divisible by 24".into(), + )); + } + payload + .chunks_exact(24) + .map(|c| { + Ok([ + f64::from_le_bytes(arr(&c[0..8])?), + f64::from_le_bytes(arr(&c[8..16])?), + f64::from_le_bytes(arr(&c[16..24])?), + ]) + }) + .collect() +} + +pub(super) fn decode_f32_array(payload: &[u8]) -> Result> { + if !payload.len().is_multiple_of(4) { + return Err(Error::InvalidData( + "f32 array payload length not divisible by 4".into(), + )); + } + payload + .chunks_exact(4) + .map(|c| Ok(f32::from_le_bytes(arr(c)?))) + .collect() +} + pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { if !payload.len().is_multiple_of(8) { return Err(Error::InvalidData( @@ -124,6 +168,23 @@ pub(super) fn decode_f64_array(payload: &[u8]) -> Result> { .collect() } +pub(super) fn decode_mat3x3_f32(payload: &[u8]) -> Result<[[f32; 3]; 3]> { + if payload.len() != 36 { + return Err(Error::InvalidData(format!( + "mat3x3 payload length {} (expected 36)", + payload.len() + ))); + } + let mut mat = [[0.0f32; 3]; 3]; + for (r, row) in mat.iter_mut().enumerate() { + for (c, cell) in row.iter_mut().enumerate() { + let o = (r * 3 + c) * 4; + *cell = f32::from_le_bytes(arr(&payload[o..o + 4])?); + } + } + Ok(mat) +} + pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { if payload.len() != 72 { return Err(Error::InvalidData(format!( @@ -141,6 +202,434 @@ pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { Ok(mat) } +fn write_vec3_section(buf: &mut Vec, key: &str, values: &Vec3Data) { + match values { + Vec3Data::F32(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_VEC3_F32, + bytemuck::cast_slice::<[f32; 3], u8>(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 12); + for value in values { + extend_f32(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F32, &payload); + } + } + Vec3Data::F64(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_VEC3_F64, + bytemuck::cast_slice::<[f64; 3], u8>(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 24); + for value in values { + extend_f64(&mut payload, value); + } + write_section(buf, KIND_BUILTIN, key, TYPE_VEC3_F64, &payload); + } + } + } +} + +fn write_float_array_section(buf: &mut Vec, key: &str, values: &FloatArrayData) { + match values { + FloatArrayData::F32(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_F32_ARRAY, + bytemuck::cast_slice::(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 4); + extend_f32(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F32_ARRAY, &payload); + } + } + FloatArrayData::F64(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_F64_ARRAY, + bytemuck::cast_slice::(values), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(values.len() * 8); + extend_f64(&mut payload, values); + write_section(buf, KIND_BUILTIN, key, TYPE_F64_ARRAY, &payload); + } + } + } +} + +fn write_mat3_section(buf: &mut Vec, key: &str, values: &Mat3Data) { + match values { + Mat3Data::F32(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_MAT3X3_F32, + bytemuck::cast_slice::<[f32; 3], u8>(values.as_slice()), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(36); + for row in values { + extend_f32(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F32, &payload); + } + } + Mat3Data::F64(values) => { + #[cfg(target_endian = "little")] + { + write_section( + buf, + KIND_BUILTIN, + key, + TYPE_MAT3X3_F64, + bytemuck::cast_slice::<[f64; 3], u8>(values.as_slice()), + ); + } + #[cfg(not(target_endian = "little"))] + { + let mut payload = Vec::with_capacity(72); + for row in values { + extend_f64(&mut payload, row); + } + write_section(buf, KIND_BUILTIN, key, TYPE_MAT3X3_F64, &payload); + } + } + } +} + +fn write_energy_section(buf: &mut Vec, value: &FloatScalarData) { + match value { + FloatScalarData::F32(value) => { + write_section( + buf, + KIND_BUILTIN, + "energy", + TYPE_FLOAT32, + &value.to_le_bytes(), + ); + } + FloatScalarData::F64(value) => { + write_section( + buf, + KIND_BUILTIN, + "energy", + TYPE_FLOAT, + &value.to_le_bytes(), + ); + } + } +} + +pub(super) fn resolve_positions_type( + record_format: u32, + positions_type_hint: Option, +) -> Result { + match record_format { + RECORD_FORMAT_SOA_V2 => match positions_type_hint { + Some(TYPE_VEC3_F64) => Err(Error::InvalidData( + "record format 2 does not support float64 positions".into(), + )), + _ => Ok(TYPE_VEC3_F32), + }, + RECORD_FORMAT_SOA_V3 => positions_type_hint.ok_or_else(|| { + Error::InvalidData("Missing positions dtype for record format 3".into()) + }), + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn positions_stride(positions_type: u8) -> Result { + match positions_type { + TYPE_VEC3_F32 => Ok(12), + TYPE_VEC3_F64 => Ok(24), + _ => Err(Error::InvalidData(format!( + "Unsupported positions type tag {}", + positions_type + ))), + } +} + +fn validate_record_format_compat(molecule: &Molecule, record_format: u32) -> Result<()> { + match record_format { + RECORD_FORMAT_SOA_V3 => Ok(()), + RECORD_FORMAT_SOA_V2 => { + if matches!(molecule.positions, Vec3Data::F64(_)) { + return Err(Error::InvalidData( + "record format 2 does not support float64 positions".into(), + )); + } + if matches!(molecule.charges, Some(FloatArrayData::F32(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float32 charges".into(), + )); + } + if matches!(molecule.cell, Some(Mat3Data::F32(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float32 cell".into(), + )); + } + if matches!(molecule.energy, Some(FloatScalarData::F32(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float32 energy".into(), + )); + } + if matches!(molecule.forces, Some(Vec3Data::F64(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float64 forces".into(), + )); + } + if matches!(molecule.stress, Some(Mat3Data::F32(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float32 stress".into(), + )); + } + if matches!(molecule.velocities, Some(Vec3Data::F64(_))) { + return Err(Error::InvalidData( + "record format 2 does not support float64 velocities".into(), + )); + } + Ok(()) + } + _ => Err(Error::InvalidData(format!( + "Unsupported record format {}", + record_format + ))), + } +} + +pub(super) fn minimum_record_format_for_molecule(molecule: &Molecule) -> u32 { + if validate_record_format_compat(molecule, RECORD_FORMAT_SOA_V2).is_ok() { + RECORD_FORMAT_SOA_V2 + } else { + RECORD_FORMAT_SOA_V3 + } +} + +fn write_positions(buf: &mut Vec, positions: &Vec3Data) { + match positions { + Vec3Data::F32(values) => { + #[cfg(target_endian = "little")] + { + buf.extend_from_slice(bytemuck::cast_slice::<[f32; 3], u8>(values)); + } + #[cfg(not(target_endian = "little"))] + { + for value in values { + extend_f32(buf, value); + } + } + } + Vec3Data::F64(values) => { + #[cfg(target_endian = "little")] + { + buf.extend_from_slice(bytemuck::cast_slice::<[f64; 3], u8>(values)); + } + #[cfg(not(target_endian = "little"))] + { + for value in values { + extend_f64(buf, value); + } + } + } + } +} + +fn count_sections(molecule: &Molecule) -> u16 { + let mut n_sections = 0; + if molecule.charges.is_some() { + n_sections += 1; + } + if molecule.cell.is_some() { + n_sections += 1; + } + if molecule.energy.is_some() { + n_sections += 1; + } + if molecule.forces.is_some() { + n_sections += 1; + } + if molecule.name.is_some() { + n_sections += 1; + } + if molecule.pbc.is_some() { + n_sections += 1; + } + if molecule.stress.is_some() { + n_sections += 1; + } + if molecule.velocities.is_some() { + n_sections += 1; + } + n_sections += molecule.atom_properties.len() as u16; + n_sections += molecule.properties.len() as u16; + n_sections +} + +fn section_size(key_len: usize, payload_len: usize) -> usize { + 1 + 1 + key_len + 1 + 4 + payload_len +} + +fn estimate_serialized_len(molecule: &Molecule) -> usize { + let positions_bytes = match &molecule.positions { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + let mut total = 4 + positions_bytes + molecule.atomic_numbers.len() + 2; + + if let Some(charges) = &molecule.charges { + let payload_len = match charges { + FloatArrayData::F32(values) => values.len() * 4, + FloatArrayData::F64(values) => values.len() * 8, + }; + total += section_size("charges".len(), payload_len); + } + if let Some(cell) = &molecule.cell { + let payload_len = match cell { + Mat3Data::F32(_) => 36, + Mat3Data::F64(_) => 72, + }; + total += section_size("cell".len(), payload_len); + } + if let Some(energy) = &molecule.energy { + let payload_len = match energy { + FloatScalarData::F32(_) => 4, + FloatScalarData::F64(_) => 8, + }; + total += section_size("energy".len(), payload_len); + } + if let Some(forces) = &molecule.forces { + let payload_len = match forces { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + total += section_size("forces".len(), payload_len); + } + if let Some(name) = &molecule.name { + total += section_size("name".len(), name.len()); + } + if molecule.pbc.is_some() { + total += section_size("pbc".len(), 3); + } + if let Some(stress) = &molecule.stress { + let payload_len = match stress { + Mat3Data::F32(_) => 36, + Mat3Data::F64(_) => 72, + }; + total += section_size("stress".len(), payload_len); + } + if let Some(velocities) = &molecule.velocities { + let payload_len = match velocities { + Vec3Data::F32(values) => values.len() * 12, + Vec3Data::F64(values) => values.len() * 24, + }; + total += section_size("velocities".len(), payload_len); + } + + for (key, value) in &molecule.atom_properties { + total += section_size(key.len(), property_value_payload_len(value)); + } + for (key, value) in &molecule.properties { + total += section_size(key.len(), property_value_payload_len(value)); + } + + total +} + +fn write_sections(buf: &mut Vec, molecule: &Molecule) { + if let Some(ref charges) = molecule.charges { + write_float_array_section(buf, "charges", charges); + } + if let Some(ref cell) = molecule.cell { + write_mat3_section(buf, "cell", cell); + } + if let Some(ref energy) = molecule.energy { + write_energy_section(buf, energy); + } + if let Some(ref forces) = molecule.forces { + write_vec3_section(buf, "forces", forces); + } + if let Some(ref name) = molecule.name { + write_section(buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); + } + if let Some(ref pbc) = molecule.pbc { + let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8]; + write_section(buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); + } + if let Some(ref stress) = molecule.stress { + write_mat3_section(buf, "stress", stress); + } + if let Some(ref velocities) = molecule.velocities { + write_vec3_section(buf, "velocities", velocities); + } + + let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect(); + atom_keys.sort(); + for key in atom_keys { + let value = &molecule.atom_properties[key]; + let payload = property_value_to_bytes(value); + write_section( + buf, + KIND_ATOM_PROP, + key, + property_value_type_tag(value), + &payload, + ); + } + + let mut prop_keys: Vec<&String> = molecule.properties.keys().collect(); + prop_keys.sort(); + for key in prop_keys { + let value = &molecule.properties[key]; + let payload = property_value_to_bytes(value); + write_section( + buf, + KIND_MOL_PROP, + key, + property_value_type_tag(value), + &payload, + ); + } +} + fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { Ok(match type_tag { TYPE_FLOAT => { @@ -224,135 +713,171 @@ fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result }) } -pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result> { +pub(super) fn serialize_molecule_soa(molecule: &Molecule, record_format: u32) -> Result> { + validate_record_format_compat(molecule, record_format)?; + let n = molecule.len(); - let mut buf = Vec::new(); + let mut buf = Vec::with_capacity(estimate_serialized_len(molecule)); buf.extend_from_slice(&(n as u32).to_le_bytes()); - for position in &molecule.positions { - buf.extend_from_slice(&position[0].to_le_bytes()); - buf.extend_from_slice(&position[1].to_le_bytes()); - buf.extend_from_slice(&position[2].to_le_bytes()); - } + write_positions(&mut buf, &molecule.positions); buf.extend_from_slice(&molecule.atomic_numbers); + buf.extend_from_slice(&count_sections(molecule).to_le_bytes()); + write_sections(&mut buf, molecule); - let mut n_sections: u16 = 0; - if molecule.charges.is_some() { - n_sections += 1; - } - if molecule.cell.is_some() { - n_sections += 1; - } - if molecule.energy.is_some() { - n_sections += 1; - } - if molecule.forces.is_some() { - n_sections += 1; - } - if molecule.name.is_some() { - n_sections += 1; - } - if molecule.pbc.is_some() { - n_sections += 1; - } - if molecule.stress.is_some() { - n_sections += 1; - } - if molecule.velocities.is_some() { - n_sections += 1; - } - n_sections += molecule.atom_properties.len() as u16; - n_sections += molecule.properties.len() as u16; - buf.extend_from_slice(&n_sections.to_le_bytes()); + Ok(buf) +} - if let Some(ref charges) = molecule.charges { - let mut payload = Vec::with_capacity(charges.len() * 8); - extend_f64(&mut payload, charges); - write_section(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, &payload); +fn decode_positions( + bytes: &[u8], + pos: &mut usize, + n_atoms: usize, + positions_type: u8, +) -> Result { + let positions_len = n_atoms + .checked_mul(positions_stride(positions_type)?) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + let positions_end = pos + .checked_add(positions_len) + .ok_or_else(|| Error::InvalidData("SOA positions overflow".into()))?; + if positions_end > bytes.len() { + return Err(Error::InvalidData( + "SOA record truncated at positions".into(), + )); } - if let Some(ref cell) = molecule.cell { - let mut payload = Vec::with_capacity(72); - for row in cell { - extend_f64(&mut payload, row); + let payload = &bytes[*pos..positions_end]; + let positions = match positions_type { + TYPE_VEC3_F32 => { + let mut values = Vec::with_capacity(n_atoms); + for chunk in payload.chunks_exact(12) { + values.push([ + f32::from_le_bytes(arr(&chunk[0..4])?), + f32::from_le_bytes(arr(&chunk[4..8])?), + f32::from_le_bytes(arr(&chunk[8..12])?), + ]); + } + Vec3Data::F32(values) } - write_section(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, &payload); - } - if let Some(energy) = molecule.energy { - write_section( - &mut buf, - KIND_BUILTIN, - "energy", - TYPE_FLOAT, - &energy.to_le_bytes(), - ); - } - if let Some(ref forces) = molecule.forces { - let mut payload = Vec::with_capacity(forces.len() * 12); - for f in forces { - extend_f32(&mut payload, f); + TYPE_VEC3_F64 => { + let mut values = Vec::with_capacity(n_atoms); + for chunk in payload.chunks_exact(24) { + values.push([ + f64::from_le_bytes(arr(&chunk[0..8])?), + f64::from_le_bytes(arr(&chunk[8..16])?), + f64::from_le_bytes(arr(&chunk[16..24])?), + ]); + } + Vec3Data::F64(values) } - write_section(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, &payload); - } - if let Some(ref name) = molecule.name { - write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes()); - } - if let Some(ref pbc) = molecule.pbc { - let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8]; - write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload); - } - if let Some(ref stress) = molecule.stress { - let mut payload = Vec::with_capacity(72); - for row in stress { - extend_f64(&mut payload, row); + _ => { + return Err(Error::InvalidData(format!( + "Unsupported positions type tag {}", + positions_type + ))); } - write_section(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, &payload); - } - if let Some(ref velocities) = molecule.velocities { - let mut payload = Vec::with_capacity(velocities.len() * 12); - for v in velocities { - extend_f32(&mut payload, v); + }; + *pos = positions_end; + Ok(positions) +} + +fn decode_float_scalar_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_FLOAT => { + if payload.len() != 8 { + return Err(Error::InvalidData(format!( + "{field_name} f64 payload truncated" + ))); + } + Ok(FloatScalarData::F64(f64::from_le_bytes(arr(payload)?))) } - write_section( - &mut buf, - KIND_BUILTIN, - "velocities", - TYPE_VEC3_F32, - &payload, - ); + TYPE_FLOAT32 => { + if payload.len() != 4 { + return Err(Error::InvalidData(format!( + "{field_name} f32 payload truncated" + ))); + } + Ok(FloatScalarData::F32(f32::from_le_bytes(arr(payload)?))) + } + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), } +} - let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect(); - atom_keys.sort(); - for key in atom_keys { - let value = &molecule.atom_properties[key]; - let payload = property_value_to_bytes(value); - write_section( - &mut buf, - KIND_ATOM_PROP, - key, - property_value_type_tag(value), - &payload, - ); +fn decode_vec3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_VEC3_F32 => Ok(Vec3Data::F32(decode_vec3_f32(payload)?)), + TYPE_VEC3_F64 => Ok(Vec3Data::F64(decode_vec3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), } +} - let mut prop_keys: Vec<&String> = molecule.properties.keys().collect(); - prop_keys.sort(); - for key in prop_keys { - let value = &molecule.properties[key]; - let payload = property_value_to_bytes(value); - write_section( - &mut buf, - KIND_MOL_PROP, - key, - property_value_type_tag(value), - &payload, - ); +fn decode_float_array_data( + payload: &[u8], + type_tag: u8, + field_name: &str, +) -> Result { + match type_tag { + TYPE_F32_ARRAY => Ok(FloatArrayData::F32(decode_f32_array(payload)?)), + TYPE_F64_ARRAY => Ok(FloatArrayData::F64(decode_f64_array(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), } +} - Ok(buf) +fn decode_mat3_data(payload: &[u8], type_tag: u8, field_name: &str) -> Result { + match type_tag { + TYPE_MAT3X3_F32 => Ok(Mat3Data::F32(decode_mat3x3_f32(payload)?)), + TYPE_MAT3X3_F64 => Ok(Mat3Data::F64(decode_mat3x3_f64(payload)?)), + _ => Err(Error::InvalidData(format!( + "Unsupported {field_name} type tag {}", + type_tag + ))), + } +} + +fn decode_builtin_section( + mol: &mut Molecule, + key: &str, + type_tag: u8, + payload: &[u8], +) -> Result<()> { + match key { + "energy" => mol.energy = Some(decode_float_scalar_data(payload, type_tag, "energy")?), + "forces" => mol.forces = Some(decode_vec3_data(payload, type_tag, "forces")?), + "charges" => mol.charges = Some(decode_float_array_data(payload, type_tag, "charges")?), + "velocities" => mol.velocities = Some(decode_vec3_data(payload, type_tag, "velocities")?), + "cell" => mol.cell = Some(decode_mat3_data(payload, type_tag, "cell")?), + "pbc" => { + if payload.len() < 3 { + return Err(Error::InvalidData("pbc payload truncated".into())); + } + mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); + } + "stress" => mol.stress = Some(decode_mat3_data(payload, type_tag, "stress")?), + "name" => { + mol.name = Some( + std::str::from_utf8(payload) + .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? + .to_string(), + ); + } + _ => {} + } + Ok(()) } -pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { +fn deserialize_molecule_soa_with_positions(bytes: &[u8], positions_type: u8) -> Result { if bytes.len() < 6 { return Err(Error::InvalidData("SOA record too small".into())); } @@ -361,23 +886,10 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { let n = u32::from_le_bytes(arr(&bytes[pos..pos + 4])?) as usize; pos += 4; - let positions_end = pos + n * 12; - if positions_end > bytes.len() { - return Err(Error::InvalidData( - "SOA record truncated at positions".into(), - )); - } - let mut positions = Vec::with_capacity(n); - for i in 0..n { - let o = pos + i * 12; - let x = f32::from_le_bytes(arr(&bytes[o..o + 4])?); - let y = f32::from_le_bytes(arr(&bytes[o + 4..o + 8])?); - let z = f32::from_le_bytes(arr(&bytes[o + 8..o + 12])?); - positions.push([x, y, z]); - } - pos = positions_end; - - let z_end = pos + n; + let positions = decode_positions(bytes, &mut pos, n, positions_type)?; + let z_end = pos + .checked_add(n) + .ok_or_else(|| Error::InvalidData("SOA atomic_numbers overflow".into()))?; if z_end > bytes.len() { return Err(Error::InvalidData( "SOA record truncated at atomic_numbers".into(), @@ -394,7 +906,14 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { let n_sections = u16::from_le_bytes(arr(&bytes[pos..pos + 2])?) as usize; pos += 2; - let mut mol = Molecule::new(positions, atomic_numbers).map_err(Error::InvalidData)?; + let mut mol = match positions { + Vec3Data::F32(values) => { + Molecule::new(values, atomic_numbers).map_err(Error::InvalidData)? + } + Vec3Data::F64(values) => { + Molecule::new_f64(values, atomic_numbers).map_err(Error::InvalidData)? + } + }; for _ in 0..n_sections { if pos + 7 > bytes.len() { @@ -421,33 +940,7 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { pos += payload_len; match kind { - KIND_BUILTIN => match key { - "energy" => { - if payload.len() < 8 { - return Err(Error::InvalidData("energy payload truncated".into())); - } - mol.energy = Some(f64::from_le_bytes(arr(&payload[..8])?)); - } - "forces" => mol.forces = Some(decode_vec3_f32(payload)?), - "charges" => mol.charges = Some(decode_f64_array(payload)?), - "velocities" => mol.velocities = Some(decode_vec3_f32(payload)?), - "cell" => mol.cell = Some(decode_mat3x3_f64(payload)?), - "pbc" => { - if payload.len() < 3 { - return Err(Error::InvalidData("pbc payload truncated".into())); - } - mol.pbc = Some([payload[0] != 0, payload[1] != 0, payload[2] != 0]); - } - "stress" => mol.stress = Some(decode_mat3x3_f64(payload)?), - "name" => { - mol.name = Some( - std::str::from_utf8(payload) - .map_err(|_| Error::InvalidData("Invalid UTF-8 in name".into()))? - .to_string(), - ) - } - _ => {} - }, + KIND_BUILTIN => decode_builtin_section(&mut mol, key, type_tag, payload)?, KIND_ATOM_PROP => { mol.atom_properties .insert(key.to_string(), decode_property_value(type_tag, payload)?); @@ -462,3 +955,12 @@ pub(super) fn deserialize_molecule_soa(bytes: &[u8]) -> Result { Ok(mol) } + +pub(super) fn deserialize_molecule_soa( + bytes: &[u8], + record_format: u32, + positions_type_hint: Option, +) -> Result { + let positions_type = resolve_positions_type(record_format, positions_type_hint)?; + deserialize_molecule_soa_with_positions(bytes, positions_type) +} diff --git a/atompack/tests/throughput_smoke.rs b/atompack/tests/throughput_smoke.rs index 1c58da4..28f58b6 100644 --- a/atompack/tests/throughput_smoke.rs +++ b/atompack/tests/throughput_smoke.rs @@ -1,5 +1,7 @@ // Copyright 2026 Entalpic -use atompack::{Atom, AtomDatabase, Molecule, compression::CompressionType}; +use atompack::{ + Atom, AtomDatabase, FloatScalarData, Molecule, Vec3Data, compression::CompressionType, +}; use std::hint::black_box; use std::time::{Duration, Instant}; @@ -179,15 +181,15 @@ fn synthetic_molecule(id: usize) -> Molecule { }) .collect(); let mut molecule = Molecule::from_atoms(atoms); - molecule.energy = Some(-1_000.0 - id as f64 * 1e-3); - molecule.forces = Some( + molecule.energy = Some(FloatScalarData::F64(-1_000.0 - id as f64 * 1e-3)); + molecule.forces = Some(Vec3Data::F32( (0..ATOMS_PER_MOLECULE) .map(|i| { let t = id as f32 * 0.002 + i as f32 * 0.02; [(t * 0.7).sin(), (t * 0.9).cos(), (t * 1.1).sin()] }) .collect(), - ); + )); molecule } @@ -231,7 +233,7 @@ fn run_smoke() -> atompack::Result { for _ in 0..RANDOM_READS { let molecule = db.get_molecule(rng.index(db.len()))?; rand_atoms += molecule.len(); - black_box(molecule.forces.as_ref().map(Vec::len)); + black_box(molecule.forces.as_ref().map(Vec3Data::len)); } let rand_elapsed = start_rand.elapsed(); assert_eq!(rand_atoms, RANDOM_READS * ATOMS_PER_MOLECULE);