Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import numpy.typing as npt

from . import hub as hub

Float1D = npt.NDArray[np.float32] | npt.NDArray[np.float64]
Float2D = npt.NDArray[np.float32] | npt.NDArray[np.float64]
Float3D = npt.NDArray[np.float32] | npt.NDArray[np.float64]

class Atom:
"""
Represents a single atom with 3D coordinates and atomic number.
Expand Down Expand Up @@ -115,29 +119,29 @@ class Molecule:

def __init__(
self,
positions: npt.NDArray[np.float32],
positions: Float2D,
atomic_numbers: npt.NDArray[np.uint8],
*,
energy: float | None = ...,
forces: npt.NDArray[np.float32] | None = ...,
charges: npt.NDArray[np.float64] | None = ...,
velocities: npt.NDArray[np.float32] | None = ...,
cell: npt.NDArray[np.float64] | None = ...,
stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ...,
forces: Float2D | None = ...,
charges: Float1D | None = ...,
velocities: Float2D | None = ...,
cell: Float2D | None = ...,
stress: Float2D | None = ...,
pbc: tuple[bool, bool, bool] | None = ...,
name: str | None = ...,
) -> None: ...
@staticmethod
def from_arrays(
positions: npt.NDArray[np.float32],
positions: Float2D,
atomic_numbers: npt.NDArray[np.uint8],
*,
energy: float | None = ...,
forces: npt.NDArray[np.float32] | None = ...,
charges: npt.NDArray[np.float64] | None = ...,
velocities: npt.NDArray[np.float32] | None = ...,
cell: npt.NDArray[np.float64] | None = ...,
stress: npt.NDArray[np.float64] | npt.NDArray[np.float32] | None = ...,
forces: Float2D | None = ...,
charges: Float1D | None = ...,
velocities: Float2D | None = ...,
cell: Float2D | None = ...,
stress: Float2D | None = ...,
pbc: tuple[bool, bool, bool] | None = ...,
name: str | None = ...,
) -> Molecule:
Expand Down Expand Up @@ -204,7 +208,7 @@ class Molecule:
...

@property
def forces(self) -> npt.NDArray[np.float32] | None:
def forces(self) -> Float2D | None:
"""
Per-atom forces.

Expand All @@ -216,7 +220,7 @@ class Molecule:
...

@forces.setter
def forces(self, value: npt.NDArray[np.float32]) -> None: ...
def forces(self, value: Float2D) -> None: ...
@property
def energy(self) -> float | None:
"""
Expand All @@ -232,7 +236,7 @@ class Molecule:
@energy.setter
def energy(self, value: float | None) -> None: ...
@property
def charges(self) -> npt.NDArray[np.float64] | None:
def charges(self) -> Float1D | None:
"""
Per-atom partial charges.

Expand All @@ -244,9 +248,9 @@ class Molecule:
...

@charges.setter
def charges(self, value: npt.NDArray[np.float64]) -> None: ...
def charges(self, value: Float1D) -> None: ...
@property
def velocities(self) -> npt.NDArray[np.float32] | None:
def velocities(self) -> Float2D | None:
"""
Per-atom velocities.

Expand All @@ -258,9 +262,9 @@ class Molecule:
...

@velocities.setter
def velocities(self, value: npt.NDArray[np.float32]) -> None: ...
def velocities(self, value: Float2D) -> None: ...
@property
def cell(self) -> npt.NDArray[np.float64] | None:
def cell(self) -> Float2D | None:
"""
Unit cell for periodic systems.

Expand All @@ -272,9 +276,9 @@ class Molecule:
...

@cell.setter
def cell(self, value: npt.NDArray[np.float64]) -> None: ...
def cell(self, value: Float2D) -> None: ...
@property
def stress(self) -> npt.NDArray[np.float64] | None:
def stress(self) -> Float2D | None:
"""
Virial stress tensor.

Expand All @@ -286,9 +290,9 @@ class Molecule:
...

@stress.setter
def stress(self, value: npt.NDArray[np.float32] | npt.NDArray[np.float64]) -> None: ...
def stress(self, value: Float2D) -> None: ...
@property
def positions(self) -> npt.NDArray[np.float32]:
def positions(self) -> Float2D:
"""
Atomic positions (read-only).

