Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions atompack-py/python/atompack/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,114 @@
from pathlib import Path, PurePosixPath
from typing import Any, Sequence

import numpy as np

from ._atompack_rs import PyAtomDatabase as Database
from ._atompack_rs import PyMolecule as Molecule
from .ase_bridge import to_ase_batch as _to_ase_batch

_DIRECTORY_ALLOW_PATTERNS = ["*.atp", "**/*.atp", "manifest.json", "**/manifest.json"]
_XET_DISABLE_SINGLE_FILE_THRESHOLD_BYTES = 512 * 1024 * 1024
_PER_ATOM_BATCH_KEYS = {"positions", "atomic_numbers", "forces", "charges", "velocities"}
_NESTED_BATCH_KEYS = {"properties", "atom_properties"}


def _slice_batch_value(value: Any, start: int, stop: int) -> Any:
return value[start:stop]


def _concat_batch_values(parts: list[Any]) -> Any:
if isinstance(parts[0], list):
merged: list[Any] = []
for part in parts:
merged.extend(part)
return merged
return np.concatenate(parts, axis=0)


def _split_flat_batch_records(batch: dict[str, Any]) -> list[dict[str, Any]]:
counts = [int(count) for count in batch["n_atoms"]]
offsets = [0]
for count in counts:
offsets.append(offsets[-1] + count)

records: list[dict[str, Any]] = []
for mol_index, (start, stop) in enumerate(zip(offsets, offsets[1:])):
record: dict[str, Any] = {
"n_atoms": _slice_batch_value(batch["n_atoms"], mol_index, mol_index + 1),
"positions": _slice_batch_value(batch["positions"], start, stop),
"atomic_numbers": _slice_batch_value(batch["atomic_numbers"], start, stop),
}
for key, value in batch.items():
if key in record:
continue
if key == "atom_properties":
record[key] = {
prop_key: _slice_batch_value(prop_value, start, stop)
for prop_key, prop_value in value.items()
}
continue
if key == "properties":
record[key] = {
prop_key: _slice_batch_value(prop_value, mol_index, mol_index + 1)
for prop_key, prop_value in value.items()
}
continue
if key in _PER_ATOM_BATCH_KEYS:
record[key] = _slice_batch_value(value, start, stop)
else:
record[key] = _slice_batch_value(value, mol_index, mol_index + 1)
records.append(record)
return records


def _merge_nested_batch_group(records: list[dict[str, Any]], key: str) -> dict[str, Any]:
present = [key in record for record in records]
if not any(present):
return {}
if not all(present):
raise ValueError(f"Selected molecules disagree on whether '{key}' is present")

first_group = records[0][key]
expected_keys = list(first_group.keys())
expected_set = set(expected_keys)
for record in records[1:]:
actual_keys = set(record[key].keys())
if actual_keys != expected_set:
raise ValueError(f"Selected molecules disagree on nested '{key}' keys")

return {
nested_key: _concat_batch_values([record[key][nested_key] for record in records])
for nested_key in expected_keys
}


def _merge_flat_batch_records(records: list[dict[str, Any]]) -> dict[str, Any]:
result = {
"n_atoms": _concat_batch_values([record["n_atoms"] for record in records]),
"positions": _concat_batch_values([record["positions"] for record in records]),
"atomic_numbers": _concat_batch_values([record["atomic_numbers"] for record in records]),
}

top_level_keys = {
key
for record in records
for key in record
if key not in result and key not in _NESTED_BATCH_KEYS
}
for key in top_level_keys:
present = [key in record for record in records]
if not all(present):
raise ValueError(f"Selected molecules disagree on whether '{key}' is present")
result[key] = _concat_batch_values([record[key] for record in records])

properties = _merge_nested_batch_group(records, "properties")
if properties:
result["properties"] = properties
atom_properties = _merge_nested_batch_group(records, "atom_properties")
if atom_properties:
result["atom_properties"] = atom_properties
return result


def _require_hf_hub() -> Any:
Expand Down Expand Up @@ -197,6 +299,30 @@ def get_molecules(self, indices: list[int]) -> list[Molecule]:

return [molecule for molecule in molecules if molecule is not None]

