From 7d9024e44e35fc4a7a89d944430797321de9e6c3 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 9 May 2026 16:22:01 +0200 Subject: [PATCH] feat: add pythonic hub reader indexing --- atompack-py/python/atompack/hub.py | 11 ++++++++++- atompack-py/python/atompack/hub.pyi | 8 +++++++- atompack-py/tests/test_hub.py | 15 +++++++++++++++ atompack-py/tests/test_stub_surface.py | 3 +++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/atompack-py/python/atompack/hub.py b/atompack-py/python/atompack/hub.py index 0fb8a19..bd4defe 100644 --- a/atompack-py/python/atompack/hub.py +++ b/atompack-py/python/atompack/hub.py @@ -171,9 +171,18 @@ def __len__(self) -> int: self._ensure_open() return self._total_length - def __getitem__(self, index: int) -> Molecule: + def __getitem__(self, index: int | slice) -> Molecule | list[Molecule]: + if isinstance(index, slice): + self._ensure_open() + start, stop, step = index.indices(self._total_length) + return self.get_molecules(list(range(start, stop, step))) return self.get_molecule(index) + def __iter__(self): + self._ensure_open() + for index in range(self._total_length): + yield self.get_molecule(index) + def get_molecule(self, index: int) -> Molecule: db_index, local_index = self._locate(index) return self._databases[db_index][local_index] diff --git a/atompack-py/python/atompack/hub.pyi b/atompack-py/python/atompack/hub.pyi index 134283d..7aed48e 100644 --- a/atompack-py/python/atompack/hub.pyi +++ b/atompack-py/python/atompack/hub.pyi @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path from types import TracebackType -from typing import Any, Sequence +from typing import Any, Iterator, Sequence, overload from . import Molecule @@ -26,6 +26,7 @@ class AtompackReader: def __len__(self) -> int: """Return the total number of molecules across all opened files.""" ... + @overload def __getitem__(self, index: int) -> Molecule: """ Fetch one molecule by index. @@ -34,6 +35,8 @@ class AtompackReader: dataset, not within a single shard. """ ... + @overload + def __getitem__(self, index: slice) -> list[Molecule]: ... def get_molecule(self, index: int) -> Molecule: """ Fetch one molecule by global index across the underlying shard set. @@ -75,6 +78,9 @@ class AtompackReader: def close(self) -> None: """Close the underlying databases and invalidate the reader.""" ... + def __iter__(self) -> Iterator[Molecule]: + """Iterate over molecules in logical reader order.""" + ... def download( repo_id: str, diff --git a/atompack-py/tests/test_hub.py b/atompack-py/tests/test_hub.py index c7db2a0..fd7700a 100644 --- a/atompack-py/tests/test_hub.py +++ b/atompack-py/tests/test_hub.py @@ -415,6 +415,21 @@ def test_open_path_directory_flattens_lexicographically(tmp_path: Path) -> None: assert [reader[i].energy for i in range(len(reader))] == pytest.approx([-1.0, -2.0, -3.0]) +def test_reader_supports_iteration_and_slices(tmp_path: Path) -> None: + shard_dir = tmp_path / "shards" + shard_dir.mkdir() + _make_db(shard_dir / "a.atp", [-1.0, -2.0]) + _make_db(shard_dir / "b.atp", [-3.0, -4.0]) + + reader = atompack.hub.open_path(shard_dir) + + assert [molecule.energy for molecule in reader] == pytest.approx([-1.0, -2.0, -3.0, -4.0]) + assert [molecule.energy for molecule in reader[1:4:2]] == pytest.approx([-2.0, -4.0]) + assert [molecule.energy for molecule in reader[::-1]] == pytest.approx( + [-4.0, -3.0, -2.0, -1.0] + ) + + def test_open_path_context_manager_closes_reader(tmp_path: Path) -> None: source = tmp_path / "single.atp" _make_db(source, [-1.0]) diff --git a/atompack-py/tests/test_stub_surface.py b/atompack-py/tests/test_stub_surface.py index b883fcc..f4cc377 100644 --- a/atompack-py/tests/test_stub_surface.py +++ b/atompack-py/tests/test_stub_surface.py @@ -76,6 +76,9 @@ def test_public_stub_exposes_flat_batch_reader() -> None: def test_hub_stub_has_public_docstrings() -> None: + reader_methods = _class_method_names(HUB_STUB, "AtompackReader") + assert {"__getitem__", "__iter__"} <= reader_methods + reader_doc = _class_docstring(HUB_STUB, "AtompackReader") or "" assert "lexicographically ordered shard set" in reader_doc