From b2f8a26ff94f394f908aad44eb4ad14a35361c2e Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 16:21:57 +0200 Subject: [PATCH] feat: add flat batches for hub readers --- atompack-py/python/atompack/hub.py | 126 +++++++++++++++++++++++++ atompack-py/python/atompack/hub.pyi | 7 ++ atompack-py/tests/test_hub.py | 39 ++++++++ atompack-py/tests/test_stub_surface.py | 3 + 4 files changed, 175 insertions(+) diff --git a/atompack-py/python/atompack/hub.py b/atompack-py/python/atompack/hub.py index 0fb8a19..c8dc76d 100644 --- a/atompack-py/python/atompack/hub.py +++ b/atompack-py/python/atompack/hub.py @@ -5,12 +5,114 @@ from pathlib import Path, PurePosixPath from typing import Any, Sequence +import numpy as np + from ._atompack_rs import PyAtomDatabase as Database from ._atompack_rs import PyMolecule as Molecule from .ase_bridge import to_ase_batch as _to_ase_batch _DIRECTORY_ALLOW_PATTERNS = ["*.atp", "**/*.atp", "manifest.json", "**/manifest.json"] _XET_DISABLE_SINGLE_FILE_THRESHOLD_BYTES = 512 * 1024 * 1024 +_PER_ATOM_BATCH_KEYS = {"positions", "atomic_numbers", "forces", "charges", "velocities"} +_NESTED_BATCH_KEYS = {"properties", "atom_properties"} + + +def _slice_batch_value(value: Any, start: int, stop: int) -> Any: + return value[start:stop] + + +def _concat_batch_values(parts: list[Any]) -> Any: + if isinstance(parts[0], list): + merged: list[Any] = [] + for part in parts: + merged.extend(part) + return merged + return np.concatenate(parts, axis=0) + + +def _split_flat_batch_records(batch: dict[str, Any]) -> list[dict[str, Any]]: + counts = [int(count) for count in batch["n_atoms"]] + offsets = [0] + for count in counts: + offsets.append(offsets[-1] + count) + + records: list[dict[str, Any]] = [] + for mol_index, (start, stop) in enumerate(zip(offsets, offsets[1:])): + record: dict[str, Any] = { + "n_atoms": _slice_batch_value(batch["n_atoms"], mol_index, mol_index + 1), + "positions": _slice_batch_value(batch["positions"], start, stop), + "atomic_numbers": _slice_batch_value(batch["atomic_numbers"], start, stop), + } + for key, value in batch.items(): + if key in record: + continue + if key == "atom_properties": + record[key] = { + prop_key: _slice_batch_value(prop_value, start, stop) + for prop_key, prop_value in value.items() + } + continue + if key == "properties": + record[key] = { + prop_key: _slice_batch_value(prop_value, mol_index, mol_index + 1) + for prop_key, prop_value in value.items() + } + continue + if key in _PER_ATOM_BATCH_KEYS: + record[key] = _slice_batch_value(value, start, stop) + else: + record[key] = _slice_batch_value(value, mol_index, mol_index + 1) + records.append(record) + return records + + +def _merge_nested_batch_group(records: list[dict[str, Any]], key: str) -> dict[str, Any]: + present = [key in record for record in records] + if not any(present): + return {} + if not all(present): + raise ValueError(f"Selected molecules disagree on whether '{key}' is present") + + first_group = records[0][key] + expected_keys = list(first_group.keys()) + expected_set = set(expected_keys) + for record in records[1:]: + actual_keys = set(record[key].keys()) + if actual_keys != expected_set: + raise ValueError(f"Selected molecules disagree on nested '{key}' keys") + + return { + nested_key: _concat_batch_values([record[key][nested_key] for record in records]) + for nested_key in expected_keys + } + + +def _merge_flat_batch_records(records: list[dict[str, Any]]) -> dict[str, Any]: + result = { + "n_atoms": _concat_batch_values([record["n_atoms"] for record in records]), + "positions": _concat_batch_values([record["positions"] for record in records]), + "atomic_numbers": _concat_batch_values([record["atomic_numbers"] for record in records]), + } + + top_level_keys = { + key + for record in records + for key in record + if key not in result and key not in _NESTED_BATCH_KEYS + } + for key in top_level_keys: + present = [key in record for record in records] + if not all(present): + raise ValueError(f"Selected molecules disagree on whether '{key}' is present") + result[key] = _concat_batch_values([record[key] for record in records]) + + properties = _merge_nested_batch_group(records, "properties") + if properties: + result["properties"] = properties + atom_properties = _merge_nested_batch_group(records, "atom_properties") + if atom_properties: + result["atom_properties"] = atom_properties + return result def _require_hf_hub() -> Any: @@ -197,6 +299,30 @@ def get_molecules(self, indices: list[int]) -> list[Molecule]: return [molecule for molecule in molecules if molecule is not None] + def get_molecules_flat(self, indices: Sequence[int]) -> dict[str, Any]: + self._ensure_open() + selected_indices = list(indices) + if not selected_indices: + return self._databases[0].get_molecules_flat([]) + + grouped: dict[int, list[tuple[int, int]]] = {} + for output_index, index in enumerate(selected_indices): + db_index, local_index = self._locate(index) + grouped.setdefault(db_index, []).append((output_index, local_index)) + + records: list[dict[str, Any] | None] = [None] * len(selected_indices) + for db_index, pairs in grouped.items(): + local_indices = [local_index for _, local_index in pairs] + shard_batch = self._databases[db_index].get_molecules_flat(local_indices) + shard_records = _split_flat_batch_records(shard_batch) + for (output_index, _), record in zip(pairs, shard_records): + records[output_index] = record + + if any(record is None for record in records): + raise ValueError("Failed to reconstruct one or more requested flat-batch records") + ordered_records = [record for record in records if record is not None] + return _merge_flat_batch_records(ordered_records) + def to_ase_batch( self, indices: list[int] | None = None, diff --git a/atompack-py/python/atompack/hub.pyi b/atompack-py/python/atompack/hub.pyi index 134283d..f87af6a 100644 --- a/atompack-py/python/atompack/hub.pyi +++ b/atompack-py/python/atompack/hub.pyi @@ -56,6 +56,13 @@ class AtompackReader: supported. """ ... + def get_molecules_flat(self, indices: Sequence[int]) -> dict[str, Any]: + """ + Fetch many molecules as one flat batch while preserving input order. + + This mirrors ``Database.get_molecules_flat`` across a merged shard set. + """ + ... def to_ase_batch( self, indices: Sequence[int] | None = None, diff --git a/atompack-py/tests/test_hub.py b/atompack-py/tests/test_hub.py index c7db2a0..78a7811 100644 --- a/atompack-py/tests/test_hub.py +++ b/atompack-py/tests/test_hub.py @@ -455,6 +455,45 @@ def test_reader_to_ase_batch_preserves_requested_order(tmp_path: Path) -> None: assert atoms_list[1].get_potential_energy() == pytest.approx(-1.0) +def test_reader_get_molecules_flat_preserves_requested_order(tmp_path: Path) -> None: + shard_dir = tmp_path / "train" + shard_dir.mkdir() + + first = atompack.Database(str(shard_dir / "a.atp"), compression="none") + mol_a = atompack.Molecule( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32), + np.array([6, 8], dtype=np.uint8), + ) + mol_a.energy = -1.0 + mol_a.set_property("split", "train-a") + mol_b = atompack.Molecule( + np.array([[0.5, 0.0, 0.0]], dtype=np.float32), + np.array([1], dtype=np.uint8), + ) + mol_b.energy = -2.0 + mol_b.set_property("split", "train-b") + first.add_molecules([mol_a, mol_b]) + first.flush() + + second = atompack.Database(str(shard_dir / "b.atp"), compression="none") + mol_c = atompack.Molecule( + np.array([[1.5, 0.0, 0.0], [2.5, 0.0, 0.0], [3.5, 0.0, 0.0]], dtype=np.float32), + np.array([8, 1, 1], dtype=np.uint8), + ) + mol_c.energy = -3.0 + mol_c.set_property("split", "train-c") + second.add_molecule(mol_c) + second.flush() + + reader = atompack.hub.open_path(shard_dir) + batch = reader.get_molecules_flat([2, 0, 1]) + + np.testing.assert_array_equal(batch["n_atoms"], np.array([3, 2, 1], dtype=np.uint32)) + np.testing.assert_allclose(batch["energy"], np.array([-3.0, -1.0, -2.0], dtype=np.float64)) + np.testing.assert_array_equal(batch["atomic_numbers"], np.array([8, 1, 1, 6, 8, 1])) + assert batch["properties"]["split"] == ["train-c", "train-a", "train-b"] + + def test_import_atompack_does_not_require_huggingface_hub(tmp_path: Path) -> None: python_src = Path(__file__).resolve().parents[1] / "python" sitecustomize = tmp_path / "sitecustomize.py" diff --git a/atompack-py/tests/test_stub_surface.py b/atompack-py/tests/test_stub_surface.py index b883fcc..5b89c23 100644 --- a/atompack-py/tests/test_stub_surface.py +++ b/atompack-py/tests/test_stub_surface.py @@ -76,6 +76,9 @@ def test_public_stub_exposes_flat_batch_reader() -> None: def test_hub_stub_has_public_docstrings() -> None: + reader_methods = _class_method_names(HUB_STUB, "AtompackReader") + assert "get_molecules_flat" in reader_methods + reader_doc = _class_docstring(HUB_STUB, "AtompackReader") or "" assert "lexicographically ordered shard set" in reader_doc