def get_molecules_flat(self, indices: Sequence[int]) -> dict[str, Any]:
self._ensure_open()
selected_indices = list(indices)
if not selected_indices:
return self._databases[0].get_molecules_flat([])

grouped: dict[int, list[tuple[int, int]]] = {}
for output_index, index in enumerate(selected_indices):
db_index, local_index = self._locate(index)
grouped.setdefault(db_index, []).append((output_index, local_index))

records: list[dict[str, Any] | None] = [None] * len(selected_indices)
for db_index, pairs in grouped.items():
local_indices = [local_index for _, local_index in pairs]
shard_batch = self._databases[db_index].get_molecules_flat(local_indices)
shard_records = _split_flat_batch_records(shard_batch)
for (output_index, _), record in zip(pairs, shard_records):
records[output_index] = record

if any(record is None for record in records):
raise ValueError("Failed to reconstruct one or more requested flat-batch records")
ordered_records = [record for record in records if record is not None]
return _merge_flat_batch_records(ordered_records)

def to_ase_batch(
self,
indices: list[int] | None = None,
Expand Down
7 changes: 7 additions & 0 deletions atompack-py/python/atompack/hub.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class AtompackReader:
supported.
"""
...
def get_molecules_flat(self, indices: Sequence[int]) -> dict[str, Any]:
"""
Fetch many molecules as one flat batch while preserving input order.

This mirrors ``Database.get_molecules_flat`` across a merged shard set.
"""
...
def to_ase_batch(
self,
indices: Sequence[int] | None = None,
Expand Down
39 changes: 39 additions & 0 deletions atompack-py/tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,45 @@ def test_reader_to_ase_batch_preserves_requested_order(tmp_path: Path) -> None:
assert atoms_list[1].get_potential_energy() == pytest.approx(-1.0)


def test_reader_get_molecules_flat_preserves_requested_order(tmp_path: Path) -> None:
shard_dir = tmp_path / "train"
shard_dir.mkdir()

first = atompack.Database(str(shard_dir / "a.atp"), compression="none")
mol_a = atompack.Molecule(
np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32),
np.array([6, 8], dtype=np.uint8),
)
mol_a.energy = -1.0
mol_a.set_property("split", "train-a")
mol_b = atompack.Molecule(
np.array([[0.5, 0.0, 0.0]], dtype=np.float32),
np.array([1], dtype=np.uint8),
)
mol_b.energy = -2.0
mol_b.set_property("split", "train-b")
first.add_molecules([mol_a, mol_b])
first.flush()

second = atompack.Database(str(shard_dir / "b.atp"), compression="none")
mol_c = atompack.Molecule(
np.array([[1.5, 0.0, 0.0], [2.5, 0.0, 0.0], [3.5, 0.0, 0.0]], dtype=np.float32),
np.array([8, 1, 1], dtype=np.uint8),
)
mol_c.energy = -3.0
mol_c.set_property("split", "train-c")
second.add_molecule(mol_c)
second.flush()

reader = atompack.hub.open_path(shard_dir)
batch = reader.get_molecules_flat([2, 0, 1])

np.testing.assert_array_equal(batch["n_atoms"], np.array([3, 2, 1], dtype=np.uint32))
np.testing.assert_allclose(batch["energy"], np.array([-3.0, -1.0, -2.0], dtype=np.float64))
np.testing.assert_array_equal(batch["atomic_numbers"], np.array([8, 1, 1, 6, 8, 1]))
assert batch["properties"]["split"] == ["train-c", "train-a", "train-b"]


def test_import_atompack_does_not_require_huggingface_hub(tmp_path: Path) -> None:
python_src = Path(__file__).resolve().parents[1] / "python"
sitecustomize = tmp_path / "sitecustomize.py"
Expand Down
3 changes: 3 additions & 0 deletions atompack-py/tests/test_stub_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_public_stub_exposes_flat_batch_reader() -> None:


def test_hub_stub_has_public_docstrings() -> None:
reader_methods = _class_method_names(HUB_STUB, "AtompackReader")
assert "get_molecules_flat" in reader_methods

reader_doc = _class_docstring(HUB_STUB, "AtompackReader") or ""
assert "lexicographically ordered shard set" in reader_doc

Expand Down
Loading