Expand Down Expand Up @@ -505,15 +509,15 @@ class Database:
...
def add_arrays_batch(
self,
positions: npt.NDArray[np.float32],
positions: Float3D,
atomic_numbers: npt.NDArray[np.uint8],
*,
energy: npt.NDArray[np.float64] | None = ...,
forces: npt.NDArray[np.float32] | None = ...,
charges: npt.NDArray[np.float64] | None = ...,
velocities: npt.NDArray[np.float32] | None = ...,
cell: npt.NDArray[np.float64] | None = ...,
stress: npt.NDArray[np.float64] | None = ...,
energy: Float1D | None = ...,
forces: Float3D | None = ...,
charges: Float2D | None = ...,
velocities: Float3D | None = ...,
cell: Float3D | None = ...,
stress: Float3D | None = ...,
pbc: npt.NDArray[np.bool_] | None = ...,
name: Sequence[str] | None = ...,
properties: dict[str, Any] | None = ...,
Expand Down
66 changes: 34 additions & 32 deletions atompack-py/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub(crate) struct PyAtomDatabase {
impl PyAtomDatabase {
fn single_molecule_view(&self, py: Python<'_>, index: usize) -> PyResult<SoaMoleculeView> {
let compression = self.inner.compression();
let record_format = self.inner.record_format();
let positions_type = self.inner.positions_type();
let use_mmap = self.inner.get_compressed_slice(0).is_some();

if use_mmap {
Expand All @@ -22,7 +24,11 @@ impl PyAtomDatabase {
let bytes = self.inner.get_shared_mmap_bytes(index).ok_or_else(|| {
invalid_data(format!("Missing mmap bytes for molecule {}", index))
})?;
SoaMoleculeView::from_shared_bytes_inner(bytes)
SoaMoleculeView::from_shared_bytes_inner(
bytes,
record_format,
positions_type,
)
} else {
let compressed =
self.inner.get_compressed_slice(index).ok_or_else(|| {
Expand All @@ -43,7 +49,11 @@ impl PyAtomDatabase {
compression,
Some(uncompressed_size),
)?;
SoaMoleculeView::from_bytes_inner(decompressed)
SoaMoleculeView::from_bytes_inner(
decompressed,
record_format,
positions_type,
)
}
})
.map_err(|e| PyValueError::new_err(format!("{}", e)));
Expand All @@ -55,7 +65,7 @@ impl PyAtomDatabase {
let raw = raw_bytes.pop().ok_or_else(|| {
PyValueError::new_err(format!("Missing raw bytes for molecule {}", index))
})?;
SoaMoleculeView::from_bytes(raw)
SoaMoleculeView::from_bytes(raw, record_format, positions_type)
}
}

Expand Down Expand Up @@ -117,12 +127,6 @@ impl PyAtomDatabase {

/// Add a molecule to the database
fn add_molecule(&mut self, molecule: &PyMolecule) -> PyResult<()> {
if let Some((soa_bytes, n_atoms)) = molecule.soa_bytes() {
return self
.inner
.add_raw_soa_records(&[(soa_bytes, n_atoms)])
.map_err(|e| PyValueError::new_err(format!("{}", e)));
}
let owned = molecule.clone_as_owned()?;
self.inner
.add_molecule(&owned)
Expand All @@ -131,22 +135,10 @@ impl PyAtomDatabase {

/// Add multiple molecules (processed in parallel)
fn add_molecules(&mut self, molecules: Vec<PyRef<PyMolecule>>) -> PyResult<()> {
// Split into view-backed (fast path) and owned molecules
let mut raw_records: Vec<(&[u8], u32)> = Vec::new();
let mut owned_molecules: Vec<Molecule> = Vec::new();

for m in &molecules {
if let Some((soa_bytes, n_atoms)) = m.soa_bytes() {
raw_records.push((soa_bytes, n_atoms));
} else {
owned_molecules.push(m.clone_as_owned()?);
}
}

if !raw_records.is_empty() {
self.inner
.add_raw_soa_records(&raw_records)
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
owned_molecules.push(m.clone_as_owned()?);
}
if !owned_molecules.is_empty() {
let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect();
Expand Down Expand Up @@ -179,14 +171,14 @@ impl PyAtomDatabase {
fn add_arrays_batch(
&mut self,
py: Python<'_>,
positions: &Bound<'_, PyArray3<f32>>,
positions: &Bound<'_, PyAny>,
atomic_numbers: &Bound<'_, PyArray2<u8>>,
energy: Option<&Bound<'_, PyArray1<f64>>>,
forces: Option<&Bound<'_, PyArray3<f32>>>,
charges: Option<&Bound<'_, PyArray2<f64>>>,
velocities: Option<&Bound<'_, PyArray3<f32>>>,
cell: Option<&Bound<'_, PyArray3<f64>>>,
stress: Option<&Bound<'_, PyArray3<f64>>>,
energy: Option<&Bound<'_, PyAny>>,
forces: Option<&Bound<'_, PyAny>>,
charges: Option<&Bound<'_, PyAny>>,
velocities: Option<&Bound<'_, PyAny>>,
cell: Option<&Bound<'_, PyAny>>,
stress: Option<&Bound<'_, PyAny>>,
pbc: Option<&Bound<'_, PyArray2<bool>>>,
name: Option<Vec<String>>,
properties: Option<&Bound<'_, PyDict>>,
Expand Down Expand Up @@ -238,6 +230,8 @@ impl PyAtomDatabase {
}

let compression = self.inner.compression();
let record_format = self.inner.record_format();
let positions_type = self.inner.positions_type();
let use_mmap = self.inner.get_compressed_slice(0).is_some();

let views: Vec<SoaMoleculeView> = if use_mmap {
Expand All @@ -250,7 +244,11 @@ impl PyAtomDatabase {
let bytes = self.inner.get_shared_mmap_bytes(idx).ok_or_else(|| {
invalid_data(format!("Missing mmap bytes for molecule {}", idx))
})?;
SoaMoleculeView::from_shared_bytes_inner(bytes)
SoaMoleculeView::from_shared_bytes_inner(
bytes,
record_format,
positions_type,
)
} else {
let compressed =
self.inner.get_compressed_slice(idx).ok_or_else(|| {
Expand All @@ -271,7 +269,11 @@ impl PyAtomDatabase {
compression,
Some(uncompressed_size),
)?;
SoaMoleculeView::from_bytes_inner(decompressed)
SoaMoleculeView::from_bytes_inner(
decompressed,
record_format,
positions_type,
)
}
})
.collect()
Expand All @@ -283,7 +285,7 @@ impl PyAtomDatabase {
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
raw_bytes
.into_iter()
.map(SoaMoleculeView::from_bytes)
.map(|bytes| SoaMoleculeView::from_bytes(bytes, record_format, positions_type))
.collect::<PyResult<Vec<_>>>()
.map_err(|e| invalid_data(format!("{}", e)))
}
Expand Down
Loading
Loading