From 83f51d3e0cfc244b896fec12a4feef4b3f9c0d5c Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 16:21:51 +0200 Subject: [PATCH] feat: expose database atom count metadata --- atompack-py/python/atompack/__init__.pyi | 12 +++++++ atompack-py/python/atompack/_atompack_rs.pyi | 12 +++++++ atompack-py/src/database.rs | 26 +++++++++++++++ atompack-py/tests/test_database.py | 35 ++++++++++++++++++++ atompack-py/tests/test_stub_surface.py | 2 +- 5 files changed, 86 insertions(+), 1 deletion(-) diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index a28c97e..14712a4 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -569,6 +569,18 @@ class Database: ``properties`` and ``atom_properties`` dictionaries when present. """ ... + def num_atoms(self, index: int) -> int: + """ + Get the atom count for one molecule without materializing it. + """ + ... + def atom_counts(self, indices: list[int] | None = None) -> npt.NDArray[np.uint32]: + """ + Get atom counts for a selection of molecules. + + Passing ``indices=None`` returns counts for the whole database. + """ + ... def to_ase_batch( self, indices: list[int] | None = None, diff --git a/atompack-py/python/atompack/_atompack_rs.pyi b/atompack-py/python/atompack/_atompack_rs.pyi index 99e8730..7ed4b12 100644 --- a/atompack-py/python/atompack/_atompack_rs.pyi +++ b/atompack-py/python/atompack/_atompack_rs.pyi @@ -508,6 +508,18 @@ class PyAtomDatabase: ``properties`` and ``atom_properties`` dictionaries when present. """ ... + def num_atoms(self, index: int) -> int: + """ + Get the atom count for one molecule without materializing it. + """ + ... + def atom_counts(self, indices: Sequence[int] | None = None) -> np.ndarray: + """ + Get atom counts for a selection of molecules as ``uint32`` values. + + Passing ``indices=None`` returns counts for the whole database. + """ + ... def __len__(self) -> int: """ diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index c890ae8..e11a844 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -332,6 +332,32 @@ impl PyAtomDatabase { flat::get_molecules_flat_soa_impl(&self.inner, py, indices) } + /// Get the atom count for one molecule without materializing it. + fn num_atoms(&self, index: usize) -> PyResult { + self.inner.num_atoms(index).ok_or_else(|| { + PyIndexError::new_err(format!( + "Index {} out of bounds for database of length {}", + index, + self.inner.len() + )) + }) + } + + /// Get atom counts for a selection of molecules. + #[pyo3(signature = (indices=None))] + fn atom_counts<'py>( + &self, + py: Python<'py>, + indices: Option>, + ) -> PyResult>> { + let selected = indices.unwrap_or_else(|| (0..self.inner.len()).collect()); + let counts = selected + .into_iter() + .map(|index| self.num_atoms(index)) + .collect::>>()?; + Ok(PyArray1::from_slice(py, &counts)) + } + /// Get the number of molecules in the database fn __len__(&self) -> usize { self.inner.len() diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 4d83dfd..a883f90 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -324,6 +324,41 @@ def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed( assert flat["forces"].dtype == np.float64 +def test_database_atom_counts_expose_cheap_shape_metadata(tmp_path: Path) -> None: + path = tmp_path / "atom_counts.atp" + molecules = [ + atompack.Molecule.from_arrays( + np.zeros((1, 3), dtype=np.float32), + np.array([1], dtype=np.uint8), + ), + atompack.Molecule.from_arrays( + np.zeros((3, 3), dtype=np.float32), + np.array([6, 8, 1], dtype=np.uint8), + ), + atompack.Molecule.from_arrays( + np.zeros((0, 3), dtype=np.float32), + np.zeros((0,), dtype=np.uint8), + ), + ] + + db = atompack.Database(str(path)) + db.add_molecules(molecules) + db.flush() + + reopened = atompack.Database.open(str(path)) + assert reopened.num_atoms(0) == 1 + assert reopened.num_atoms(1) == 3 + assert reopened.num_atoms(2) == 0 + np.testing.assert_array_equal(reopened.atom_counts(), np.array([1, 3, 0], dtype=np.uint32)) + np.testing.assert_array_equal( + reopened.atom_counts([2, 0]), + np.array([0, 1], dtype=np.uint32), + ) + + with pytest.raises(IndexError, match="out of bounds"): + reopened.num_atoms(3) + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_reads_are_view_compatible( diff --git a/atompack-py/tests/test_stub_surface.py b/atompack-py/tests/test_stub_surface.py index b883fcc..51bf110 100644 --- a/atompack-py/tests/test_stub_surface.py +++ b/atompack-py/tests/test_stub_surface.py @@ -72,7 +72,7 @@ def test_private_stub_tracks_low_level_surface() -> None: def test_public_stub_exposes_flat_batch_reader() -> None: database_methods = _class_method_names(PUBLIC_STUB, "Database") - assert "get_molecules_flat" in database_methods + assert {"get_molecules_flat", "num_atoms", "atom_counts"} <= database_methods def test_hub_stub_has_public_docstrings() -> None: