Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,18 @@ class Database:
``properties`` and ``atom_properties`` dictionaries when present.
"""
...
def num_atoms(self, index: int) -> int:
"""
Get the atom count for one molecule without materializing it.
"""
...
def atom_counts(self, indices: list[int] | None = None) -> npt.NDArray[np.uint32]:
"""
Get atom counts for a selection of molecules.

Passing ``indices=None`` returns counts for the whole database.
"""
...
def to_ase_batch(
self,
indices: list[int] | None = None,
Expand Down
12 changes: 12 additions & 0 deletions atompack-py/python/atompack/_atompack_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,18 @@ class PyAtomDatabase:
``properties`` and ``atom_properties`` dictionaries when present.
"""
...
def num_atoms(self, index: int) -> int:
"""
Get the atom count for one molecule without materializing it.
"""
...
def atom_counts(self, indices: Sequence[int] | None = None) -> np.ndarray:
"""
Get atom counts for a selection of molecules as ``uint32`` values.

Passing ``indices=None`` returns counts for the whole database.
"""
...

def __len__(self) -> int:
"""
Expand Down
26 changes: 26 additions & 0 deletions atompack-py/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,32 @@ impl PyAtomDatabase {
flat::get_molecules_flat_soa_impl(&self.inner, py, indices)
}

/// Get the atom count for one molecule without materializing it.
fn num_atoms(&self, index: usize) -> PyResult<u32> {
self.inner.num_atoms(index).ok_or_else(|| {
PyIndexError::new_err(format!(
"Index {} out of bounds for database of length {}",
index,
self.inner.len()
))
})
}

/// Get atom counts for a selection of molecules.
#[pyo3(signature = (indices=None))]
fn atom_counts<'py>(
&self,
py: Python<'py>,
indices: Option<Vec<usize>>,
) -> PyResult<Bound<'py, PyArray1<u32>>> {
let selected = indices.unwrap_or_else(|| (0..self.inner.len()).collect());
let counts = selected
.into_iter()
.map(|index| self.num_atoms(index))
.collect::<PyResult<Vec<_>>>()?;
Ok(PyArray1::from_slice(py, &counts))
}

/// Get the number of molecules in the database
fn __len__(&self) -> usize {
self.inner.len()
Expand Down
35 changes: 35 additions & 0 deletions atompack-py/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,41 @@ def test_database_add_arrays_batch_promotes_to_float64_geometry_when_needed(
assert flat["forces"].dtype == np.float64


def test_database_atom_counts_expose_cheap_shape_metadata(tmp_path: Path) -> None:
path = tmp_path / "atom_counts.atp"
molecules = [
atompack.Molecule.from_arrays(
np.zeros((1, 3), dtype=np.float32),
np.array([1], dtype=np.uint8),
),
atompack.Molecule.from_arrays(
np.zeros((3, 3), dtype=np.float32),
np.array([6, 8, 1], dtype=np.uint8),
),
atompack.Molecule.from_arrays(
np.zeros((0, 3), dtype=np.float32),
np.zeros((0,), dtype=np.uint8),
),
]

db = atompack.Database(str(path))
db.add_molecules(molecules)
db.flush()

reopened = atompack.Database.open(str(path))
assert reopened.num_atoms(0) == 1
assert reopened.num_atoms(1) == 3
assert reopened.num_atoms(2) == 0
np.testing.assert_array_equal(reopened.atom_counts(), np.array([1, 3, 0], dtype=np.uint32))
np.testing.assert_array_equal(
reopened.atom_counts([2, 0]),
np.array([0, 1], dtype=np.uint32),
)

with pytest.raises(IndexError, match="out of bounds"):
reopened.num_atoms(3)


@pytest.mark.parametrize("mmap", [False, True])
@pytest.mark.parametrize("compression", ["none", "lz4", "zstd"])
def test_database_single_item_reads_are_view_compatible(
Expand Down
2 changes: 1 addition & 1 deletion atompack-py/tests/test_stub_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_private_stub_tracks_low_level_surface() -> None:

def test_public_stub_exposes_flat_batch_reader() -> None:
database_methods = _class_method_names(PUBLIC_STUB, "Database")
assert "get_molecules_flat" in database_methods
assert {"get_molecules_flat", "num_atoms", "atom_counts"} <= database_methods


def test_hub_stub_has_public_docstrings() -> None:
Expand Down
Loading