From 08980dc6f2518839f6b49f252fd6dd3ddd9332b3 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 00:03:25 +0100 Subject: [PATCH 01/68] chore: clean up .gitignore by removing unnecessary entries --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 367285b4c..660e95452 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,4 @@ interfaces/lammps/examples/*/*.dat interfaces/lammps/examples/*/deployed_model # batchwise optimizer examples -examples/howtos/howto_batchwise_relaxations_outputs/* \ No newline at end of file +examples/howtos/howto_batchwise_relaxations_outputs/* From f8879c0a17dd85f4ea6574a91d4a61591efa013e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 04:23:31 +0100 Subject: [PATCH 02/68] feat: add new AtomsDataModuleV2 and StatsAtomrefProvider for enhanced data handling --- src/schnetpack/data/__init__.py | 2 + src/schnetpack/data/atoms.py | 2 +- src/schnetpack/data/datamodule_v2.py | 243 ++++++++++++++++++++++ src/schnetpack/data/provider.py | 76 +++++++ src/schnetpack/datasets/qm9.py | 266 ++++++++++++++---------- src/schnetpack/datasets/qm9_legacy.py | 279 ++++++++++++++++++++++++++ src/schnetpack/transform/atomistic.py | 203 +++++++++++++++---- src/schnetpack/transform/base.py | 9 +- 8 files changed, 940 insertions(+), 140 deletions(-) create mode 100644 src/schnetpack/data/datamodule_v2.py create mode 100644 src/schnetpack/data/provider.py create mode 100644 src/schnetpack/datasets/qm9_legacy.py diff --git a/src/schnetpack/data/__init__.py b/src/schnetpack/data/__init__.py index d3c6b83fc..328def9b1 100644 --- a/src/schnetpack/data/__init__.py +++ b/src/schnetpack/data/__init__.py @@ -4,3 +4,5 @@ from .splitting import * from .datamodule import * from .sampler import * +from .datamodule_v2 import * +from .provider import * diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index f01f866fd..a06d7b5d6 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -24,7 +24,7 @@ import schnetpack as spk import schnetpack.properties as structure -from schnetpack.transform import Transform +from schnetpack.transform.base import Transform logger = logging.getLogger(__name__) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py new file mode 100644 index 000000000..5a1a54a68 --- /dev/null +++ b/src/schnetpack/data/datamodule_v2.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from copy import copy +from typing import Any, Dict, List, Optional, Type, Union + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import BatchSampler + +from schnetpack.data.atoms import BaseAtomsData +from schnetpack.data.loader import AtomsLoader +from schnetpack.data.provider import StatsAtomrefProvider +from schnetpack.data.splitting import RandomSplit, SplittingStrategy + + +class AtomsDataModuleV2(pl.LightningDataModule): + """ + V2 DataModule: + - accepts a dataset instance (datasets are independent) + - handles splitting + loaders/batching + - builds StatsAtomrefProvider from train split + - initializes transforms via t.initialize(provider, atomrefs=...) + (no t.datamodule() dependency in the V2 path) + """ + + def __init__( + self, + dataset: BaseAtomsData, + batch_size: int, + num_train: Union[int, float], + num_val: Union[int, float], + num_test: Optional[Union[int, float]] = None, + split_file: Optional[str] = "split.npz", + splitting: Optional[SplittingStrategy] = None, + transforms: Optional[List] = None, + train_transforms: Optional[List] = None, + val_transforms: Optional[List] = None, + test_transforms: Optional[List] = None, + train_sampler_cls: Optional[Type] = None, + train_sampler_args: Optional[Dict[str, Any]] = None, + num_workers: int = 8, + num_val_workers: Optional[int] = None, + num_test_workers: Optional[int] = None, + pin_memory: bool = False, + strict_transform_init: bool = True, + val_batch_size: Optional[int] = None, + test_batch_size: Optional[int] = None, + ): + super().__init__() + + self.dataset = dataset + + self.batch_size = batch_size + self.val_batch_size = val_batch_size or batch_size + self.test_batch_size = test_batch_size or batch_size + + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.split_file = split_file + self.splitting = splitting or RandomSplit() + + # If transforms passed, replicate (copy) for each split unless split-specific provided + self._train_transforms = train_transforms or copy(transforms) or [] + self._val_transforms = val_transforms or copy(transforms) or [] + self._test_transforms = test_transforms or copy(transforms) or [] + self.strict_transform_init = strict_transform_init + + self.train_sampler_cls = train_sampler_cls + self.train_sampler_args = train_sampler_args or {} + + self.num_workers = num_workers + self.num_val_workers = num_val_workers if num_val_workers is not None else num_workers + self.num_test_workers = num_test_workers if num_test_workers is not None else num_workers + self.pin_memory = pin_memory + + self.train_idx: Optional[List[int]] = None + self.val_idx: Optional[List[int]] = None + self.test_idx: Optional[List[int]] = None + + self._train_dataset: Optional[BaseAtomsData] = None + self._val_dataset: Optional[BaseAtomsData] = None + self._test_dataset: Optional[BaseAtomsData] = None + + self._train_loader: Optional[AtomsLoader] = None + self._val_loader: Optional[AtomsLoader] = None + self._test_loader: Optional[AtomsLoader] = None + + self.provider: Optional[StatsAtomrefProvider] = None + + @property + def train_transforms(self): + return self._train_transforms + + @property + def val_transforms(self): + return self._val_transforms + + @property + def test_transforms(self): + return self._test_transforms + + @property + def train_dataset(self) -> BaseAtomsData: + if self._train_dataset is None: + raise RuntimeError("Call setup() before accessing train_dataset.") + return self._train_dataset + + @property + def val_dataset(self) -> BaseAtomsData: + if self._val_dataset is None: + raise RuntimeError("Call setup() before accessing val_dataset.") + return self._val_dataset + + @property + def test_dataset(self) -> BaseAtomsData: + if self._test_dataset is None: + raise RuntimeError("Call setup() before accessing test_dataset.") + return self._test_dataset + + def setup(self, stage: Optional[str] = None) -> None: + if self.train_idx is None: + self._load_partitions() + + # Create split datasets (no transforms attached yet) + self._train_dataset = self.dataset.subset(self.train_idx) + self._val_dataset = self.dataset.subset(self.val_idx) + self._test_dataset = self.dataset.subset(self.test_idx) + + # Build provider bound to train loader factory (loader created on demand) + self.provider = StatsAtomrefProvider( + train_dataloader_factory=self.train_dataloader, + train_atomrefs=getattr(self._train_dataset, "atomrefs", None), + ) + + # Initialize transforms (V2 path: datamodule-free) + train_atomrefs = getattr(self._train_dataset, "atomrefs", None) + self._initialize_transform_list(self.train_transforms, train_atomrefs=train_atomrefs) + self._initialize_transform_list(self.val_transforms, train_atomrefs=train_atomrefs) + self._initialize_transform_list(self.test_transforms, train_atomrefs=train_atomrefs) + + # Attach transforms after init (matches legacy behavior) + self._train_dataset.transforms = self.train_transforms + self._val_dataset.transforms = self.val_transforms + self._test_dataset.transforms = self.test_transforms + + def _initialize_transform_list(self, transforms: List, train_atomrefs): + if not transforms: + return + + for t in transforms: + init_fn = getattr(t, "initialize", None) + if callable(init_fn): + init_fn(self.provider, atomrefs=train_atomrefs) + continue + + if self.strict_transform_init: + raise RuntimeError( + f"Transform {type(t).__name__} does not implement initialize(provider, atomrefs=...)." + ) + + def _load_partitions(self) -> None: + import os + + total_size = len(self.dataset) + + def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: + if x is None: + return None + if isinstance(x, float) and x <= 1.0: + return int(x * total_size) + return int(x) + + num_train = _to_abs(self.num_train) + num_val = _to_abs(self.num_val) + num_test = _to_abs(self.num_test) + + self.num_train, self.num_val, self.num_test = num_train, num_val, num_test + + if self.split_file is not None and os.path.exists(self.split_file): + S = np.load(self.split_file) + self.train_idx = S["train_idx"].tolist() + self.val_idx = S["val_idx"].tolist() + self.test_idx = S["test_idx"].tolist() + return + + train_idx, val_idx, test_idx = self.splitting.split( + self.dataset, num_train, num_val, num_test + ) + self.train_idx, self.val_idx, self.test_idx = train_idx, val_idx, test_idx + + if self.split_file is not None: + np.savez( + self.split_file, + train_idx=self.train_idx, + val_idx=self.val_idx, + test_idx=self.test_idx, + ) + + def _setup_train_batch_sampler(self): + if self.train_sampler_cls is None: + return None + + sampler = self.train_sampler_cls( + data_source=self.train_dataset, + num_samples=len(self.train_dataset), + **self.train_sampler_args, + ) + return BatchSampler(sampler=sampler, batch_size=self.batch_size, drop_last=True) + + def train_dataloader(self) -> AtomsLoader: + if self._train_loader is None: + batch_sampler = self._setup_train_batch_sampler() + self._train_loader = AtomsLoader( + self.train_dataset, + batch_size=self.batch_size if batch_sampler is None else 1, + shuffle=batch_sampler is None, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return self._train_loader + + def val_dataloader(self) -> AtomsLoader: + if self._val_loader is None: + self._val_loader = AtomsLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_val_workers, + pin_memory=self.pin_memory, + ) + return self._val_loader + + def test_dataloader(self) -> AtomsLoader: + if self._test_loader is None: + self._test_loader = AtomsLoader( + self.test_dataset, + batch_size=self.test_batch_size, + num_workers=self.num_test_workers, + pin_memory=self.pin_memory, + ) + return self._test_loader \ No newline at end of file diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py new file mode 100644 index 000000000..99e31255b --- /dev/null +++ b/src/schnetpack/data/provider.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +import torch + +from schnetpack.data.loader import AtomsLoader +from schnetpack.data.stats import calculate_stats, estimate_atomrefs + + +@dataclass +class StatsAtomrefProvider: + """ + Compute and cache statistics and atom references from the *training split*. + + This replaces the logic that used to live in AtomsDataModule.get_stats/get_atomrefs, + without requiring a DataModule instance in transforms or datasets. + """ + + train_dataloader_factory: Callable[[], AtomsLoader] + train_atomrefs: Optional[Dict[str, torch.Tensor]] = None + + _stats_cache: Optional[Dict[Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor]]] = None + _atomref_cache: Optional[Dict[Tuple[str, bool], torch.Tensor]] = None + + def __post_init__(self) -> None: + if self._stats_cache is None: + self._stats_cache = {} + if self._atomref_cache is None: + self._atomref_cache = {} + + def get_stats( + self, property: str, divide_by_atoms: bool, remove_atomref: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + property: property key + divide_by_atoms: if True, compute stats of property / n_atoms + remove_atomref: if True, subtract atomref prior to stats computation + Returns: + (mean, std) tensors + """ + key = (property, divide_by_atoms, remove_atomref) + if key in self._stats_cache: + return self._stats_cache[key] + + loader = self.train_dataloader_factory() + atomref = self.train_atomrefs if remove_atomref else None + + stats = calculate_stats( + loader, + divide_by_atoms={property: divide_by_atoms}, + atomref=atomref, + )[property] + + self._stats_cache[key] = stats + return stats + + def get_atomrefs(self, property: str, is_extensive: bool) -> Dict[str, torch.Tensor]: + """ + Args: + property: property key + is_extensive: whether property is extensive + Returns: + dict {property: atomref_tensor} + """ + key = (property, is_extensive) + if key in self._atomref_cache: + return {property: self._atomref_cache[key]} + + loader = self.train_dataloader_factory() + atomref = estimate_atomrefs(loader, is_extensive={property: is_extensive})[property] + + self._atomref_cache[key] = atomref + return {property: atomref} \ No newline at end of file diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 3085a21a7..3c2675afc 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -5,7 +5,7 @@ import shutil import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request import numpy as np @@ -14,23 +14,25 @@ from tqdm import tqdm import torch -from schnetpack.data import * import schnetpack.properties as structure -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.data import ( + AtomsDataFormat, + AtomsDataModuleError, + BaseAtomsData, + create_dataset, + load_dataset, +) __all__ = ["QM9"] -class QM9(AtomsDataModule): - """QM9 benchmark database for organic molecules. +class QM9: + """QM9 benchmark database for organic molecules (dataset-only). - The QM9 database contains small organic molecules with up to nine non-hydrogen atoms - from including C, O, N, F. This class adds convenient functions to download QM9 from - figshare and load the data into pytorch. - - References: - - .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 + This class: + - is a dataset wrapper (no Lightning DataModule inheritance) + - can download + build the dataset via prepare() + - forwards the BaseAtomsData API to an underlying dataset instance """ base_urls = [ @@ -63,80 +65,107 @@ class QM9(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, + format: AtomsDataFormat = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, remove_uncharacterized: bool = False, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, + transforms: Optional[List[torch.nn.Module]] = None, **kwargs, ): """ - Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format + datapath: path to dataset DB (e.g. qm9.db) + format: dataset format (ASE by default) load_properties: subset of properties to load - remove_uncharacterized: do not include uncharacterized molecules. - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + remove_uncharacterized: if True, exclude uncharacterized molecules + property_units: optional unit overrides on load (passed to load_dataset/create_dataset) + distance_unit: optional distance unit override on load (passed to load_dataset/create_dataset) + transforms: optional default transforms (typically set by your DataModule per split) + **kwargs: reserved for forward compatibility """ - super().__init__( - datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, - load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, - transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, - property_units=property_units, - distance_unit=distance_unit, - data_workdir=data_workdir, - **kwargs, - ) - + self.datapath = datapath + self.format = format + self.load_properties = load_properties self.remove_uncharacterized = remove_uncharacterized + self.property_units = property_units + self.distance_unit = distance_unit + self._kwargs = kwargs + + self._dataset: Optional[BaseAtomsData] = None + self.transforms = transforms or [] + + # ------------------------- + # Dataset forwarding helpers + # ------------------------- + def _ensure_loaded(self) -> None: + if self._dataset is None: + # Lazy-load if it already exists; otherwise require user to call prepare() + if not os.path.exists(self.datapath): + raise AtomsDataModuleError( + f"QM9 dataset not found at {self.datapath}. Call dataset.prepare() first." + ) + self._dataset = load_dataset( + self.datapath, + self.format, + load_properties=self.load_properties, + property_units=self.property_units, + distance_unit=self.distance_unit, + ) + # attach any transforms set before first load + self._dataset.transforms = self.transforms + + def __len__(self) -> int: + self._ensure_loaded() + return len(self._dataset) + + def __getitem__(self, idx: int): + self._ensure_loaded() + return self._dataset[idx] + + def subset(self, indices): + """ + Forward subset() to underlying dataset. + Returns a BaseAtomsData-like object (whatever the backend returns). + """ + self._ensure_loaded() + sub = self._dataset.subset(indices) + # Ensure transforms are carried over if caller expects it + # (DataModuleV2 will set per-split transforms anyway) + if getattr(sub, "transforms", None) is None: + sub.transforms = [] + return sub + + # Common attributes used elsewhere in SchNetPack + @property + def available_properties(self): + self._ensure_loaded() + return self._dataset.available_properties + + @property + def atomrefs(self): + self._ensure_loaded() + return getattr(self._dataset, "atomrefs", None) + + @property + def metadata(self): + self._ensure_loaded() + return getattr(self._dataset, "metadata", None) - def _download_file(self, file_id: str, destination: str): + @property + def distance_unit_internal(self): + self._ensure_loaded() + return getattr(self._dataset, "distance_unit", None) + + @property + def property_unit_dict(self): + self._ensure_loaded() + return getattr(self._dataset, "property_unit_dict", None) + + # ------------------------- + # Download / build pipeline + # ------------------------- + def _download_file(self, file_id: str, destination: str) -> None: for base_url in self.base_urls: url = f"{base_url}{file_id}" try: @@ -148,7 +177,13 @@ def _download_file(self, file_id: str, destination: str): f"Could not download file with id {file_id} from any source." ) - def prepare_data(self): + def prepare_data(self) -> None: + """ + Download + build the dataset if missing. If it already exists, verify + the uncharacterized setting is consistent. + + After prepare(), the dataset is loaded and ready to use. + """ if not os.path.exists(self.datapath): property_unit_dict = { QM9.A: "GHz", @@ -169,36 +204,62 @@ def prepare_data(self): } tmpdir = tempfile.mkdtemp("qm9") - atomrefs = self._download_atomrefs(tmpdir) - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, - ) + try: + atomrefs = self._download_atomrefs(tmpdir) + + dataset = create_dataset( + datapath=self.datapath, + format=self.format, + distance_unit=self.distance_unit or "Ang", + property_unit_dict=property_unit_dict if self.property_units is None else self.property_units, + atomrefs=atomrefs, + ) + + if self.remove_uncharacterized: + uncharacterized = self._download_uncharacterized(tmpdir) + else: + uncharacterized = None + + self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) + + finally: + shutil.rmtree(tmpdir, ignore_errors=True) - if self.remove_uncharacterized: - uncharacterized = self._download_uncharacterized(tmpdir) - else: - uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) - shutil.rmtree(tmpdir) else: - dataset = load_dataset(self.datapath, self.format) + # validate uncharacterized constraint against dataset size + dataset = load_dataset( + self.datapath, + self.format, + load_properties=self.load_properties, + property_units=self.property_units, + distance_unit=self.distance_unit, + ) if self.remove_uncharacterized and len(dataset) == 133885: raise AtomsDataModuleError( "The dataset at the chosen location contains the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=False`!" + "Choose a different location to reload the data or set `remove_uncharacterized=False`." ) - elif not self.remove_uncharacterized and len(dataset) < 133885: + if (not self.remove_uncharacterized) and len(dataset) < 133885: raise AtomsDataModuleError( "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=True`!" + "Choose a different location to reload the data or set `remove_uncharacterized=True`." ) - def _download_uncharacterized(self, tmpdir): + # Load after prepare so wrapper is ready + self._dataset = load_dataset( + self.datapath, + self.format, + load_properties=self.load_properties, + property_units=self.property_units, + distance_unit=self.distance_unit, + ) + self._dataset.transforms = self.transforms + + # # keep Lightning naming for convenience if any code still calls it + # def prepare_data(self) -> None: + # self.prepare() + + def _download_uncharacterized(self, tmpdir: str) -> List[int]: logging.info("Downloading list of uncharacterized molecules...") tmp_path = os.path.join(tmpdir, "uncharacterized.txt") self._download_file(self.file_ids["uncharacterized"], tmp_path) @@ -211,7 +272,7 @@ def _download_uncharacterized(self, tmpdir): uncharacterized.append(int(line.split()[0])) return uncharacterized - def _download_atomrefs(self, tmpdir): + def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: logging.info("Downloading GDB-9 atom references...") tmp_path = os.path.join(tmpdir, "atomrefs.txt") self._download_file(self.file_ids["atomrefs"], tmp_path) @@ -224,12 +285,14 @@ def _download_atomrefs(self, tmpdir): for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): for i, p in enumerate(props): atref[p][z] = float(l.split()[i + 1]) - atref = {k: v.tolist() for k, v in atref.items()} - return atref + return {k: v.tolist() for k, v in atref.items()} def _download_data( - self, tmpdir, dataset: BaseAtomsData, uncharacterized: List[int] - ): + self, + tmpdir: str, + dataset: BaseAtomsData, + uncharacterized: Optional[List[int]], + ) -> None: logging.info("Downloading GDB-9 data...") tar_path = os.path.join(tmpdir, "gdb9.tar.gz") raw_path = os.path.join(tmpdir, "gdb9_xyz") @@ -248,7 +311,6 @@ def _download_data( ) property_list = [] - irange = np.arange(len(ordered_files), dtype=int) if uncharacterized is not None: irange = np.setdiff1d(irange, np.array(uncharacterized, dtype=int) - 1) @@ -276,4 +338,4 @@ def _download_data( logging.info("Write atoms to db...") dataset.add_systems(property_list=property_list) - logging.info("Done.") + logging.info("Done.") \ No newline at end of file diff --git a/src/schnetpack/datasets/qm9_legacy.py b/src/schnetpack/datasets/qm9_legacy.py new file mode 100644 index 000000000..3085a21a7 --- /dev/null +++ b/src/schnetpack/datasets/qm9_legacy.py @@ -0,0 +1,279 @@ +import io +import logging +import os +import re +import shutil +import tarfile +import tempfile +from typing import List, Optional, Dict +from urllib import request as request + +import numpy as np +from ase import Atoms +from ase.io.extxyz import read_xyz +from tqdm import tqdm + +import torch +from schnetpack.data import * +import schnetpack.properties as structure +from schnetpack.data import AtomsDataModuleError, AtomsDataModule + +__all__ = ["QM9"] + + +class QM9(AtomsDataModule): + """QM9 benchmark database for organic molecules. + + The QM9 database contains small organic molecules with up to nine non-hydrogen atoms + from including C, O, N, F. This class adds convenient functions to download QM9 from + figshare and load the data into pytorch. + + References: + + .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 + """ + + base_urls = [ + "https://ndownloader.figshare.com/files/", + "https://springernature.figshare.com/ndownloader/files/", + ] + file_ids = { + "data": "3195389", + "atomrefs": "3195395", + "uncharacterized": "3195404", + } + + # properties + A = "rotational_constant_A" + B = "rotational_constant_B" + C = "rotational_constant_C" + mu = "dipole_moment" + alpha = "isotropic_polarizability" + homo = "homo" + lumo = "lumo" + gap = "gap" + r2 = "electronic_spatial_extent" + zpve = "zpve" + U0 = "energy_U0" + U = "energy_U" + H = "enthalpy_H" + G = "free_energy" + Cv = "heat_capacity" + + def __init__( + self, + datapath: str, + batch_size: int, + num_train: Optional[int] = None, + num_val: Optional[int] = None, + num_test: Optional[int] = None, + split_file: Optional[str] = "split.npz", + format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, + load_properties: Optional[List[str]] = None, + remove_uncharacterized: bool = False, + val_batch_size: Optional[int] = None, + test_batch_size: Optional[int] = None, + transforms: Optional[List[torch.nn.Module]] = None, + train_transforms: Optional[List[torch.nn.Module]] = None, + val_transforms: Optional[List[torch.nn.Module]] = None, + test_transforms: Optional[List[torch.nn.Module]] = None, + num_workers: int = 2, + num_val_workers: Optional[int] = None, + num_test_workers: Optional[int] = None, + property_units: Optional[Dict[str, str]] = None, + distance_unit: Optional[str] = None, + data_workdir: Optional[str] = None, + **kwargs, + ): + """ + + Args: + datapath: path to dataset + batch_size: (train) batch size + num_train: number of training examples + num_val: number of validation examples + num_test: number of test examples + split_file: path to npz file with data partitions + format: dataset format + load_properties: subset of properties to load + remove_uncharacterized: do not include uncharacterized molecules. + val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. + test_batch_size: test batch size. If None, use val_batch_size, then batch_size. + transforms: Transform applied to each system separately before batching. + train_transforms: Overrides transform_fn for training. + val_transforms: Overrides transform_fn for validation. + test_transforms: Overrides transform_fn for testing. + num_workers: Number of data loader workers. + num_val_workers: Number of validation data loader workers (overrides num_workers). + num_test_workers: Number of test data loader workers (overrides num_workers). + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + """ + super().__init__( + datapath=datapath, + batch_size=batch_size, + num_train=num_train, + num_val=num_val, + num_test=num_test, + split_file=split_file, + format=format, + load_properties=load_properties, + val_batch_size=val_batch_size, + test_batch_size=test_batch_size, + transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, + num_workers=num_workers, + num_val_workers=num_val_workers, + num_test_workers=num_test_workers, + property_units=property_units, + distance_unit=distance_unit, + data_workdir=data_workdir, + **kwargs, + ) + + self.remove_uncharacterized = remove_uncharacterized + + def _download_file(self, file_id: str, destination: str): + for base_url in self.base_urls: + url = f"{base_url}{file_id}" + try: + request.urlretrieve(url, destination) + return + except Exception: + logging.warning(f"Could not download from {url}, trying next source...") + raise AtomsDataModuleError( + f"Could not download file with id {file_id} from any source." + ) + + def prepare_data(self): + if not os.path.exists(self.datapath): + property_unit_dict = { + QM9.A: "GHz", + QM9.B: "GHz", + QM9.C: "GHz", + QM9.mu: "Debye", + QM9.alpha: "a0 a0 a0", + QM9.homo: "Ha", + QM9.lumo: "Ha", + QM9.gap: "Ha", + QM9.r2: "a0 a0", + QM9.zpve: "Ha", + QM9.U0: "Ha", + QM9.U: "Ha", + QM9.H: "Ha", + QM9.G: "Ha", + QM9.Cv: "cal/mol/K", + } + + tmpdir = tempfile.mkdtemp("qm9") + atomrefs = self._download_atomrefs(tmpdir) + + dataset = create_dataset( + datapath=self.datapath, + format=self.format, + distance_unit="Ang", + property_unit_dict=property_unit_dict, + atomrefs=atomrefs, + ) + + if self.remove_uncharacterized: + uncharacterized = self._download_uncharacterized(tmpdir) + else: + uncharacterized = None + self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) + shutil.rmtree(tmpdir) + else: + dataset = load_dataset(self.datapath, self.format) + if self.remove_uncharacterized and len(dataset) == 133885: + raise AtomsDataModuleError( + "The dataset at the chosen location contains the uncharacterized 3054 molecules. " + + "Choose a different location to reload the data or set `remove_uncharacterized=False`!" + ) + elif not self.remove_uncharacterized and len(dataset) < 133885: + raise AtomsDataModuleError( + "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " + + "Choose a different location to reload the data or set `remove_uncharacterized=True`!" + ) + + def _download_uncharacterized(self, tmpdir): + logging.info("Downloading list of uncharacterized molecules...") + tmp_path = os.path.join(tmpdir, "uncharacterized.txt") + self._download_file(self.file_ids["uncharacterized"], tmp_path) + logging.info("Done.") + + uncharacterized = [] + with open(tmp_path) as f: + lines = f.readlines() + for line in lines[9:-1]: + uncharacterized.append(int(line.split()[0])) + return uncharacterized + + def _download_atomrefs(self, tmpdir): + logging.info("Downloading GDB-9 atom references...") + tmp_path = os.path.join(tmpdir, "atomrefs.txt") + self._download_file(self.file_ids["atomrefs"], tmp_path) + logging.info("Done.") + + props = [QM9.zpve, QM9.U0, QM9.U, QM9.H, QM9.G, QM9.Cv] + atref = {p: np.zeros((100,)) for p in props} + with open(tmp_path) as f: + lines = f.readlines() + for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): + for i, p in enumerate(props): + atref[p][z] = float(l.split()[i + 1]) + atref = {k: v.tolist() for k, v in atref.items()} + return atref + + def _download_data( + self, tmpdir, dataset: BaseAtomsData, uncharacterized: List[int] + ): + logging.info("Downloading GDB-9 data...") + tar_path = os.path.join(tmpdir, "gdb9.tar.gz") + raw_path = os.path.join(tmpdir, "gdb9_xyz") + self._download_file(self.file_ids["data"], tar_path) + logging.info("Done.") + + logging.info("Extracting files...") + tar = tarfile.open(tar_path) + tar.extractall(raw_path) + tar.close() + logging.info("Done.") + + logging.info("Parse xyz files...") + ordered_files = sorted( + os.listdir(raw_path), key=lambda x: (int(re.sub(r"\D", "", x)), x) + ) + + property_list = [] + + irange = np.arange(len(ordered_files), dtype=int) + if uncharacterized is not None: + irange = np.setdiff1d(irange, np.array(uncharacterized, dtype=int) - 1) + + for i in tqdm(irange): + xyzfile = os.path.join(raw_path, ordered_files[i]) + properties = {} + + tmp = io.StringIO() + with open(xyzfile, "r") as f: + lines = f.readlines() + l = lines[1].split()[2:] + for pn, p in zip(dataset.available_properties, l): + properties[pn] = np.array([float(p)]) + for line in lines: + tmp.write(line.replace("*^", "e")) + + tmp.seek(0) + ats: Atoms = list(read_xyz(tmp, 0))[0] + properties[structure.Z] = ats.numbers + properties[structure.R] = ats.positions + properties[structure.cell] = ats.cell + properties[structure.pbc] = ats.pbc + property_list.append(properties) + + logging.info("Write atoms to db...") + dataset.add_systems(property_list=property_list) + logging.info("Done.") diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index d7a87dfa1..358387086 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -6,6 +6,8 @@ import schnetpack.properties as structure from .base import Transform from schnetpack.nn import scatter_add +from schnetpack.data.provider import StatsAtomrefProvider + __all__ = [ "SubtractCenterOfMass", @@ -117,24 +119,58 @@ def __init__( property_mean = property_mean or torch.zeros((1,)) self.register_buffer("mean", property_mean) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None) -> None: """ - Sets mean and atomref automatically when using PyTorchLightning integration. + Initialize mean and/or atomref using a StatsAtomrefProvider. """ if self.remove_atomrefs and not self._atomrefs_initialized: if self.estimate_atomref: - atrefs = _datamodule.get_atomrefs( - property=self._property, is_extensive=self.is_extensive - ) + atrefs = provider.get_atomrefs(self._property, self.is_extensive) else: - atrefs = _datamodule.train_dataset.atomrefs + if atomrefs is None: + raise RuntimeError( + "RemoveOffsets requires dataset atomrefs when estimate_atomref=False." + ) + atrefs = atomrefs self.atomref = atrefs[self._property].detach() if self.remove_mean and not self._mean_initialized: - stats = _datamodule.get_stats( + mean, _std = provider.get_stats( self._property, self.is_extensive, self.remove_atomrefs ) - self.mean = stats[0].detach() + self.mean = mean.detach() + + #legacy hook for old AtomsDataModule + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + + provider = StatsAtomrefProvider( + train_dataloader_factory=_datamodule.train_dataloader, + train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), + ) + atomrefs = getattr(_datamodule.train_dataset, "atomrefs", None) + return self.initialize(provider, atomrefs=atomrefs) + + # def datamodule(self, _datamodule): + # """ + # Sets mean and atomref automatically when using PyTorchLightning integration. + # """ + # if self.remove_atomrefs and not self._atomrefs_initialized: + # if self.estimate_atomref: + # atrefs = _datamodule.get_atomrefs( + # property=self._property, is_extensive=self.is_extensive + # ) + # else: + # atrefs = _datamodule.train_dataset.atomrefs + # self.atomref = atrefs[self._property].detach() + + # if self.remove_mean and not self._mean_initialized: + # stats = _datamodule.get_stats( + # self._property, self.is_extensive, self.remove_atomrefs + # ) + # self.mean = stats[0].detach() def forward( self, @@ -200,12 +236,32 @@ def __init__( scale = scale or torch.ones((1,)) self.register_buffer("scale", scale) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None) -> None: + """ + Initialize scaling using training statistics. + """ if not self._initialized: - stats = _datamodule.get_stats(self._target_key, True, False) - scale = stats[0] if self._scale_by_mean else stats[1] + mean, std = provider.get_stats(self._target_key, True, False) + scale = mean if self._scale_by_mean else std self.scale = torch.abs(scale).detach() + + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + provider = StatsAtomrefProvider( + train_dataloader_factory=_datamodule.train_dataloader, + train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), + ) + return self.initialize(provider, atomrefs=None) + + # def datamodule(self, _datamodule): + # if not self._initialized: + # stats = _datamodule.get_stats(self._target_key, True, False) + # scale = stats[0] if self._scale_by_mean else stats[1] + # self.scale = torch.abs(scale).detach() + def forward( self, inputs: Dict[str, torch.Tensor], @@ -281,44 +337,119 @@ def __init__( self.register_buffer("atomref", atomrefs) self.register_buffer("mean", property_mean) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None) -> None: + """ + Initialize mean and/or atomref using a StatsAtomrefProvider. + """ if self.add_atomrefs and not self._atomrefs_initialized: if self.estimate_atomref: - atrefs = _datamodule.get_atomrefs( - property=self._property, is_extensive=self.is_extensive - ) + atrefs = provider.get_atomrefs(self._property, self.is_extensive) else: - atrefs = _datamodule.train_dataset.atomrefs + if atomrefs is None: + raise RuntimeError( + "AddOffsets requires dataset atomrefs when estimate_atomref=False." + ) + atrefs = atomrefs self.atomref = atrefs[self._property].detach() if self.add_mean and not self._mean_initialized: - stats = _datamodule.get_stats( + mean, _std = provider.get_stats( self._property, self.is_extensive, self.add_atomrefs ) - self.mean = stats[0].detach() + self.mean = mean.detach() - def forward( - self, - inputs: Dict[str, torch.Tensor], - ) -> Dict[str, torch.Tensor]: + + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + + provider = StatsAtomrefProvider( + train_dataloader_factory=_datamodule.train_dataloader, + train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), + ) + atomrefs = getattr(_datamodule.train_dataset, "atomrefs", None) + return self.initialize(provider, atomrefs=atomrefs) + + # def datamodule(self, _datamodule): + # if self.add_atomrefs and not self._atomrefs_initialized: + # if self.estimate_atomref: + # atrefs = _datamodule.get_atomrefs( + # property=self._property, is_extensive=self.is_extensive + # ) + # else: + # atrefs = _datamodule.train_dataset.atomrefs + # self.atomref = atrefs[self._property].detach() + + # if self.add_mean and not self._mean_initialized: + # stats = _datamodule.get_stats( + # self._property, self.is_extensive, self.add_atomrefs + # ) + # self.mean = stats[0].detach() + + # def forward( + # self, + # inputs: Dict[str, torch.Tensor], + # ) -> Dict[str, torch.Tensor]: + # if self.add_mean: + # mean = ( + # self.mean * inputs[structure.n_atoms] + # if self.is_extensive + # else self.mean + # ) + # inputs[self._property] += mean + + # if self.add_atomrefs: + # idx_m = inputs[structure.idx_m] + # y0i = self.atomref[inputs[structure.Z]] + # maxm = int(idx_m[-1]) + 1 + + # y0 = scatter_add(y0i, idx_m, dim_size=maxm) + + # if not self.is_extensive: + # y0 /= inputs[structure.n_atoms] + + # inputs[self._property] += y0 + + # return inputs + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.add_mean: - mean = ( - self.mean * inputs[structure.n_atoms] - if self.is_extensive - else self.mean - ) - inputs[self._property] += mean + mean = self.mean + if self.is_extensive: + # n_atoms: (B,) in batch, (1,) in single + mean = mean * inputs[structure.n_atoms].to(mean.dtype) + # If mean is (1,) and n_atoms is (B,), broadcasting yields (B,) + inputs[self._property] = inputs[self._property] + mean if self.add_atomrefs: - idx_m = inputs[structure.idx_m] - y0i = self.atomref[inputs[structure.Z]] - maxm = int(idx_m[-1]) + 1 + z = inputs[structure.Z] + y0i = self.atomref[z] # (N_atoms, ...) ; usually (N_atoms,) - y0 = scatter_add(y0i, idx_m, dim_size=maxm) + if structure.idx_m in inputs: + # Batched path + idx_m = inputs[structure.idx_m] + maxm = int(idx_m[-1]) + 1 if idx_m.numel() > 0 else 0 + y0 = scatter_add(y0i, idx_m, dim_size=maxm) # (B, ...) or (B,) - if not self.is_extensive: - y0 /= inputs[structure.n_atoms] + if not self.is_extensive: + n_atoms = inputs[structure.n_atoms].to(y0.dtype) # (B,) + # Make n_atoms broadcast with y0 if y0 is (B, k) + while n_atoms.dim() < y0.dim(): + n_atoms = n_atoms.unsqueeze(-1) + y0 = y0 / n_atoms - inputs[self._property] += y0 + inputs[self._property] = inputs[self._property] + y0 - return inputs + else: + # Single-system path (no idx_m) + y0 = y0i.sum(dim=0) # scalar () or vector (k,) + + if not self.is_extensive: + n_atoms = inputs[structure.n_atoms].to(y0.dtype) # usually (1,) + # reduce n_atoms to scalar to avoid odd broadcasting + n_atoms_scalar = n_atoms.view(-1)[0] + y0 = y0 / n_atoms_scalar + + inputs[self._property] = inputs[self._property] + y0 + + return inputs \ No newline at end of file diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index 77535c200..75ce0a3ce 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -import schnetpack as spk +#import schnetpack as spk __all__ = [ "Transform", @@ -48,3 +48,10 @@ def forward( def teardown(self): pass + + def initialize(self, provider, atomrefs=None) -> None: + """ + Preferred initialization hook (DataModule-free). + Transforms that require training stats/atomrefs override this. + """ + return From b043e7e24d764f527fd1d08a45cb5a9c67509fe7 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 22:19:37 +0100 Subject: [PATCH 03/68] style: improve code formatting and readability in multiple files --- src/schnetpack/data/datamodule_v2.py | 22 ++++++++++++++++------ src/schnetpack/data/provider.py | 14 ++++++++++---- src/schnetpack/datasets/qm9.py | 8 ++++++-- src/schnetpack/transform/atomistic.py | 6 ++---- src/schnetpack/transform/base.py | 2 +- 5 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 5a1a54a68..5a70dd769 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -71,8 +71,12 @@ def __init__( self.train_sampler_args = train_sampler_args or {} self.num_workers = num_workers - self.num_val_workers = num_val_workers if num_val_workers is not None else num_workers - self.num_test_workers = num_test_workers if num_test_workers is not None else num_workers + self.num_val_workers = ( + num_val_workers if num_val_workers is not None else num_workers + ) + self.num_test_workers = ( + num_test_workers if num_test_workers is not None else num_workers + ) self.pin_memory = pin_memory self.train_idx: Optional[List[int]] = None @@ -136,9 +140,15 @@ def setup(self, stage: Optional[str] = None) -> None: # Initialize transforms (V2 path: datamodule-free) train_atomrefs = getattr(self._train_dataset, "atomrefs", None) - self._initialize_transform_list(self.train_transforms, train_atomrefs=train_atomrefs) - self._initialize_transform_list(self.val_transforms, train_atomrefs=train_atomrefs) - self._initialize_transform_list(self.test_transforms, train_atomrefs=train_atomrefs) + self._initialize_transform_list( + self.train_transforms, train_atomrefs=train_atomrefs + ) + self._initialize_transform_list( + self.val_transforms, train_atomrefs=train_atomrefs + ) + self._initialize_transform_list( + self.test_transforms, train_atomrefs=train_atomrefs + ) # Attach transforms after init (matches legacy behavior) self._train_dataset.transforms = self.train_transforms @@ -240,4 +250,4 @@ def test_dataloader(self) -> AtomsLoader: num_workers=self.num_test_workers, pin_memory=self.pin_memory, ) - return self._test_loader \ No newline at end of file + return self._test_loader diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py index 99e31255b..f708ee814 100644 --- a/src/schnetpack/data/provider.py +++ b/src/schnetpack/data/provider.py @@ -21,7 +21,9 @@ class StatsAtomrefProvider: train_dataloader_factory: Callable[[], AtomsLoader] train_atomrefs: Optional[Dict[str, torch.Tensor]] = None - _stats_cache: Optional[Dict[Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor]]] = None + _stats_cache: Optional[ + Dict[Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor]] + ] = None _atomref_cache: Optional[Dict[Tuple[str, bool], torch.Tensor]] = None def __post_init__(self) -> None: @@ -57,7 +59,9 @@ def get_stats( self._stats_cache[key] = stats return stats - def get_atomrefs(self, property: str, is_extensive: bool) -> Dict[str, torch.Tensor]: + def get_atomrefs( + self, property: str, is_extensive: bool + ) -> Dict[str, torch.Tensor]: """ Args: property: property key @@ -70,7 +74,9 @@ def get_atomrefs(self, property: str, is_extensive: bool) -> Dict[str, torch.Ten return {property: self._atomref_cache[key]} loader = self.train_dataloader_factory() - atomref = estimate_atomrefs(loader, is_extensive={property: is_extensive})[property] + atomref = estimate_atomrefs(loader, is_extensive={property: is_extensive})[ + property + ] self._atomref_cache[key] = atomref - return {property: atomref} \ No newline at end of file + return {property: atomref} diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 3c2675afc..dee30403f 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -211,7 +211,11 @@ def prepare_data(self) -> None: datapath=self.datapath, format=self.format, distance_unit=self.distance_unit or "Ang", - property_unit_dict=property_unit_dict if self.property_units is None else self.property_units, + property_unit_dict=( + property_unit_dict + if self.property_units is None + else self.property_units + ), atomrefs=atomrefs, ) @@ -338,4 +342,4 @@ def _download_data( logging.info("Write atoms to db...") dataset.add_systems(property_list=property_list) - logging.info("Done.") \ No newline at end of file + logging.info("Done.") diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index 358387086..354dabc44 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -140,7 +140,7 @@ def initialize(self, provider, atomrefs=None) -> None: ) self.mean = mean.detach() - #legacy hook for old AtomsDataModule + # legacy hook for old AtomsDataModule def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. @@ -245,7 +245,6 @@ def initialize(self, provider, atomrefs=None) -> None: scale = mean if self._scale_by_mean else std self.scale = torch.abs(scale).detach() - def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. @@ -358,7 +357,6 @@ def initialize(self, provider, atomrefs=None) -> None: ) self.mean = mean.detach() - def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. @@ -452,4 +450,4 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: inputs[self._property] = inputs[self._property] + y0 - return inputs \ No newline at end of file + return inputs diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index 75ce0a3ce..f2b51eb40 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -#import schnetpack as spk +# import schnetpack as spk __all__ = [ "Transform", From 6e67e015fdb963a133cbaba4d35e449e27cd06b8 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 22:53:28 +0100 Subject: [PATCH 04/68] refactor: simplify AtomsDataModuleV2 by removing unused parameters and improving clarity --- src/schnetpack/data/datamodule_v2.py | 103 +++++++++------------------ 1 file changed, 35 insertions(+), 68 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 5a70dd769..6e5ec6f07 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,11 +1,10 @@ from __future__ import annotations from copy import copy -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import pytorch_lightning as pl -from torch.utils.data import BatchSampler from schnetpack.data.atoms import BaseAtomsData from schnetpack.data.loader import AtomsLoader @@ -20,7 +19,6 @@ class AtomsDataModuleV2(pl.LightningDataModule): - handles splitting + loaders/batching - builds StatsAtomrefProvider from train split - initializes transforms via t.initialize(provider, atomrefs=...) - (no t.datamodule() dependency in the V2 path) """ def __init__( @@ -36,24 +34,21 @@ def __init__( train_transforms: Optional[List] = None, val_transforms: Optional[List] = None, test_transforms: Optional[List] = None, - train_sampler_cls: Optional[Type] = None, - train_sampler_args: Optional[Dict[str, Any]] = None, - num_workers: int = 8, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, - pin_memory: bool = False, + num_workers: int = 0, strict_transform_init: bool = True, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, + loader_kwargs: Optional[Dict[str, Any]] = None, + val_loader_kwargs: Optional[Dict[str, Any]] = None, + test_loader_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, # swallow legacy knobs without breaking configs ): super().__init__() + if kwargs: + pass + self.dataset = dataset self.batch_size = batch_size - self.val_batch_size = val_batch_size or batch_size - self.test_batch_size = test_batch_size or batch_size - self.num_train = num_train self.num_val = num_val self.num_test = num_test @@ -61,35 +56,28 @@ def __init__( self.split_file = split_file self.splitting = splitting or RandomSplit() - # If transforms passed, replicate (copy) for each split unless split-specific provided self._train_transforms = train_transforms or copy(transforms) or [] self._val_transforms = val_transforms or copy(transforms) or [] self._test_transforms = test_transforms or copy(transforms) or [] self.strict_transform_init = strict_transform_init - self.train_sampler_cls = train_sampler_cls - self.train_sampler_args = train_sampler_args or {} - self.num_workers = num_workers - self.num_val_workers = ( - num_val_workers if num_val_workers is not None else num_workers - ) - self.num_test_workers = ( - num_test_workers if num_test_workers is not None else num_workers - ) - self.pin_memory = pin_memory - self.train_idx: Optional[List[int]] = None - self.val_idx: Optional[List[int]] = None - self.test_idx: Optional[List[int]] = None + self.loader_kwargs = loader_kwargs or {} + self.val_loader_kwargs = val_loader_kwargs or {} + self.test_loader_kwargs = test_loader_kwargs or {} + + self.train_idx = None + self.val_idx = None + self.test_idx = None - self._train_dataset: Optional[BaseAtomsData] = None - self._val_dataset: Optional[BaseAtomsData] = None - self._test_dataset: Optional[BaseAtomsData] = None + self._train_dataset = None + self._val_dataset = None + self._test_dataset = None - self._train_loader: Optional[AtomsLoader] = None - self._val_loader: Optional[AtomsLoader] = None - self._test_loader: Optional[AtomsLoader] = None + self._train_loader = None + self._val_loader = None + self._test_loader = None self.provider: Optional[StatsAtomrefProvider] = None @@ -127,19 +115,17 @@ def setup(self, stage: Optional[str] = None) -> None: if self.train_idx is None: self._load_partitions() - # Create split datasets (no transforms attached yet) self._train_dataset = self.dataset.subset(self.train_idx) self._val_dataset = self.dataset.subset(self.val_idx) self._test_dataset = self.dataset.subset(self.test_idx) - # Build provider bound to train loader factory (loader created on demand) + train_atomrefs = getattr(self._train_dataset, "atomrefs", None) + self.provider = StatsAtomrefProvider( train_dataloader_factory=self.train_dataloader, - train_atomrefs=getattr(self._train_dataset, "atomrefs", None), + train_atomrefs=train_atomrefs, ) - # Initialize transforms (V2 path: datamodule-free) - train_atomrefs = getattr(self._train_dataset, "atomrefs", None) self._initialize_transform_list( self.train_transforms, train_atomrefs=train_atomrefs ) @@ -150,7 +136,6 @@ def setup(self, stage: Optional[str] = None) -> None: self.test_transforms, train_atomrefs=train_atomrefs ) - # Attach transforms after init (matches legacy behavior) self._train_dataset.transforms = self.train_transforms self._val_dataset.transforms = self.val_transforms self._test_dataset.transforms = self.test_transforms @@ -158,13 +143,11 @@ def setup(self, stage: Optional[str] = None) -> None: def _initialize_transform_list(self, transforms: List, train_atomrefs): if not transforms: return - for t in transforms: init_fn = getattr(t, "initialize", None) if callable(init_fn): init_fn(self.provider, atomrefs=train_atomrefs) continue - if self.strict_transform_init: raise RuntimeError( f"Transform {type(t).__name__} does not implement initialize(provider, atomrefs=...)." @@ -202,33 +185,17 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: if self.split_file is not None: np.savez( - self.split_file, - train_idx=self.train_idx, - val_idx=self.val_idx, - test_idx=self.test_idx, + self.split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx ) - def _setup_train_batch_sampler(self): - if self.train_sampler_cls is None: - return None - - sampler = self.train_sampler_cls( - data_source=self.train_dataset, - num_samples=len(self.train_dataset), - **self.train_sampler_args, - ) - return BatchSampler(sampler=sampler, batch_size=self.batch_size, drop_last=True) - def train_dataloader(self) -> AtomsLoader: if self._train_loader is None: - batch_sampler = self._setup_train_batch_sampler() self._train_loader = AtomsLoader( self.train_dataset, - batch_size=self.batch_size if batch_sampler is None else 1, - shuffle=batch_sampler is None, - batch_sampler=batch_sampler, + batch_size=self.batch_size, + shuffle=True, num_workers=self.num_workers, - pin_memory=self.pin_memory, + **self.loader_kwargs, ) return self._train_loader @@ -236,9 +203,9 @@ def val_dataloader(self) -> AtomsLoader: if self._val_loader is None: self._val_loader = AtomsLoader( self.val_dataset, - batch_size=self.val_batch_size, - num_workers=self.num_val_workers, - pin_memory=self.pin_memory, + batch_size=self.batch_size, + num_workers=self.num_workers, + **{**self.loader_kwargs, **self.val_loader_kwargs}, ) return self._val_loader @@ -246,8 +213,8 @@ def test_dataloader(self) -> AtomsLoader: if self._test_loader is None: self._test_loader = AtomsLoader( self.test_dataset, - batch_size=self.test_batch_size, - num_workers=self.num_test_workers, - pin_memory=self.pin_memory, + batch_size=self.batch_size, + num_workers=self.num_workers, + **{**self.loader_kwargs, **self.test_loader_kwargs}, ) return self._test_loader From f7dd667ba3637395f76b255aec3afb9ce5f6239a Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 23:09:12 +0100 Subject: [PATCH 05/68] docs: update AtomsDataModuleV2 docstring for clarity by removing redundant information --- src/schnetpack/data/datamodule_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 6e5ec6f07..235b0bd26 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -15,7 +15,7 @@ class AtomsDataModuleV2(pl.LightningDataModule): """ V2 DataModule: - - accepts a dataset instance (datasets are independent) + - accepts a dataset instance - handles splitting + loaders/batching - builds StatsAtomrefProvider from train split - initializes transforms via t.initialize(provider, atomrefs=...) @@ -150,7 +150,7 @@ def _initialize_transform_list(self, transforms: List, train_atomrefs): continue if self.strict_transform_init: raise RuntimeError( - f"Transform {type(t).__name__} does not implement initialize(provider, atomrefs=...)." + f"Transform {type(t).__name__} does not implement initialize." ) def _load_partitions(self) -> None: From 8b5a014440605e47429292a30efac95d2fa4148a Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 23:18:59 +0100 Subject: [PATCH 06/68] test: add pytests for AtomsDataset and AtomsDataModuleV2 functionality --- tests/data/test_refactor.py | 235 ++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 tests/data/test_refactor.py diff --git a/tests/data/test_refactor.py b/tests/data/test_refactor.py new file mode 100644 index 000000000..f3f6f9465 --- /dev/null +++ b/tests/data/test_refactor.py @@ -0,0 +1,235 @@ +import pytest +import torch +import numpy as np + +import schnetpack.properties as structure +import schnetpack.data.provider as providers_mod +from schnetpack.data.datamodule_v2 import AtomsDataModuleV2 +from schnetpack.transform.atomistic import AddOffsets, RemoveOffsets, ScaleProperty + + +class AtomsDataset: + """ + Adapter to make the existing `example_data` list fixture behave like a dataset. + """ + + def __init__(self, data): + self._data = list(data) + self.transforms = [] + + _, props0 = self._data[0] + self.available_properties = list(props0.keys()) + + def __len__(self): + return len(self._data) + + def subset(self, indices): + # indices can be list/np array + idx_list = list(indices) + sub = AtomsDataset([self._data[i] for i in idx_list]) + # do not carry transforms automatically; DM will attach per split + return sub + + def __getitem__(self, idx): + ats, props = self._data[idx] + out = {} + + # Structure keys expected by SchNetPack transforms/loader + out[structure.Z] = torch.tensor(ats.numbers, dtype=torch.long) + out[structure.R] = torch.tensor(np.asarray(ats.positions), dtype=torch.float) + out[structure.cell] = torch.tensor( + np.asarray(ats.cell.array), dtype=torch.float + ) + out[structure.pbc] = torch.tensor(np.asarray(ats.pbc), dtype=torch.bool) + out[structure.n_atoms] = torch.tensor([len(ats.numbers)], dtype=torch.long) + + # Add properties + for k, v in props.items(): + # ensure torch tensor + if isinstance(v, torch.Tensor): + out[k] = v + else: + out[k] = torch.tensor(np.asarray(v), dtype=torch.float) + + # Apply transforms per-system (like BaseAtomsData.__getitem__) + for t in self.transforms: + out = t(out) + + return out + + +def _first_scalar_property_key(dataset: AtomsDataset): + # choose the first property key, but ensure it's not a structure key + for p in dataset.available_properties: + if p not in ( + structure.Z, + structure.R, + structure.cell, + structure.pbc, + structure.n_atoms, + ): + return p + raise AssertionError( + "No suitable scalar property found in dataset.available_properties" + ) + + +def _make_constant_atomrefs(zmax=100, value=1.0): + atref = torch.zeros((zmax,), dtype=torch.float) + atref[:] = float(value) + return atref + + +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_v2_setup_attaches_transforms(example_data, batch_size): + dataset = AtomsDataset(example_data) + + dm = AtomsDataModuleV2( + dataset=dataset, + batch_size=batch_size, + num_train=0.6, + num_val=0.2, + num_test=0.2, + transforms=[], + num_workers=0, + split_file=None, + ) + dm.setup() + + assert dm.train_dataset is not None + assert dm.val_dataset is not None + assert dm.test_dataset is not None + + # transforms attribute exists and is settable + dm.train_dataset.transforms = [] + dm.val_dataset.transforms = [] + dm.test_dataset.transforms = [] + + +def test_provider_initializes_stats_transforms(example_data): + dataset = AtomsDataset(example_data) + prop = _first_scalar_property_key(dataset) + + transforms = [ + RemoveOffsets( + property=prop, remove_mean=True, remove_atomrefs=False, is_extensive=True + ), + ScaleProperty( + input_key=prop, target_key=prop, output_key=prop, scale_by_mean=False + ), + ] + + dm = AtomsDataModuleV2( + dataset=dataset, + batch_size=4, + num_train=0.6, + num_val=0.2, + num_test=0.2, + transforms=transforms, + num_workers=0, + split_file=None, + ) + dm.setup() + + ro = transforms[0] + sp = transforms[1] + + assert hasattr(ro, "mean") + assert ro.mean is not None + assert hasattr(sp, "scale") + assert sp.scale is not None + + +@pytest.mark.parametrize("is_extensive", [True, False]) +def test_addoffsets_unbatched_and_batched(example_data, is_extensive): + dataset = AtomsDataset(example_data) + prop = _first_scalar_property_key(dataset) + + zmax = 100 + atomref_tensor = _make_constant_atomrefs(zmax=zmax, value=1.0) + + t = AddOffsets( + property=prop, + add_mean=False, + add_atomrefs=True, + is_extensive=is_extensive, + zmax=zmax, + atomrefs=atomref_tensor, + ) + + dm = AtomsDataModuleV2( + dataset=dataset, + batch_size=4, + num_train=0.6, + num_val=0.2, + num_test=0.2, + transforms=[t], + num_workers=0, + split_file=None, + ) + dm.setup() + + # Unbatched: dataset[0] (transform runs in __getitem__) + old = dm.train_dataset.transforms + dm.train_dataset.transforms = [] + raw = dm.train_dataset[0] + dm.train_dataset.transforms = old + one = dm.train_dataset[0] + + y_raw = raw[prop] + y_one = one[prop] + delta = (y_one - y_raw).detach().view(-1)[0] + + n_atoms = int(one[structure.n_atoms].view(-1)[0].item()) + expected = float(n_atoms) if is_extensive else 1.0 + + assert torch.allclose(delta, torch.tensor(expected, dtype=delta.dtype), atol=1e-6) + + # Batched: loader should include idx_m and not crash + batch = next(iter(dm.train_dataloader())) + assert structure.idx_m in batch + assert prop in batch + + +def test_provider_caches_stats_calls(example_data, monkeypatch): + """ + Ensure provider caching prevents recomputing the same stats key multiple times. + In this test transforms request the same key when: + - RemoveOffsets is_extensive=True and remove_atomrefs=False -> (prop, True, False) + - ScaleProperty always requests (prop, True, False) + """ + + dataset = AtomsDataset(example_data) + prop = _first_scalar_property_key(dataset) + + call_count = {"n": 0} + real_calculate_stats = providers_mod.calculate_stats + + def wrapped_calculate_stats(*args, **kwargs): + call_count["n"] += 1 + return real_calculate_stats(*args, **kwargs) + + monkeypatch.setattr(providers_mod, "calculate_stats", wrapped_calculate_stats) + + transforms = [ + RemoveOffsets( + property=prop, remove_mean=True, remove_atomrefs=False, is_extensive=True + ), + ScaleProperty( + input_key=prop, target_key=prop, output_key=prop, scale_by_mean=False + ), + ] + + dm = AtomsDataModuleV2( + dataset=dataset, + batch_size=4, + num_train=0.6, + num_val=0.2, + num_test=0.2, + transforms=transforms, + num_workers=0, + split_file=None, + ) + dm.setup() + + assert call_count["n"] == 1 From 215587e3f0920071ac0803f2d8f2fd83eb8d67b8 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 22 Feb 2026 23:55:09 +0100 Subject: [PATCH 07/68] refactor: update QM9 and StatsAtomrefProvider docstrings for clarity and conciseness --- src/schnetpack/data/provider.py | 3 --- src/schnetpack/datasets/qm9.py | 13 ++++--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py index f708ee814..fd5a68a1e 100644 --- a/src/schnetpack/data/provider.py +++ b/src/schnetpack/data/provider.py @@ -13,9 +13,6 @@ class StatsAtomrefProvider: """ Compute and cache statistics and atom references from the *training split*. - - This replaces the logic that used to live in AtomsDataModule.get_stats/get_atomrefs, - without requiring a DataModule instance in transforms or datasets. """ train_dataloader_factory: Callable[[], AtomsLoader] diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index dee30403f..ae53667aa 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -27,10 +27,10 @@ class QM9: - """QM9 benchmark database for organic molecules (dataset-only). + """QM9 benchmark database for organic molecules. This class: - - is a dataset wrapper (no Lightning DataModule inheritance) + - is a dataset wrapper - can download + build the dataset via prepare() - forwards the BaseAtomsData API to an underlying dataset instance """ @@ -126,17 +126,16 @@ def __getitem__(self, idx: int): def subset(self, indices): """ Forward subset() to underlying dataset. - Returns a BaseAtomsData-like object (whatever the backend returns). + Returns a BaseAtomsData-like object. """ self._ensure_loaded() sub = self._dataset.subset(indices) # Ensure transforms are carried over if caller expects it - # (DataModuleV2 will set per-split transforms anyway) + # (DataModuleV2 will set per-split transforms) if getattr(sub, "transforms", None) is None: sub.transforms = [] return sub - # Common attributes used elsewhere in SchNetPack @property def available_properties(self): self._ensure_loaded() @@ -259,10 +258,6 @@ def prepare_data(self) -> None: ) self._dataset.transforms = self.transforms - # # keep Lightning naming for convenience if any code still calls it - # def prepare_data(self) -> None: - # self.prepare() - def _download_uncharacterized(self, tmpdir: str) -> List[int]: logging.info("Downloading list of uncharacterized molecules...") tmp_path = os.path.join(tmpdir, "uncharacterized.txt") From db210047696f64519fc83a8fb3ebcddae064a436 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 02:56:04 +0100 Subject: [PATCH 08/68] feat: refactor AtomsDataModuleV2 --- src/schnetpack/data/datamodule_v2.py | 127 +++++++-------------------- 1 file changed, 33 insertions(+), 94 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 235b0bd26..f7ef803f9 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,13 +1,12 @@ from __future__ import annotations from copy import copy -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import pytorch_lightning as pl from schnetpack.data.atoms import BaseAtomsData -from schnetpack.data.loader import AtomsLoader from schnetpack.data.provider import StatsAtomrefProvider from schnetpack.data.splitting import RandomSplit, SplittingStrategy @@ -16,9 +15,9 @@ class AtomsDataModuleV2(pl.LightningDataModule): """ V2 DataModule: - accepts a dataset instance - - handles splitting + loaders/batching + - handles splitting - builds StatsAtomrefProvider from train split - - initializes transforms via t.initialize(provider, atomrefs=...) + - initializes transforms """ def __init__( @@ -36,36 +35,22 @@ def __init__( test_transforms: Optional[List] = None, num_workers: int = 0, strict_transform_init: bool = True, - loader_kwargs: Optional[Dict[str, Any]] = None, - val_loader_kwargs: Optional[Dict[str, Any]] = None, - test_loader_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, # swallow legacy knobs without breaking configs ): super().__init__() - if kwargs: - pass - self.dataset = dataset - self.batch_size = batch_size self.num_train = num_train self.num_val = num_val self.num_test = num_test - self.split_file = split_file self.splitting = splitting or RandomSplit() - - self._train_transforms = train_transforms or copy(transforms) or [] - self._val_transforms = val_transforms or copy(transforms) or [] - self._test_transforms = test_transforms or copy(transforms) or [] - self.strict_transform_init = strict_transform_init - self.num_workers = num_workers + self.strict_transform_init = strict_transform_init - self.loader_kwargs = loader_kwargs or {} - self.val_loader_kwargs = val_loader_kwargs or {} - self.test_loader_kwargs = test_loader_kwargs or {} + self.train_transforms = train_transforms or copy(transforms) or [] + self.val_transforms = val_transforms or copy(transforms) or [] + self.test_transforms = test_transforms or copy(transforms) or [] self.train_idx = None self.val_idx = None @@ -75,24 +60,8 @@ def __init__( self._val_dataset = None self._test_dataset = None - self._train_loader = None - self._val_loader = None - self._test_loader = None - self.provider: Optional[StatsAtomrefProvider] = None - @property - def train_transforms(self): - return self._train_transforms - - @property - def val_transforms(self): - return self._val_transforms - - @property - def test_transforms(self): - return self._test_transforms - @property def train_dataset(self) -> BaseAtomsData: if self._train_dataset is None: @@ -119,36 +88,29 @@ def setup(self, stage: Optional[str] = None) -> None: self._val_dataset = self.dataset.subset(self.val_idx) self._test_dataset = self.dataset.subset(self.test_idx) - train_atomrefs = getattr(self._train_dataset, "atomrefs", None) - self.provider = StatsAtomrefProvider( - train_dataloader_factory=self.train_dataloader, - train_atomrefs=train_atomrefs, + train_dataset=self._train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, ) - self._initialize_transform_list( - self.train_transforms, train_atomrefs=train_atomrefs - ) - self._initialize_transform_list( - self.val_transforms, train_atomrefs=train_atomrefs - ) - self._initialize_transform_list( - self.test_transforms, train_atomrefs=train_atomrefs - ) + self._initialize_transform_list(self.train_transforms) + self._initialize_transform_list(self.val_transforms) + self._initialize_transform_list(self.test_transforms) self._train_dataset.transforms = self.train_transforms self._val_dataset.transforms = self.val_transforms self._test_dataset.transforms = self.test_transforms - def _initialize_transform_list(self, transforms: List, train_atomrefs): + def _initialize_transform_list(self, transforms: List) -> None: if not transforms: return + for t in transforms: init_fn = getattr(t, "initialize", None) if callable(init_fn): - init_fn(self.provider, atomrefs=train_atomrefs) - continue - if self.strict_transform_init: + init_fn(self.provider, atomrefs=self.provider.train_atomrefs) + elif self.strict_transform_init: raise RuntimeError( f"Transform {type(t).__name__} does not implement initialize." ) @@ -169,52 +131,29 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: num_val = _to_abs(self.num_val) num_test = _to_abs(self.num_test) - self.num_train, self.num_val, self.num_test = num_train, num_val, num_test + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test if self.split_file is not None and os.path.exists(self.split_file): - S = np.load(self.split_file) - self.train_idx = S["train_idx"].tolist() - self.val_idx = S["val_idx"].tolist() - self.test_idx = S["test_idx"].tolist() + split_data = np.load(self.split_file) + self.train_idx = split_data["train_idx"].tolist() + self.val_idx = split_data["val_idx"].tolist() + self.test_idx = split_data["test_idx"].tolist() return train_idx, val_idx, test_idx = self.splitting.split( self.dataset, num_train, num_val, num_test ) - self.train_idx, self.val_idx, self.test_idx = train_idx, val_idx, test_idx + + self.train_idx = train_idx + self.val_idx = val_idx + self.test_idx = test_idx if self.split_file is not None: np.savez( - self.split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx - ) - - def train_dataloader(self) -> AtomsLoader: - if self._train_loader is None: - self._train_loader = AtomsLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - **self.loader_kwargs, - ) - return self._train_loader - - def val_dataloader(self) -> AtomsLoader: - if self._val_loader is None: - self._val_loader = AtomsLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - **{**self.loader_kwargs, **self.val_loader_kwargs}, - ) - return self._val_loader - - def test_dataloader(self) -> AtomsLoader: - if self._test_loader is None: - self._test_loader = AtomsLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - **{**self.loader_kwargs, **self.test_loader_kwargs}, - ) - return self._test_loader + self.split_file, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + ) \ No newline at end of file From 01fb1ddf1a551a47c884c69b1a6fe577a224e188 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 02:56:16 +0100 Subject: [PATCH 09/68] refactor: update StatsAtomrefProvider to use BaseAtomsData and simplify initialization --- src/schnetpack/data/provider.py | 60 ++++++++++++--------------------- 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py index fd5a68a1e..dce5a853c 100644 --- a/src/schnetpack/data/provider.py +++ b/src/schnetpack/data/provider.py @@ -1,54 +1,38 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Dict, Optional, Tuple import torch -from schnetpack.data.loader import AtomsLoader +from schnetpack.data.atoms import BaseAtomsData from schnetpack.data.stats import calculate_stats, estimate_atomrefs -@dataclass class StatsAtomrefProvider: """ - Compute and cache statistics and atom references from the *training split*. + Compute and cache statistics and atom references from the training dataset. """ - train_dataloader_factory: Callable[[], AtomsLoader] - train_atomrefs: Optional[Dict[str, torch.Tensor]] = None + def __init__(self, train_dataset: BaseAtomsData) -> None: + self.train_dataset = train_dataset + self.train_atomrefs = getattr(train_dataset, "atomrefs", None) - _stats_cache: Optional[ - Dict[Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor]] - ] = None - _atomref_cache: Optional[Dict[Tuple[str, bool], torch.Tensor]] = None - - def __post_init__(self) -> None: - if self._stats_cache is None: - self._stats_cache = {} - if self._atomref_cache is None: - self._atomref_cache = {} + self._stats_cache: Dict[ + Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor] + ] = {} + self._atomref_cache: Dict[Tuple[str, bool], torch.Tensor] = {} def get_stats( self, property: str, divide_by_atoms: bool, remove_atomref: bool ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - property: property key - divide_by_atoms: if True, compute stats of property / n_atoms - remove_atomref: if True, subtract atomref prior to stats computation - Returns: - (mean, std) tensors - """ key = (property, divide_by_atoms, remove_atomref) if key in self._stats_cache: return self._stats_cache[key] - loader = self.train_dataloader_factory() atomref = self.train_atomrefs if remove_atomref else None stats = calculate_stats( - loader, + self.train_dataset, divide_by_atoms={property: divide_by_atoms}, atomref=atomref, )[property] @@ -59,21 +43,19 @@ def get_stats( def get_atomrefs( self, property: str, is_extensive: bool ) -> Dict[str, torch.Tensor]: - """ - Args: - property: property key - is_extensive: whether property is extensive - Returns: - dict {property: atomref_tensor} - """ + # 1) If dataset already has atomrefs for this property, use them directly + if self.train_atomrefs is not None and property in self.train_atomrefs: + return {property: self.train_atomrefs[property]} + + # 2) Otherwise estimate and cache key = (property, is_extensive) if key in self._atomref_cache: return {property: self._atomref_cache[key]} - loader = self.train_dataloader_factory() - atomref = estimate_atomrefs(loader, is_extensive={property: is_extensive})[ - property - ] + atomref = estimate_atomrefs( + self.train_dataset, + is_extensive={property: is_extensive}, + )[property] self._atomref_cache[key] = atomref - return {property: atomref} + return {property: atomref} \ No newline at end of file From 6d897a92b593b04f4f39c3d9e2346dc25f0c0b2e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 02:56:40 +0100 Subject: [PATCH 10/68] refactor: update calculate_stats and estimate_atomrefs to use BaseAtomsData --- src/schnetpack/data/stats.py | 118 +++++++++++++++++------------------ 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 7276c8149..dc0a9db6e 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -1,37 +1,33 @@ -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional, Any import torch from tqdm import tqdm import schnetpack.properties as properties -from schnetpack.data import AtomsLoader +from schnetpack.data.atoms import BaseAtomsData +from schnetpack.data.loader import AtomsLoader __all__ = ["calculate_stats", "estimate_atomrefs"] def calculate_stats( - dataloader: AtomsLoader, + dataset: BaseAtomsData, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] = None, + batch_size: int = 10000, + num_workers: int = 4, + loader_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: - """ - Use the incremental Welford algorithm described in [h1]_ to accumulate - the mean and standard deviation over a set of samples. - - References: - ----------- - .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - - Args: - dataloader: data loader - divide_by_atoms: dict from property name to bool: - If True, divide property by number of atoms before calculating statistics. - atomref: reference values for single atoms to be removed before calculating stats - - Returns: - Mean and standard deviation over all samples + loader_kwargs = loader_kwargs or {} + + dataloader = AtomsLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + **loader_kwargs, + ) - """ property_names = list(divide_by_atoms.keys()) norm_mask = torch.tensor( [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64 @@ -45,7 +41,8 @@ def calculate_stats( sample_values = [] for p in property_names: val = props[p][None, :] - if atomref and p in atomref.keys(): + + if atomref and p in atomref: ar = atomref[p] ar = ar[props[properties.Z]] idx_m = props[properties.idx_m] @@ -54,90 +51,91 @@ def calculate_stats( val -= v0 sample_values.append(val) + sample_values = torch.cat(sample_values, dim=0) - batch_size = sample_values.shape[1] - new_count = count + batch_size + batch_n = sample_values.shape[1] + new_count = count + batch_n norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + ( 1 - norm_mask[:, None] ) - sample_values /= norm + sample_values = sample_values / norm sample_mean = torch.mean(sample_values, dim=1) sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1) delta = sample_mean - mean - mean += delta * batch_size / new_count - corr = batch_size * count / new_count + mean += delta * batch_n / new_count + corr = batch_n * count / new_count M2 += sample_m2 + delta**2 * corr count = new_count stddev = torch.sqrt(M2 / count) - stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} - return stats - - -def estimate_atomrefs(dataloader, is_extensive, z_max=100): - """ - Uses linear regression to estimate the elementwise biases (atomrefs). - - Args: - dataloader: data loader - is_extensive: If True, divide atom type counts by number of atoms before - calculating statistics. - - Returns: - Elementwise bias estimates over all samples + return {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} + + +def estimate_atomrefs( + dataset: BaseAtomsData, + is_extensive: Dict[str, bool], + z_max: int = 100, + batch_size: int = 10000, + num_workers: int = 4, + loader_kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, torch.Tensor]: + loader_kwargs = loader_kwargs or {} + + dataloader = AtomsLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + **loader_kwargs, + ) - """ property_names = list(is_extensive.keys()) - n_data = len(dataloader.dataset) + n_data = len(dataset) + all_properties = {pname: torch.zeros(n_data) for pname in property_names} all_atom_types = torch.zeros((n_data, z_max)) data_counter = 0 - # loop over all batches for batch in tqdm(dataloader, "estimating atomrefs"): - # load data idx_m = batch[properties.idx_m] atomic_numbers = batch[properties.Z] - # get counts for atomic numbers - unique_ids = torch.unique(idx_m) - for i in unique_ids: + for i in torch.unique(idx_m): atomic_numbers_i = atomic_numbers[idx_m == i] atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True) - # save atom counts and properties + for atom_type, atom_count in zip(atom_types, atom_counts): all_atom_types[data_counter, atom_type] = atom_count + for pname in property_names: property_value = batch[pname][i] if not is_extensive[pname]: property_value *= batch[properties.n_atoms][i] all_properties[pname][data_counter] = property_value + data_counter += 1 - # perform linear regression to get the elementwise energy contributions existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0] X = torch.squeeze(all_atom_types[:, existing_atom_types]) - w = dict() + + weights = {} for pname in property_names: if is_extensive[pname]: - w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] + weights[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] else: - w[pname] = ( + weights[pname] = ( torch.linalg.inv(X.T @ X) @ X.T @ (all_properties[pname] / X.sum(axis=1)) ) - # compute energy estimates - elementwise_contributions = { - pname: torch.zeros((z_max)) for pname in property_names - } + out = {pname: torch.zeros((z_max,)) for pname in property_names} for pname in property_names: - for atom_type, weight in zip(existing_atom_types, w[pname]): - elementwise_contributions[pname][atom_type] = weight + for atom_type, weight in zip(existing_atom_types, weights[pname]): + out[pname][atom_type] = weight - return elementwise_contributions + return out \ No newline at end of file From 7af99a83390d017b4df73a10c657cbaa958b2bff Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 02:57:29 +0100 Subject: [PATCH 11/68] refactor: simplify initialization in StatsAtomrefProvider --- src/schnetpack/transform/atomistic.py | 144 +++++--------------------- 1 file changed, 26 insertions(+), 118 deletions(-) diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index 354dabc44..e42643825 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -119,7 +119,7 @@ def __init__( property_mean = property_mean or torch.zeros((1,)) self.register_buffer("mean", property_mean) - def initialize(self, provider, atomrefs=None) -> None: + def initialize(self, provider, atomrefs=None, **kwargs) -> None: """ Initialize mean and/or atomref using a StatsAtomrefProvider. """ @@ -145,32 +145,8 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ - - provider = StatsAtomrefProvider( - train_dataloader_factory=_datamodule.train_dataloader, - train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), - ) - atomrefs = getattr(_datamodule.train_dataset, "atomrefs", None) - return self.initialize(provider, atomrefs=atomrefs) - - # def datamodule(self, _datamodule): - # """ - # Sets mean and atomref automatically when using PyTorchLightning integration. - # """ - # if self.remove_atomrefs and not self._atomrefs_initialized: - # if self.estimate_atomref: - # atrefs = _datamodule.get_atomrefs( - # property=self._property, is_extensive=self.is_extensive - # ) - # else: - # atrefs = _datamodule.train_dataset.atomrefs - # self.atomref = atrefs[self._property].detach() - - # if self.remove_mean and not self._mean_initialized: - # stats = _datamodule.get_stats( - # self._property, self.is_extensive, self.remove_atomrefs - # ) - # self.mean = stats[0].detach() + provider = StatsAtomrefProvider(_datamodule.train_dataset) + return self.initialize(provider, atomrefs=provider.train_atomrefs) def forward( self, @@ -183,6 +159,7 @@ def forward( else self.mean ) inputs[self._property] -= mean + if self.remove_atomrefs: atomref_bias = torch.sum(self.atomref[inputs[structure.Z]]) if not self.is_extensive: @@ -249,18 +226,9 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ - provider = StatsAtomrefProvider( - train_dataloader_factory=_datamodule.train_dataloader, - train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), - ) + provider = StatsAtomrefProvider(_datamodule.train_dataset) return self.initialize(provider, atomrefs=None) - # def datamodule(self, _datamodule): - # if not self._initialized: - # stats = _datamodule.get_stats(self._target_key, True, False) - # scale = stats[0] if self._scale_by_mean else stats[1] - # self.scale = torch.abs(scale).detach() - def forward( self, inputs: Dict[str, torch.Tensor], @@ -361,93 +329,33 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ + provider = StatsAtomrefProvider(_datamodule.train_dataset) + return self.initialize(provider, atomrefs=provider.train_atomrefs) + - provider = StatsAtomrefProvider( - train_dataloader_factory=_datamodule.train_dataloader, - train_atomrefs=getattr(_datamodule.train_dataset, "atomrefs", None), - ) - atomrefs = getattr(_datamodule.train_dataset, "atomrefs", None) - return self.initialize(provider, atomrefs=atomrefs) - - # def datamodule(self, _datamodule): - # if self.add_atomrefs and not self._atomrefs_initialized: - # if self.estimate_atomref: - # atrefs = _datamodule.get_atomrefs( - # property=self._property, is_extensive=self.is_extensive - # ) - # else: - # atrefs = _datamodule.train_dataset.atomrefs - # self.atomref = atrefs[self._property].detach() - - # if self.add_mean and not self._mean_initialized: - # stats = _datamodule.get_stats( - # self._property, self.is_extensive, self.add_atomrefs - # ) - # self.mean = stats[0].detach() - - # def forward( - # self, - # inputs: Dict[str, torch.Tensor], - # ) -> Dict[str, torch.Tensor]: - # if self.add_mean: - # mean = ( - # self.mean * inputs[structure.n_atoms] - # if self.is_extensive - # else self.mean - # ) - # inputs[self._property] += mean - - # if self.add_atomrefs: - # idx_m = inputs[structure.idx_m] - # y0i = self.atomref[inputs[structure.Z]] - # maxm = int(idx_m[-1]) + 1 - - # y0 = scatter_add(y0i, idx_m, dim_size=maxm) - - # if not self.is_extensive: - # y0 /= inputs[structure.n_atoms] - - # inputs[self._property] += y0 - - # return inputs - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward( + self, + inputs: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: if self.add_mean: - mean = self.mean - if self.is_extensive: - # n_atoms: (B,) in batch, (1,) in single - mean = mean * inputs[structure.n_atoms].to(mean.dtype) - # If mean is (1,) and n_atoms is (B,), broadcasting yields (B,) - inputs[self._property] = inputs[self._property] + mean + mean = ( + self.mean * inputs[structure.n_atoms] + if self.is_extensive + else self.mean + ) + inputs[self._property] += mean if self.add_atomrefs: - z = inputs[structure.Z] - y0i = self.atomref[z] # (N_atoms, ...) ; usually (N_atoms,) - - if structure.idx_m in inputs: - # Batched path - idx_m = inputs[structure.idx_m] - maxm = int(idx_m[-1]) + 1 if idx_m.numel() > 0 else 0 - y0 = scatter_add(y0i, idx_m, dim_size=maxm) # (B, ...) or (B,) + idx_m = inputs[structure.idx_m] + y0i = self.atomref[inputs[structure.Z]] + maxm = int(idx_m[-1]) + 1 - if not self.is_extensive: - n_atoms = inputs[structure.n_atoms].to(y0.dtype) # (B,) - # Make n_atoms broadcast with y0 if y0 is (B, k) - while n_atoms.dim() < y0.dim(): - n_atoms = n_atoms.unsqueeze(-1) - y0 = y0 / n_atoms + y0 = scatter_add(y0i, idx_m, dim_size=maxm) - inputs[self._property] = inputs[self._property] + y0 - - else: - # Single-system path (no idx_m) - y0 = y0i.sum(dim=0) # scalar () or vector (k,) - - if not self.is_extensive: - n_atoms = inputs[structure.n_atoms].to(y0.dtype) # usually (1,) - # reduce n_atoms to scalar to avoid odd broadcasting - n_atoms_scalar = n_atoms.view(-1)[0] - y0 = y0 / n_atoms_scalar + if not self.is_extensive: + y0 /= inputs[structure.n_atoms] - inputs[self._property] = inputs[self._property] + y0 + inputs[self._property] += y0 return inputs + \ No newline at end of file From e5927437adfba7aa3b4199d57add3c32fff7dd6f Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 02:57:48 +0100 Subject: [PATCH 12/68] refactor: update Transform class --- src/schnetpack/transform/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index f2b51eb40..b8b7c7c02 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Dict import torch import torch.nn as nn @@ -32,6 +32,7 @@ class Transform(nn.Module): def datamodule(self, value): """ + Legacy hook for transforms initialized from an old AtomsDataModule. Extract all required information from data module automatically when using PyTorch Lightning integration. The transform should also implement a way to set these things manually, to make it usable independent of PL. @@ -49,9 +50,8 @@ def forward( def teardown(self): pass - def initialize(self, provider, atomrefs=None) -> None: + def initialize(self, **kwargs) -> None: """ - Preferred initialization hook (DataModule-free). - Transforms that require training stats/atomrefs override this. + Initialization hook for transforms that require training """ return From ba8d46fc06d58e62d6e1330300ddb19c15289915 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 2 Mar 2026 03:13:03 +0100 Subject: [PATCH 13/68] refactor: QM9 class by removing unused parameters and simplifying docstring --- src/schnetpack/datasets/qm9.py | 181 +++++++++------------------------ 1 file changed, 46 insertions(+), 135 deletions(-) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index ae53667aa..939648460 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -13,7 +13,6 @@ from ase.io.extxyz import read_xyz from tqdm import tqdm -import torch import schnetpack.properties as structure from schnetpack.data import ( AtomsDataFormat, @@ -27,12 +26,11 @@ class QM9: - """QM9 benchmark database for organic molecules. + """ + QM9 benchmark database downloader/builder. - This class: - - is a dataset wrapper - - can download + build the dataset via prepare() - - forwards the BaseAtomsData API to an underlying dataset instance + This class only prepares the QM9 dataset on disk. + `prepare()` returns the loaded dataset. """ base_urls = [ @@ -45,7 +43,6 @@ class QM9: "uncharacterized": "3195404", } - # properties A = "rotational_constant_A" B = "rotational_constant_B" C = "rotational_constant_C" @@ -70,118 +67,19 @@ def __init__( remove_uncharacterized: bool = False, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - transforms: Optional[List[torch.nn.Module]] = None, - **kwargs, ): - """ - Args: - datapath: path to dataset DB (e.g. qm9.db) - format: dataset format (ASE by default) - load_properties: subset of properties to load - remove_uncharacterized: if True, exclude uncharacterized molecules - property_units: optional unit overrides on load (passed to load_dataset/create_dataset) - distance_unit: optional distance unit override on load (passed to load_dataset/create_dataset) - transforms: optional default transforms (typically set by your DataModule per split) - **kwargs: reserved for forward compatibility - """ self.datapath = datapath self.format = format self.load_properties = load_properties self.remove_uncharacterized = remove_uncharacterized self.property_units = property_units self.distance_unit = distance_unit - self._kwargs = kwargs - - self._dataset: Optional[BaseAtomsData] = None - self.transforms = transforms or [] - - # ------------------------- - # Dataset forwarding helpers - # ------------------------- - def _ensure_loaded(self) -> None: - if self._dataset is None: - # Lazy-load if it already exists; otherwise require user to call prepare() - if not os.path.exists(self.datapath): - raise AtomsDataModuleError( - f"QM9 dataset not found at {self.datapath}. Call dataset.prepare() first." - ) - self._dataset = load_dataset( - self.datapath, - self.format, - load_properties=self.load_properties, - property_units=self.property_units, - distance_unit=self.distance_unit, - ) - # attach any transforms set before first load - self._dataset.transforms = self.transforms - - def __len__(self) -> int: - self._ensure_loaded() - return len(self._dataset) - - def __getitem__(self, idx: int): - self._ensure_loaded() - return self._dataset[idx] - def subset(self, indices): + def prepare(self) -> BaseAtomsData: """ - Forward subset() to underlying dataset. - Returns a BaseAtomsData-like object. - """ - self._ensure_loaded() - sub = self._dataset.subset(indices) - # Ensure transforms are carried over if caller expects it - # (DataModuleV2 will set per-split transforms) - if getattr(sub, "transforms", None) is None: - sub.transforms = [] - return sub - - @property - def available_properties(self): - self._ensure_loaded() - return self._dataset.available_properties - - @property - def atomrefs(self): - self._ensure_loaded() - return getattr(self._dataset, "atomrefs", None) - - @property - def metadata(self): - self._ensure_loaded() - return getattr(self._dataset, "metadata", None) - - @property - def distance_unit_internal(self): - self._ensure_loaded() - return getattr(self._dataset, "distance_unit", None) - - @property - def property_unit_dict(self): - self._ensure_loaded() - return getattr(self._dataset, "property_unit_dict", None) - - # ------------------------- - # Download / build pipeline - # ------------------------- - def _download_file(self, file_id: str, destination: str) -> None: - for base_url in self.base_urls: - url = f"{base_url}{file_id}" - try: - request.urlretrieve(url, destination) - return - except Exception: - logging.warning(f"Could not download from {url}, trying next source...") - raise AtomsDataModuleError( - f"Could not download file with id {file_id} from any source." - ) - - def prepare_data(self) -> None: - """ - Download + build the dataset if missing. If it already exists, verify - the uncharacterized setting is consistent. - - After prepare(), the dataset is loaded and ready to use. + Download + build the dataset if missing. + If it already exists, verify consistency. + Returns the loaded dataset. """ if not os.path.exists(self.datapath): property_unit_dict = { @@ -210,11 +108,7 @@ def prepare_data(self) -> None: datapath=self.datapath, format=self.format, distance_unit=self.distance_unit or "Ang", - property_unit_dict=( - property_unit_dict - if self.property_units is None - else self.property_units - ), + property_unit_dict=property_unit_dict, atomrefs=atomrefs, ) @@ -223,13 +117,11 @@ def prepare_data(self) -> None: else: uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) - + self._download_data(tmpdir, dataset, uncharacterized) finally: shutil.rmtree(tmpdir, ignore_errors=True) else: - # validate uncharacterized constraint against dataset size dataset = load_dataset( self.datapath, self.format, @@ -237,26 +129,38 @@ def prepare_data(self) -> None: property_units=self.property_units, distance_unit=self.distance_unit, ) + if self.remove_uncharacterized and len(dataset) == 133885: raise AtomsDataModuleError( "The dataset at the chosen location contains the uncharacterized 3054 molecules. " "Choose a different location to reload the data or set `remove_uncharacterized=False`." ) + if (not self.remove_uncharacterized) and len(dataset) < 133885: raise AtomsDataModuleError( "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " "Choose a different location to reload the data or set `remove_uncharacterized=True`." ) - # Load after prepare so wrapper is ready - self._dataset = load_dataset( + return load_dataset( self.datapath, self.format, load_properties=self.load_properties, property_units=self.property_units, distance_unit=self.distance_unit, ) - self._dataset.transforms = self.transforms + + def _download_file(self, file_id: str, destination: str) -> None: + for base_url in self.base_urls: + url = f"{base_url}{file_id}" + try: + request.urlretrieve(url, destination) + return + except Exception: + logging.warning(f"Could not download from {url}, trying next source...") + raise AtomsDataModuleError( + f"Could not download file with id {file_id} from any source." + ) def _download_uncharacterized(self, tmpdir: str) -> List[int]: logging.info("Downloading list of uncharacterized molecules...") @@ -279,11 +183,13 @@ def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: props = [QM9.zpve, QM9.U0, QM9.U, QM9.H, QM9.G, QM9.Cv] atref = {p: np.zeros((100,)) for p in props} + with open(tmp_path) as f: lines = f.readlines() - for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): + for z, line in zip([1, 6, 7, 8, 9], lines[5:10]): for i, p in enumerate(props): - atref[p][z] = float(l.split()[i + 1]) + atref[p][z] = float(line.split()[i + 1]) + return {k: v.tolist() for k, v in atref.items()} def _download_data( @@ -310,31 +216,36 @@ def _download_data( ) property_list = [] - irange = np.arange(len(ordered_files), dtype=int) + indices = np.arange(len(ordered_files), dtype=int) + if uncharacterized is not None: - irange = np.setdiff1d(irange, np.array(uncharacterized, dtype=int) - 1) + indices = np.setdiff1d(indices, np.array(uncharacterized, dtype=int) - 1) - for i in tqdm(irange): + for i in tqdm(indices): xyzfile = os.path.join(raw_path, ordered_files[i]) properties = {} tmp = io.StringIO() with open(xyzfile, "r") as f: lines = f.readlines() - l = lines[1].split()[2:] - for pn, p in zip(dataset.available_properties, l): - properties[pn] = np.array([float(p)]) + values = lines[1].split()[2:] + + for pname, value in zip(dataset.available_properties, values): + properties[pname] = np.array([float(value)]) + for line in lines: tmp.write(line.replace("*^", "e")) tmp.seek(0) - ats: Atoms = list(read_xyz(tmp, 0))[0] - properties[structure.Z] = ats.numbers - properties[structure.R] = ats.positions - properties[structure.cell] = ats.cell - properties[structure.pbc] = ats.pbc + atoms: Atoms = list(read_xyz(tmp, 0))[0] + + properties[structure.Z] = atoms.numbers + properties[structure.R] = atoms.positions + properties[structure.cell] = atoms.cell + properties[structure.pbc] = atoms.pbc + property_list.append(properties) logging.info("Write atoms to db...") dataset.add_systems(property_list=property_list) - logging.info("Done.") + logging.info("Done.") \ No newline at end of file From ee24a7e289b6dd1f4aed4b1950677c2e125af10a Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Wed, 4 Mar 2026 00:30:38 +0100 Subject: [PATCH 14/68] refactor: enhance AtomsDataModuleV2 and QM9 class by simplifying initialization and adding dataloader methods --- src/schnetpack/data/datamodule_v2.py | 46 ++++--- src/schnetpack/datasets/qm9.py | 175 +++++++++++++------------- src/schnetpack/transform/atomistic.py | 2 - src/schnetpack/transform/base.py | 2 +- 4 files changed, 119 insertions(+), 106 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index f7ef803f9..37c3b26ed 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -9,6 +9,7 @@ from schnetpack.data.atoms import BaseAtomsData from schnetpack.data.provider import StatsAtomrefProvider from schnetpack.data.splitting import RandomSplit, SplittingStrategy +from schnetpack.data.loader import AtomsLoader class AtomsDataModuleV2(pl.LightningDataModule): @@ -17,7 +18,7 @@ class AtomsDataModuleV2(pl.LightningDataModule): - accepts a dataset instance - handles splitting - builds StatsAtomrefProvider from train split - - initializes transforms + - initializes transforms """ def __init__( @@ -34,7 +35,7 @@ def __init__( val_transforms: Optional[List] = None, test_transforms: Optional[List] = None, num_workers: int = 0, - strict_transform_init: bool = True, + **kwargs, ): super().__init__() @@ -46,7 +47,6 @@ def __init__( self.split_file = split_file self.splitting = splitting or RandomSplit() self.num_workers = num_workers - self.strict_transform_init = strict_transform_init self.train_transforms = train_transforms or copy(transforms) or [] self.val_transforms = val_transforms or copy(transforms) or [] @@ -88,11 +88,7 @@ def setup(self, stage: Optional[str] = None) -> None: self._val_dataset = self.dataset.subset(self.val_idx) self._test_dataset = self.dataset.subset(self.test_idx) - self.provider = StatsAtomrefProvider( - train_dataset=self._train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - ) + self.provider = StatsAtomrefProvider(self._train_dataset) self._initialize_transform_list(self.train_transforms) self._initialize_transform_list(self.val_transforms) @@ -107,13 +103,7 @@ def _initialize_transform_list(self, transforms: List) -> None: return for t in transforms: - init_fn = getattr(t, "initialize", None) - if callable(init_fn): - init_fn(self.provider, atomrefs=self.provider.train_atomrefs) - elif self.strict_transform_init: - raise RuntimeError( - f"Transform {type(t).__name__} does not implement initialize." - ) + t.initialize(provider=self.provider, atomrefs=self.provider.train_atomrefs) def _load_partitions(self) -> None: import os @@ -156,4 +146,28 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, - ) \ No newline at end of file + ) + + def train_dataloader(self): + return AtomsLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return AtomsLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return AtomsLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 939648460..091505c6a 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -11,26 +11,19 @@ import numpy as np from ase import Atoms from ase.io.extxyz import read_xyz + from tqdm import tqdm import schnetpack.properties as structure -from schnetpack.data import ( - AtomsDataFormat, - AtomsDataModuleError, - BaseAtomsData, - create_dataset, - load_dataset, -) +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset __all__ = ["QM9"] -class QM9: +class QM9(ASEAtomsData): """ - QM9 benchmark database downloader/builder. - - This class only prepares the QM9 dataset on disk. - `prepare()` returns the loaded dataset. + QM9 benchmark database for organic molecules. """ base_urls = [ @@ -62,93 +55,98 @@ class QM9: def __init__( self, datapath: str, - format: AtomsDataFormat = AtomsDataFormat.ASE, - load_properties: Optional[List[str]] = None, + format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, remove_uncharacterized: bool = False, + load_properties: Optional[List[str]] = None, + # transforms=None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, + **kwargs, ): - self.datapath = datapath - self.format = format - self.load_properties = load_properties self.remove_uncharacterized = remove_uncharacterized - self.property_units = property_units - self.distance_unit = distance_unit + self.format = format - def prepare(self) -> BaseAtomsData: - """ - Download + build the dataset if missing. - If it already exists, verify consistency. - Returns the loaded dataset. + self.prepare( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) + + super().__init__( + datapath=datapath, + load_properties=load_properties, + # transforms=transforms, + subset_idx=subset_idx, + property_units=property_units, + distance_unit=distance_unit, + **kwargs, + ) + + @staticmethod + def _native_property_units() -> Dict[str, str]: + # IMPORTANT: full native QM9 schema, stored in DB metadata + return { + QM9.A: "GHz", + QM9.B: "GHz", + QM9.C: "GHz", + QM9.mu: "Debye", + QM9.alpha: "a0 a0 a0", + QM9.homo: "Ha", + QM9.lumo: "Ha", + QM9.gap: "Ha", + QM9.r2: "a0 a0", + QM9.zpve: "Ha", + QM9.U0: "Ha", + QM9.U: "Ha", + QM9.H: "Ha", + QM9.G: "Ha", + QM9.Cv: "cal/mol/K", + } + + def prepare(self, datapath: str, distance_unit: str = "Ang") -> None: """ - if not os.path.exists(self.datapath): - property_unit_dict = { - QM9.A: "GHz", - QM9.B: "GHz", - QM9.C: "GHz", - QM9.mu: "Debye", - QM9.alpha: "a0 a0 a0", - QM9.homo: "Ha", - QM9.lumo: "Ha", - QM9.gap: "Ha", - QM9.r2: "a0 a0", - QM9.zpve: "Ha", - QM9.U0: "Ha", - QM9.U: "Ha", - QM9.H: "Ha", - QM9.G: "Ha", - QM9.Cv: "cal/mol/K", - } - - tmpdir = tempfile.mkdtemp("qm9") - try: - atomrefs = self._download_atomrefs(tmpdir) - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit=self.distance_unit or "Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, - ) + Make sure the QM9 database exists. - if self.remove_uncharacterized: - uncharacterized = self._download_uncharacterized(tmpdir) - else: - uncharacterized = None - - self._download_data(tmpdir, dataset, uncharacterized) - finally: - shutil.rmtree(tmpdir, ignore_errors=True) - - else: - dataset = load_dataset( - self.datapath, - self.format, - load_properties=self.load_properties, - property_units=self.property_units, - distance_unit=self.distance_unit, - ) + If the DB already exists, validate consistency with the + remove_uncharacterized setting. + """ + if os.path.exists(datapath): + dataset = load_dataset(datapath, self.format, load_structure=False) if self.remove_uncharacterized and len(dataset) == 133885: - raise AtomsDataModuleError( + raise AtomsDataError( "The dataset at the chosen location contains the uncharacterized 3054 molecules. " - "Choose a different location to reload the data or set `remove_uncharacterized=False`." + "Choose a different location to reload the data or set " + "`remove_uncharacterized=False`." ) if (not self.remove_uncharacterized) and len(dataset) < 133885: - raise AtomsDataModuleError( + raise AtomsDataError( "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " - "Choose a different location to reload the data or set `remove_uncharacterized=True`." + "Choose a different location to reload the data or set " + "`remove_uncharacterized=True`." ) + return - return load_dataset( - self.datapath, - self.format, - load_properties=self.load_properties, - property_units=self.property_units, - distance_unit=self.distance_unit, - ) + tmpdir = tempfile.mkdtemp("qm9") + try: + atomrefs = self._download_atomrefs(tmpdir) + + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=atomrefs, + ) + + if self.remove_uncharacterized: + uncharacterized = self._download_uncharacterized(tmpdir) + else: + uncharacterized = None + + self._download_data(tmpdir, dataset, uncharacterized) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) def _download_file(self, file_id: str, destination: str) -> None: for base_url in self.base_urls: @@ -158,11 +156,13 @@ def _download_file(self, file_id: str, destination: str) -> None: return except Exception: logging.warning(f"Could not download from {url}, trying next source...") - raise AtomsDataModuleError( + + raise AtomsDataError( f"Could not download file with id {file_id} from any source." ) def _download_uncharacterized(self, tmpdir: str) -> List[int]: + logging.info("Downloading list of uncharacterized molecules...") tmp_path = os.path.join(tmpdir, "uncharacterized.txt") self._download_file(self.file_ids["uncharacterized"], tmp_path) @@ -186,18 +186,19 @@ def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: with open(tmp_path) as f: lines = f.readlines() - for z, line in zip([1, 6, 7, 8, 9], lines[5:10]): + for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): for i, p in enumerate(props): - atref[p][z] = float(line.split()[i + 1]) + atref[p][z] = float(l.split()[i + 1]) return {k: v.tolist() for k, v in atref.items()} def _download_data( self, tmpdir: str, - dataset: BaseAtomsData, + dataset: ASEAtomsData, uncharacterized: Optional[List[int]], ) -> None: + logging.info("Downloading GDB-9 data...") tar_path = os.path.join(tmpdir, "gdb9.tar.gz") raw_path = os.path.join(tmpdir, "gdb9_xyz") @@ -248,4 +249,4 @@ def _download_data( logging.info("Write atoms to db...") dataset.add_systems(property_list=property_list) - logging.info("Done.") \ No newline at end of file + logging.info("Done.") diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index e42643825..ad6ba9676 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -332,7 +332,6 @@ def datamodule(self, _datamodule): provider = StatsAtomrefProvider(_datamodule.train_dataset) return self.initialize(provider, atomrefs=provider.train_atomrefs) - def forward( self, inputs: Dict[str, torch.Tensor], @@ -358,4 +357,3 @@ def forward( inputs[self._property] += y0 return inputs - \ No newline at end of file diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index b8b7c7c02..4dbff1f61 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict import torch import torch.nn as nn From 8a900555446772cf0b97cd8fe61166ae9697c388 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Wed, 4 Mar 2026 00:31:15 +0100 Subject: [PATCH 15/68] fix: black format --- src/schnetpack/data/provider.py | 2 +- src/schnetpack/data/stats.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py index dce5a853c..d05e60d2b 100644 --- a/src/schnetpack/data/provider.py +++ b/src/schnetpack/data/provider.py @@ -58,4 +58,4 @@ def get_atomrefs( )[property] self._atomref_cache[key] = atomref - return {property: atomref} \ No newline at end of file + return {property: atomref} diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index dc0a9db6e..51f049507 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -138,4 +138,4 @@ def estimate_atomrefs( for atom_type, weight in zip(existing_atom_types, weights[pname]): out[pname][atom_type] = weight - return out \ No newline at end of file + return out From ee32d85f0b727fe9a62bb3df927a12601107e26c Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Wed, 4 Mar 2026 01:29:08 +0100 Subject: [PATCH 16/68] refactor: update custom and qm9 config files --- src/schnetpack/configs/data/custom.yaml | 22 ++++++++++---- src/schnetpack/configs/data/qm9.yaml | 38 +++++++++++++++---------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/schnetpack/configs/data/custom.yaml b/src/schnetpack/configs/data/custom.yaml index 29fc00245..a6b44d5e1 100644 --- a/src/schnetpack/configs/data/custom.yaml +++ b/src/schnetpack/configs/data/custom.yaml @@ -1,12 +1,24 @@ -_target_: schnetpack.data.AtomsDataModule +# @package data +_target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2 +# legacy field if some old structured config still expects it datapath: ??? -data_workdir: null + +# dataset must be provided by concrete config +dataset: ??? + batch_size: 10 num_train: ??? num_val: ??? num_test: null + +split_file: ${run.data_dir}/split.npz +splitting: null + +transforms: ${data.transforms} +train_transforms: null +val_transforms: null +test_transforms: null + num_workers: 8 -num_val_workers: null -num_test_workers: null -train_sampler_cls: null \ No newline at end of file + diff --git a/src/schnetpack/configs/data/qm9.yaml b/src/schnetpack/configs/data/qm9.yaml index 02c3d1150..bb12876f7 100644 --- a/src/schnetpack/configs/data/qm9.yaml +++ b/src/schnetpack/configs/data/qm9.yaml @@ -1,22 +1,30 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.QM9 +datapath: ${run.data_dir}/qm9.db +#legacy inputs +train_sampler_cls: null +train_sampler_args: {} + +dataset: + _target_: schnetpack.datasets.qm9.QM9 + datapath: ${run.data_dir}/qm9.db + remove_uncharacterized: true + load_properties: null + distance_unit: Ang + property_units: + energy_U0: eV + energy_U: eV + enthalpy_H: eV + free_energy: eV + homo: eV + lumo: eV + gap: eV + zpve: eV -datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml batch_size: 100 num_train: 110000 num_val: 10000 -remove_uncharacterized: True - -# convert to typically used units -distance_unit: Ang -property_units: - energy_U0: eV - energy_U: eV - enthalpy_H: eV - free_energy: eV - homo: eV - lumo: eV - gap: eV - zpve: eV \ No newline at end of file +num_test: 10000 +num_workers: 2 From ecaca866aa3d09f51f83badd09e8bf2e4e3f5025 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Wed, 4 Mar 2026 01:46:55 +0100 Subject: [PATCH 17/68] refactor: improve model testing and checkpoint handling in cli --- src/schnetpack/cli.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 6025f2ed1..9f45163ce 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -178,17 +178,33 @@ def train(config: DictConfig): # Evaluate model on test set after training log.info("Starting testing.") - trainer.test(model=task, datamodule=datamodule, ckpt_path="best") + # trainer.test(model=task, datamodule=datamodule, ckpt_path="best") # Store best model best_path = trainer.checkpoint_callback.best_model_path + if not best_path: + raise RuntimeError("No best checkpoint found (best_model_path is empty).") + + # Load Lightning checkpoint dict (requires weights_only=False on torch 2.6+) + ckpt = torch.load( + best_path, + map_location=trainer.strategy.root_device, + weights_only=False, + ) + + # Restore weights into the already-instantiated task + task.load_state_dict(ckpt["state_dict"], strict=True) + + # Test without Lightning re-loading the checkpoint + trainer.test(model=task, datamodule=datamodule, ckpt_path=None) + log.info(f"Best checkpoint path:\n{best_path}") log.info(f"Store best model") - best_task = type(task).load_from_checkpoint(best_path) - torch.save(best_task, config.globals.model_path + ".task") + # best_task = type(task).load_from_checkpoint(best_path) + torch.save(task, config.globals.model_path + ".task") - best_task.save_model(config.globals.model_path, do_postprocessing=True) + task.save_model(config.globals.model_path, do_postprocessing=True) log.info(f"Best model stored at {os.path.abspath(config.globals.model_path)}") From 1339fb62a36bc74d4d4b0a8a994a5fae99cf3866 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Wed, 4 Mar 2026 03:13:06 +0100 Subject: [PATCH 18/68] refactor: merged ASEAtomsData class and BaseAtomsData --- src/schnetpack/data/atoms.py | 440 ++++++------------- src/schnetpack/data/atoms_legacy.py | 635 ++++++++++++++++++++++++++++ 2 files changed, 761 insertions(+), 314 deletions(-) create mode 100644 src/schnetpack/data/atoms_legacy.py diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index a06d7b5d6..f3f911766 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -1,24 +1,10 @@ -""" -This module contains all functionalities required to load atomistic data, -generate batches and compute statistics. It makes use of the ASE database -for atoms [#ase2]_. - -References ----------- -.. [#ase2] Larsen, Mortensen, Blomqvist, Castelli, Christensen, Dułak, Friis, - Groves, Hammer, Hargus: - The atomic simulation environment -- a Python library for working with atoms. - Journal of Physics: Condensed Matter, 9, 27. 2017. -""" - +import copy import logging import os -from abc import ABC, abstractmethod from enum import Enum from typing import Optional, List, Dict, Any, Iterable, Union, Tuple import torch -import copy from ase import Atoms from ase.db import connect @@ -30,7 +16,6 @@ __all__ = [ "ASEAtomsData", - "BaseAtomsData", "AtomsDataFormat", "resolve_format", "create_dataset", @@ -48,148 +33,10 @@ class AtomsDataError(Exception): pass -extension_map = {AtomsDataFormat.ASE: ".db"} - - -class BaseAtomsData(ABC): - """ - Base mixin class for atomistic data. Use together with PyTorch Dataset or - IterableDataset to implement concrete data formats. - """ - - def __init__( - self, - load_properties: Optional[List[str]] = None, - load_structure: bool = True, - transforms: Optional[List[Transform]] = None, - subset_idx: Optional[List[int]] = None, - ): - """ - Args: - load_properties: Set of properties to be loaded and returned. - If None, all properties in the ASE dB will be returned. - load_structure: If True, load structure properties. - transforms: preprocessing transforms (see schnetpack.data.transforms) - subset: List of data indices. - """ - self._transform_module = None - self.load_properties = load_properties - self.load_structure = load_structure - self.transforms = transforms - self.subset_idx = subset_idx - - def __len__(self) -> int: - raise NotImplementedError - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, value: Optional[List[Transform]]): - self._transforms = [] - self._transform_module = None - - if value is not None: - for tf in value: - self._transforms.append(tf) - self._transform_module = torch.nn.Sequential(*self._transforms) - - def subset(self, subset_idx: List[int]): - assert ( - subset_idx is not None - ), "Indices for creation of the subset need to be provided!" - ds = copy.copy(self) - if ds.subset_idx: - ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] - else: - ds.subset_idx = subset_idx - return ds - - @property - @abstractmethod - def available_properties(self) -> List[str]: - """Available properties in the dataset""" - pass - - @property - @abstractmethod - def units(self) -> Dict[str, str]: - """Property to unit dict""" - pass - - @property - def load_properties(self) -> List[str]: - """Properties to be loaded""" - if self._load_properties is None: - return self.available_properties - else: - return self._load_properties - - @load_properties.setter - def load_properties(self, val: List[str]): - if val is not None: - props = self.available_properties - assert all( - [p in props for p in val] - ), "Not all given properties are available in the dataset!" - self._load_properties = val - - @property - @abstractmethod - def metadata(self) -> Dict[str, Any]: - """Global metadata""" - pass - - @property - @abstractmethod - def atomrefs(self) -> Dict[str, torch.Tensor]: - """Single-atom reference values for properties""" - pass - - @abstractmethod - def update_metadata(self, **kwargs): - pass - - @abstractmethod - def iter_properties( - self, - indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, - load_structure: Optional[bool] = None, - ): - pass - - @staticmethod - @abstractmethod - def create( - datapath: str, - position_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Dict[str, List[float]], - **kwargs, - ) -> "BaseAtomsData": - pass - - @abstractmethod - def add_systems( - self, - property_list: List[Dict[str, Any]], - atoms_list: Optional[List[Atoms]] = None, - atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, - ): - pass - - @abstractmethod - def add_system(self, atoms: Optional[Atoms] = None, **properties): - pass - - -class ASEAtomsData(BaseAtomsData): +class ASEAtomsData(torch.utils.data.Dataset): """ PyTorch dataset for atomistic data. The raw data is stored in the specified ASE database. - """ def __init__( @@ -197,46 +44,32 @@ def __init__( datapath: str, load_properties: Optional[List[str]] = None, load_structure: bool = True, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, ): - """ - Args: - datapath: Path to ASE DB. - load_properties: Set of properties to be loaded and returned. - If None, all properties in the ASE dB will be returned. - load_structure: If True, load structure properties. - transforms: preprocessing torch.nn.Module (see schnetpack.data.transforms) - subset_idx: List of data indices. - units: property-> unit string dictionary that overwrites the native units - of the dataset. Units are converted automatically during loading. - """ self.datapath = datapath self.conn = connect(self.datapath, use_lock_file=False) - BaseAtomsData.__init__( - self, - load_properties=load_properties, - load_structure=load_structure, - transforms=transforms, - subset_idx=subset_idx, - ) + # merged ASEAtomsData state + self._transform_module = None + self._transforms: List[Transform] = [] + self._load_properties: Optional[List[str]] = None + self.load_structure = load_structure + self.subset_idx = subset_idx self._check_db() - # initialize units + # units from metadata md = self.metadata - if "_distance_unit" not in md.keys(): + if "_distance_unit" not in md: raise AtomsDataError( - "Dataset does not have a distance unit set. Please add units to the " - + "dataset using `spkconvert`!" + "Dataset does not have a distance unit set. Please add units to the dataset." ) - if "_property_unit_dict" not in md.keys(): + if "_property_unit_dict" not in md: raise AtomsDataError( - "Dataset does not have a property units set. Please add units to the " - + "dataset using `spkconvert`!" + "Dataset does not have property units set. Please add units to the dataset." ) if distance_unit: @@ -248,8 +81,10 @@ def __init__( self.distance_conversion = 1.0 self.distance_unit = md["_distance_unit"] - self._units = md["_property_unit_dict"] + self._units = dict(md["_property_unit_dict"]) self.conversions = {prop: 1.0 for prop in self._units} + + # apply unit overrides on load only if property_units is not None: for prop, unit in property_units.items(): self.conversions[prop] = spk.units.convert_units( @@ -257,24 +92,69 @@ def __init__( ) self._units[prop] = unit + # now validate load_properties against available_properties + self.load_properties = load_properties + + # set transforms last + self.transforms = transforms + + # ---------- merged ASEAtomsData bits ---------- + + @property + def transforms(self) -> List[Transform]: + return self._transforms + + @transforms.setter + def transforms(self, value: Optional[List[Transform]]): + self._transforms = [] + self._transform_module = None + if value: + self._transforms.extend(value) + self._transform_module = torch.nn.Sequential(*self._transforms) + + def subset(self, subset_idx: List[int]): + if subset_idx is None: + raise ValueError("subset_idx must be provided.") + ds = copy.copy(self) + if ds.subset_idx: + ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] + else: + ds.subset_idx = subset_idx + return ds + + @property + def load_properties(self) -> List[str]: + if self._load_properties is None: + return self.available_properties + return self._load_properties + + @load_properties.setter + def load_properties(self, val: Optional[List[str]]): + if val is not None: + props = self.available_properties + missing = [p for p in val if p not in props] + if missing: + raise AtomsDataError(f"Properties not available in dataset: {missing}") + self._load_properties = val + + # ---------- core dataset API ---------- + def __len__(self) -> int: if self.subset_idx is not None: return len(self.subset_idx) - return self.conn.count() def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: if self.subset_idx is not None: idx = self.subset_idx[idx] - props = self._get_properties( self.conn, idx, self.load_properties, self.load_structure ) - props = self._apply_transforms(props) - - return props + return self._apply_transforms(props) - def _apply_transforms(self, props): + def _apply_transforms( + self, props: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: if self._transform_module is not None: props = self._transform_module(props) return props @@ -286,47 +166,69 @@ def _check_db(self): if self.subset_idx: with connect(self.datapath, use_lock_file=False) as conn: n_structures = conn.count() + if max(self.subset_idx) >= n_structures: + raise AtomsDataError("subset_idx contains out-of-range indices") + + # ---------- metadata / units ---------- + + @property + def metadata(self) -> Dict[str, Any]: + with connect(self.datapath, use_lock_file=False) as conn: + return conn.metadata + + def _set_metadata(self, val: Dict[str, Any]): + with connect(self.datapath, use_lock_file=False) as conn: + conn.metadata = val + + def update_metadata(self, **kwargs): + if not all(k and k[0] != "_" for k in kwargs): + raise AtomsDataError("Metadata keys starting with '_' are protected!") + md = self.metadata + md.update(kwargs) + self._set_metadata(md) + + @property + def available_properties(self) -> List[str]: + md = self.metadata + return list(md["_property_unit_dict"].keys()) + + @property + def units(self) -> Dict[str, str]: + return self._units - assert max(self.subset_idx) < n_structures + @property + def atomrefs(self) -> Dict[str, torch.Tensor]: + md = self.metadata + arefs = md.get("atomrefs", {}) + return {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} + + # ---------- iteration ---------- def iter_properties( self, indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, + load_properties: Optional[List[str]] = None, load_structure: Optional[bool] = None, load_metadata: bool = False, ): - """ - Return property dictionary at given indices. - - Args: - indices: data indices - load_properties (sequence or None): subset of available properties to load - load_structure: load and return structure - load_metadata: load and return metadata - - Returns: - properties (dict): dictionary with molecular properties - - """ if load_properties is None: load_properties = self.load_properties - load_structure = load_structure or self.load_structure + if load_structure is None: + load_structure = self.load_structure if self.subset_idx: if indices is None: indices = self.subset_idx - elif type(indices) is int: + elif isinstance(indices, int): indices = [self.subset_idx[indices]] else: indices = [self.subset_idx[i] for i in indices] else: if indices is None: indices = range(len(self)) - elif type(indices) is int: + elif isinstance(indices, int): indices = [indices] - # read from ase db for i in indices: yield self._get_properties( self.conn, @@ -345,11 +247,10 @@ def _get_properties( load_metadata: bool = False, ): row = conn.get(idx + 1) - - # extract properties # TODO: can the copies be avoided? - properties = {} + properties: Dict[str, torch.Tensor] = {} properties[structure.idx] = torch.tensor([idx]) + for pname in load_properties: properties[pname] = ( torch.tensor(row.data[pname].copy()) * self.conversions[pname] @@ -373,43 +274,7 @@ def _get_properties( return properties - # Metadata - @property - def metadata(self): - with connect(self.datapath, use_lock_file=False) as conn: - return conn.metadata - - def _set_metadata(self, val: Dict[str, Any]): - with connect(self.datapath, use_lock_file=False) as conn: - conn.metadata = val - - def update_metadata(self, **kwargs): - assert all( - key[0] != 0 for key in kwargs - ), "Metadata keys starting with '_' are protected!" - - md = self.metadata - md.update(kwargs) - self._set_metadata(md) - - @property - def available_properties(self) -> List[str]: - md = self.metadata - return list(md["_property_unit_dict"].keys()) - - @property - def units(self) -> Dict[str, str]: - """Dictionary of properties to units""" - return self._units - - @property - def atomrefs(self) -> Dict[str, torch.Tensor]: - md = self.metadata - arefs = md["atomrefs"] - arefs = {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} - return arefs - - ## Creation + # ---------- creation / writing ---------- @staticmethod def create( @@ -419,35 +284,14 @@ def create( atomrefs: Optional[Dict[str, List[float]]] = None, **kwargs, ) -> "ASEAtomsData": - """ - - Args: - datapath: Path to ASE DB. - distance_unit: unit of atom positions and cell - property_unit_dict: Defines the available properties of the datasetseta and - provides units for ALL properties of the dataset. If a property is - unit-less, you can pass "arb. unit" or `None`. - atomrefs: dictionary mapping properies (the keys) to lists of single-atom - reference values of the property. This is especially useful for - extensive properties such as the energy, where the single atom energies - contribute a major part to the overall value. - kwargs: Pass arguments to init. - - Returns: - newly created ASEAtomsData - - """ if not datapath.endswith(".db"): - raise AtomsDataError( - "Invalid datapath! Please make sure to add the file extension '.db' to " - "your dbpath." - ) - + raise AtomsDataError("Invalid datapath! Add '.db' extension.") if os.path.exists(datapath): raise AtomsDataError(f"Dataset already exists: {datapath}") - atomrefs = atomrefs or {} + os.makedirs(os.path.dirname(datapath) or ".", exist_ok=True) + atomrefs = atomrefs or {} with connect(datapath) as conn: conn.metadata = { "_property_unit_dict": property_unit_dict, @@ -457,7 +301,6 @@ def create( return ASEAtomsData(datapath, **kwargs) - # add systems def add_system( self, atoms: Optional[Atoms] = None, @@ -486,36 +329,15 @@ def add_systems( atoms_list: Optional[List[Atoms]] = None, atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, ): - """ - Add atoms data to the dataset. - - Args: - atoms_list: System composition and geometry. If Atoms are None, - the structure needs to be given as part of the property dicts - (using structure.Z, structure.R, structure.cell, structure.pbc) - property_list: Properties as list of key-value pairs in the same - order as corresponding list of `atoms`. - Keys have to match the `available_properties` of the dataset - plus additional structure properties, if atoms is None. - atoms_metadata_list: Metadata of the atoms objects as list of key-value pairs in the same - order as corresponding list of `atoms`. - Metadata can not be used as a training property, but can be used for splitting - strategies (e.g. material_id, timestamp, ...). - """ if atoms_list is None: atoms_list = [None] * len(property_list) - if atoms_metadata_list is None: atoms_metadata_list = [{}] * len(property_list) for atoms, prop, atoms_metadata in zip( atoms_list, property_list, atoms_metadata_list ): - self._add_system( - atoms, - atoms_metadata, - **prop, - ) + self._add_system(atoms, atoms_metadata, **prop) def _add_system( self, @@ -523,10 +345,6 @@ def _add_system( atoms_metadata: Optional[Dict[str, Any]] = None, **properties, ): - """ - Add systems to DB. - """ - # create atoms object if not provided if atoms is None: try: Z = properties[structure.Z] @@ -535,9 +353,7 @@ def _add_system( pbc = properties[structure.pbc] atoms = Atoms(numbers=Z, positions=R, cell=cell, pbc=pbc) except KeyError as e: - raise AtomsDataError( - "Property dict does not contain all necessary structure keys" - ) from e + raise AtomsDataError("Missing structure keys in properties") from e if atoms_metadata is None: atoms_metadata = {} @@ -545,25 +361,20 @@ def _add_system( with connect(self.datapath, use_lock_file=False) as conn: prop_keys = conn.metadata["_property_unit_dict"].keys() - valid_props = set().union( - prop_keys, - [structure.Z, structure.R, structure.cell, structure.pbc], + valid_props = set(prop_keys).union( + {structure.Z, structure.R, structure.cell, structure.pbc} ) for pname in properties: if pname not in valid_props: logger.warning( - f"Property `{pname}` is not a defined property for this dataset and " - + f"will be ignored. If it should be included, it has to be " - + f"provided together with its unit when calling " - + f"AseAtomsData.create()." + f"Property `{pname}` is not defined for this dataset and will be ignored." ) data = {} for pname in prop_keys: - if pname in properties: - data[pname] = properties[pname] - else: - raise AtomsDataError("Required property missing:" + pname) + if pname not in properties: + raise AtomsDataError("Required property missing: " + pname) + data[pname] = properties[pname] conn.write(atoms, data=data, key_value_pairs=atoms_metadata) @@ -574,7 +385,7 @@ def create_dataset( distance_unit: str, property_unit_dict: Dict[str, str], **kwargs, -) -> BaseAtomsData: +) -> ASEAtomsData: """ Create a new atoms dataset. @@ -598,10 +409,11 @@ def create_dataset( ) else: raise AtomsDataError(f"Unknown format: {format}") + return dataset -def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> BaseAtomsData: +def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: """ Load dataset. diff --git a/src/schnetpack/data/atoms_legacy.py b/src/schnetpack/data/atoms_legacy.py new file mode 100644 index 000000000..d3888dd9f --- /dev/null +++ b/src/schnetpack/data/atoms_legacy.py @@ -0,0 +1,635 @@ +""" +This module contains all functionalities required to load atomistic data, +generate batches and compute statistics. It makes use of the ASE database +for atoms [#ase2]_. + +References +---------- +.. [#ase2] Larsen, Mortensen, Blomqvist, Castelli, Christensen, Dułak, Friis, + Groves, Hammer, Hargus: + The atomic simulation environment -- a Python library for working with atoms. + Journal of Physics: Condensed Matter, 9, 27. 2017. +""" + +import logging +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional, List, Dict, Any, Iterable, Union, Tuple + +import torch +import copy +from ase import Atoms +from ase.db import connect + +import schnetpack as spk +import schnetpack.properties as structure +from schnetpack.transform.base import Transform + +logger = logging.getLogger(__name__) + +__all__ = [ + "ASEAtomsData", + "BaseAtomsData", + "AtomsDataFormat", + "resolve_format", + "create_dataset", + "load_dataset", +] + + +class AtomsDataFormat(Enum): + """Enumeration of data formats""" + + ASE = "ase" + + +class AtomsDataError(Exception): + pass + + +extension_map = {AtomsDataFormat.ASE: ".db"} + + +class BaseAtomsData(ABC): + """ + Base mixin class for atomistic data. Use together with PyTorch Dataset or + IterableDataset to implement concrete data formats. + """ + + def __init__( + self, + load_properties: Optional[List[str]] = None, + load_structure: bool = True, + transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, + ): + """ + Args: + load_properties: Set of properties to be loaded and returned. + If None, all properties in the ASE dB will be returned. + load_structure: If True, load structure properties. + transforms: preprocessing transforms (see schnetpack.data.transforms) + subset: List of data indices. + """ + self._transform_module = None + self.load_properties = load_properties + self.load_structure = load_structure + self.transforms = transforms + self.subset_idx = subset_idx + + def __len__(self) -> int: + raise NotImplementedError + + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, value: Optional[List[Transform]]): + self._transforms = [] + self._transform_module = None + + if value is not None: + for tf in value: + self._transforms.append(tf) + self._transform_module = torch.nn.Sequential(*self._transforms) + + def subset(self, subset_idx: List[int]): + assert ( + subset_idx is not None + ), "Indices for creation of the subset need to be provided!" + ds = copy.copy(self) + if ds.subset_idx: + ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] + else: + ds.subset_idx = subset_idx + return ds + + @property + @abstractmethod + def available_properties(self) -> List[str]: + """Available properties in the dataset""" + pass + + @property + @abstractmethod + def units(self) -> Dict[str, str]: + """Property to unit dict""" + pass + + @property + def load_properties(self) -> List[str]: + """Properties to be loaded""" + if self._load_properties is None: + return self.available_properties + else: + return self._load_properties + + @load_properties.setter + def load_properties(self, val: List[str]): + if val is not None: + props = self.available_properties + assert all( + [p in props for p in val] + ), "Not all given properties are available in the dataset!" + self._load_properties = val + + @property + @abstractmethod + def metadata(self) -> Dict[str, Any]: + """Global metadata""" + pass + + @property + @abstractmethod + def atomrefs(self) -> Dict[str, torch.Tensor]: + """Single-atom reference values for properties""" + pass + + @abstractmethod + def update_metadata(self, **kwargs): + pass + + @abstractmethod + def iter_properties( + self, + indices: Union[int, Iterable[int]] = None, + load_properties: List[str] = None, + load_structure: Optional[bool] = None, + ): + pass + + @staticmethod + @abstractmethod + def create( + datapath: str, + position_unit: str, + property_unit_dict: Dict[str, str], + atomrefs: Dict[str, List[float]], + **kwargs, + ) -> "ASEAtomsData": + pass + + @abstractmethod + def add_systems( + self, + property_list: List[Dict[str, Any]], + atoms_list: Optional[List[Atoms]] = None, + atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, + ): + pass + + @abstractmethod + def add_system(self, atoms: Optional[Atoms] = None, **properties): + pass + + +class ASEAtomsData(ASEAtomsData): + """ + PyTorch dataset for atomistic data. The raw data is stored in the specified + ASE database. + + """ + + def __init__( + self, + datapath: str, + load_properties: Optional[List[str]] = None, + load_structure: bool = True, + transforms: Optional[List[torch.nn.Module]] = None, + subset_idx: Optional[List[int]] = None, + property_units: Optional[Dict[str, str]] = None, + distance_unit: Optional[str] = None, + ): + """ + Args: + datapath: Path to ASE DB. + load_properties: Set of properties to be loaded and returned. + If None, all properties in the ASE dB will be returned. + load_structure: If True, load structure properties. + transforms: preprocessing torch.nn.Module (see schnetpack.data.transforms) + subset_idx: List of data indices. + units: property-> unit string dictionary that overwrites the native units + of the dataset. Units are converted automatically during loading. + """ + self.datapath = datapath + self.conn = connect(self.datapath, use_lock_file=False) + + ASEAtomsData.__init__( + self, + load_properties=load_properties, + load_structure=load_structure, + transforms=transforms, + subset_idx=subset_idx, + ) + + self._check_db() + + # initialize units + md = self.metadata + if "_distance_unit" not in md.keys(): + raise AtomsDataError( + "Dataset does not have a distance unit set. Please add units to the " + + "dataset using `spkconvert`!" + ) + if "_property_unit_dict" not in md.keys(): + raise AtomsDataError( + "Dataset does not have a property units set. Please add units to the " + + "dataset using `spkconvert`!" + ) + + if distance_unit: + self.distance_conversion = spk.units.convert_units( + md["_distance_unit"], distance_unit + ) + self.distance_unit = distance_unit + else: + self.distance_conversion = 1.0 + self.distance_unit = md["_distance_unit"] + + self._units = md["_property_unit_dict"] + self.conversions = {prop: 1.0 for prop in self._units} + if property_units is not None: + for prop, unit in property_units.items(): + self.conversions[prop] = spk.units.convert_units( + self._units[prop], unit + ) + self._units[prop] = unit + + def __len__(self) -> int: + if self.subset_idx is not None: + return len(self.subset_idx) + + return self.conn.count() + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + if self.subset_idx is not None: + idx = self.subset_idx[idx] + + props = self._get_properties( + self.conn, idx, self.load_properties, self.load_structure + ) + props = self._apply_transforms(props) + + return props + + def _apply_transforms(self, props): + if self._transform_module is not None: + props = self._transform_module(props) + return props + + def _check_db(self): + if not os.path.exists(self.datapath): + raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") + + if self.subset_idx: + with connect(self.datapath, use_lock_file=False) as conn: + n_structures = conn.count() + + assert max(self.subset_idx) < n_structures + + def iter_properties( + self, + indices: Union[int, Iterable[int]] = None, + load_properties: List[str] = None, + load_structure: Optional[bool] = None, + load_metadata: bool = False, + ): + """ + Return property dictionary at given indices. + + Args: + indices: data indices + load_properties (sequence or None): subset of available properties to load + load_structure: load and return structure + load_metadata: load and return metadata + + Returns: + properties (dict): dictionary with molecular properties + + """ + if load_properties is None: + load_properties = self.load_properties + load_structure = load_structure or self.load_structure + + if self.subset_idx: + if indices is None: + indices = self.subset_idx + elif type(indices) is int: + indices = [self.subset_idx[indices]] + else: + indices = [self.subset_idx[i] for i in indices] + else: + if indices is None: + indices = range(len(self)) + elif type(indices) is int: + indices = [indices] + + # read from ase db + for i in indices: + yield self._get_properties( + self.conn, + i, + load_properties=load_properties, + load_structure=load_structure, + load_metadata=load_metadata, + ) + + def _get_properties( + self, + conn, + idx: int, + load_properties: List[str], + load_structure: bool, + load_metadata: bool = False, + ): + row = conn.get(idx + 1) + + # extract properties + # TODO: can the copies be avoided? + properties = {} + properties[structure.idx] = torch.tensor([idx]) + for pname in load_properties: + properties[pname] = ( + torch.tensor(row.data[pname].copy()) * self.conversions[pname] + ) + + Z = row["numbers"].copy() + properties[structure.n_atoms] = torch.tensor([Z.shape[0]]) + + if load_structure: + properties[structure.Z] = torch.tensor(Z, dtype=torch.long) + properties[structure.position] = ( + torch.tensor(row["positions"].copy()) * self.distance_conversion + ) + properties[structure.cell] = ( + torch.tensor(row["cell"][None].copy()) * self.distance_conversion + ) + properties[structure.pbc] = torch.tensor(row["pbc"]) + + if load_metadata: + properties["metadata"] = row.key_value_pairs + + return properties + + # Metadata + @property + def metadata(self): + with connect(self.datapath, use_lock_file=False) as conn: + return conn.metadata + + def _set_metadata(self, val: Dict[str, Any]): + with connect(self.datapath, use_lock_file=False) as conn: + conn.metadata = val + + def update_metadata(self, **kwargs): + assert all( + key[0] != 0 for key in kwargs + ), "Metadata keys starting with '_' are protected!" + + md = self.metadata + md.update(kwargs) + self._set_metadata(md) + + @property + def available_properties(self) -> List[str]: + md = self.metadata + return list(md["_property_unit_dict"].keys()) + + @property + def units(self) -> Dict[str, str]: + """Dictionary of properties to units""" + return self._units + + @property + def atomrefs(self) -> Dict[str, torch.Tensor]: + md = self.metadata + arefs = md["atomrefs"] + arefs = {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} + return arefs + + ## Creation + + @staticmethod + def create( + datapath: str, + distance_unit: str, + property_unit_dict: Dict[str, str], + atomrefs: Optional[Dict[str, List[float]]] = None, + **kwargs, + ) -> "ASEAtomsData": + """ + + Args: + datapath: Path to ASE DB. + distance_unit: unit of atom positions and cell + property_unit_dict: Defines the available properties of the datasetseta and + provides units for ALL properties of the dataset. If a property is + unit-less, you can pass "arb. unit" or `None`. + atomrefs: dictionary mapping properies (the keys) to lists of single-atom + reference values of the property. This is especially useful for + extensive properties such as the energy, where the single atom energies + contribute a major part to the overall value. + kwargs: Pass arguments to init. + + Returns: + newly created ASEAtomsData + + """ + if not datapath.endswith(".db"): + raise AtomsDataError( + "Invalid datapath! Please make sure to add the file extension '.db' to " + "your dbpath." + ) + + if os.path.exists(datapath): + raise AtomsDataError(f"Dataset already exists: {datapath}") + + atomrefs = atomrefs or {} + + with connect(datapath) as conn: + conn.metadata = { + "_property_unit_dict": property_unit_dict, + "_distance_unit": distance_unit, + "atomrefs": atomrefs, + } + + return ASEAtomsData(datapath, **kwargs) + + # add systems + def add_system( + self, + atoms: Optional[Atoms] = None, + atoms_metadata: Optional[Dict[str, Any]] = None, + **properties, + ): + self._add_system(atoms, atoms_metadata, **properties) + + def add_systems( + self, + property_list: List[Dict[str, Any]], + atoms_list: Optional[List[Atoms]] = None, + atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, + ): + """ + Add atoms data to the dataset. + + Args: + atoms_list: System composition and geometry. If Atoms are None, + the structure needs to be given as part of the property dicts + (using structure.Z, structure.R, structure.cell, structure.pbc) + property_list: Properties as list of key-value pairs in the same + order as corresponding list of `atoms`. + Keys have to match the `available_properties` of the dataset + plus additional structure properties, if atoms is None. + atoms_metadata_list: Metadata of the atoms objects as list of key-value pairs in the same + order as corresponding list of `atoms`. + Metadata can not be used as a training property, but can be used for splitting + strategies (e.g. material_id, timestamp, ...). + """ + if atoms_list is None: + atoms_list = [None] * len(property_list) + + if atoms_metadata_list is None: + atoms_metadata_list = [{}] * len(property_list) + + for atoms, prop, atoms_metadata in zip( + atoms_list, property_list, atoms_metadata_list + ): + self._add_system( + atoms, + atoms_metadata, + **prop, + ) + + def _add_system( + self, + atoms: Optional[Atoms] = None, + atoms_metadata: Optional[Dict[str, Any]] = None, + **properties, + ): + """ + Add systems to DB. + """ + # create atoms object if not provided + if atoms is None: + try: + Z = properties[structure.Z] + R = properties[structure.R] + cell = properties[structure.cell] + pbc = properties[structure.pbc] + atoms = Atoms(numbers=Z, positions=R, cell=cell, pbc=pbc) + except KeyError as e: + raise AtomsDataError( + "Property dict does not contain all necessary structure keys" + ) from e + + if atoms_metadata is None: + atoms_metadata = {} + + with connect(self.datapath, use_lock_file=False) as conn: + prop_keys = conn.metadata["_property_unit_dict"].keys() + + valid_props = set().union( + prop_keys, + [structure.Z, structure.R, structure.cell, structure.pbc], + ) + for pname in properties: + if pname not in valid_props: + logger.warning( + f"Property `{pname}` is not a defined property for this dataset and " + + f"will be ignored. If it should be included, it has to be " + + f"provided together with its unit when calling " + + f"AseAtomsData.create()." + ) + + data = {} + for pname in prop_keys: + if pname in properties: + data[pname] = properties[pname] + else: + raise AtomsDataError("Required property missing:" + pname) + + conn.write(atoms, data=data, key_value_pairs=atoms_metadata) + + +def create_dataset( + datapath: str, + format: AtomsDataFormat, + distance_unit: str, + property_unit_dict: Dict[str, str], + **kwargs, +) -> ASEAtomsData: + """ + Create a new atoms dataset. + + Args: + datapath: file path + format: atoms data format + distance_unit: unit of atom positiona etc. as string + property_unit_dict: dictionary that maps properties to units, + e.g. {"energy": "kcal/mol"} + **kwargs: arguments for passed to AtomsData init + + Returns: + + """ + if format is AtomsDataFormat.ASE: + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=property_unit_dict, + **kwargs, + ) + else: + raise AtomsDataError(f"Unknown format: {format}") + + return dataset + + +def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: + """ + Load dataset. + + Args: + datapath: file path + format: atoms data format + **kwargs: arguments for passed to AtomsData init + + """ + if format is AtomsDataFormat.ASE: + dataset = ASEAtomsData(datapath=datapath, **kwargs) + else: + raise AtomsDataError(f"Unknown format: {format}") + return dataset + + +def resolve_format( + datapath: str, format: Optional[AtomsDataFormat] = None +) -> Tuple[str, AtomsDataFormat]: + """ + Extract data format from file suffix, check for consistency with (optional) given + format, or append suffix to file path. + + Args: + datapath: path to atoms data + format: atoms data format + + """ + file, suffix = os.path.splitext(datapath) + if suffix == ".db": + if format is None: + format = AtomsDataFormat.ASE + assert ( + format is AtomsDataFormat.ASE + ), f"File extension {suffix} is not compatible with chosen format {format}" + elif len(suffix) == 0 and format: + datapath = datapath + extension_map[format] + elif len(suffix) == 0 and format is None: + raise AtomsDataError( + "If format is not given, `datapath` needs a supported file extension!" + ) + else: + raise AtomsDataError(f"Unsupported file extension: {suffix}") + return datapath, format From cc3cb3c6ea3c22e80579d6fdcb53b6b965750a99 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 7 Mar 2026 01:17:23 +0100 Subject: [PATCH 19/68] refactor: update data handling --- src/schnetpack/cli.py | 29 ++------ src/schnetpack/configs/data/custom.yaml | 10 ++- src/schnetpack/data/datamodule_v2.py | 99 +++++++++++++++++++------ 3 files changed, 88 insertions(+), 50 deletions(-) diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 9f45163ce..28af40075 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -17,7 +17,7 @@ import schnetpack as spk from schnetpack.utils import str2class from schnetpack.utils.script import log_hyperparameters, print_config -from schnetpack.data import BaseAtomsData, AtomsLoader +from schnetpack.data import ASEAtomsData, AtomsLoader from schnetpack.train import PredictionWriter from schnetpack import properties from schnetpack.utils import load_model @@ -178,40 +178,23 @@ def train(config: DictConfig): # Evaluate model on test set after training log.info("Starting testing.") - # trainer.test(model=task, datamodule=datamodule, ckpt_path="best") + trainer.test(model=task, datamodule=datamodule, ckpt_path="best",weights_only=False) # Store best model best_path = trainer.checkpoint_callback.best_model_path - if not best_path: - raise RuntimeError("No best checkpoint found (best_model_path is empty).") - - # Load Lightning checkpoint dict (requires weights_only=False on torch 2.6+) - ckpt = torch.load( - best_path, - map_location=trainer.strategy.root_device, - weights_only=False, - ) - - # Restore weights into the already-instantiated task - task.load_state_dict(ckpt["state_dict"], strict=True) - - # Test without Lightning re-loading the checkpoint - trainer.test(model=task, datamodule=datamodule, ckpt_path=None) - log.info(f"Best checkpoint path:\n{best_path}") log.info(f"Store best model") - # best_task = type(task).load_from_checkpoint(best_path) - torch.save(task, config.globals.model_path + ".task") + best_task = type(task).load_from_checkpoint(best_path) + torch.save(best_task, config.globals.model_path + ".task") - task.save_model(config.globals.model_path, do_postprocessing=True) + best_task.save_model(config.globals.model_path, do_postprocessing=True) log.info(f"Best model stored at {os.path.abspath(config.globals.model_path)}") - @hydra.main(config_path="configs", config_name="predict", version_base="1.2") def predict(config: DictConfig): log.info(f"Load data from `{config.data.datapath}`") - dataset: BaseAtomsData = hydra.utils.instantiate(config.data) + dataset: ASEAtomsData = hydra.utils.instantiate(config.data) loader = AtomsLoader(dataset, batch_size=config.batch_size, num_workers=8) model = load_model("best_model") diff --git a/src/schnetpack/configs/data/custom.yaml b/src/schnetpack/configs/data/custom.yaml index a6b44d5e1..6345d0a44 100644 --- a/src/schnetpack/configs/data/custom.yaml +++ b/src/schnetpack/configs/data/custom.yaml @@ -1,11 +1,13 @@ # @package data _target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2 -# legacy field if some old structured config still expects it -datapath: ??? - # dataset must be provided by concrete config -dataset: ??? +dataset: + _target_: schnetpack.data.ASEAtomsData + datapath: ??? + load_properties: null + distance_unit: Ang + property_units: {} batch_size: 10 num_train: ??? diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 37c3b26ed..65fca0833 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,12 +1,13 @@ from __future__ import annotations from copy import copy -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any, Type import numpy as np import pytorch_lightning as pl +from torch.utils.data import BatchSampler -from schnetpack.data.atoms import BaseAtomsData +from schnetpack.data.atoms import ASEAtomsData from schnetpack.data.provider import StatsAtomrefProvider from schnetpack.data.splitting import RandomSplit, SplittingStrategy from schnetpack.data.loader import AtomsLoader @@ -23,7 +24,7 @@ class AtomsDataModuleV2(pl.LightningDataModule): def __init__( self, - dataset: BaseAtomsData, + dataset: ASEAtomsData, batch_size: int, num_train: Union[int, float], num_val: Union[int, float], @@ -35,18 +36,27 @@ def __init__( val_transforms: Optional[List] = None, test_transforms: Optional[List] = None, num_workers: int = 0, + val_batch_size: Optional[int] = None, + test_batch_size: Optional[int] = None, + train_sampler_cls: Optional[Type] = None, + train_sampler_args: Optional[Dict[str, Any]] = None, + pin_memory: bool = False, **kwargs, ): super().__init__() self.dataset = dataset self.batch_size = batch_size + self.val_batch_size = val_batch_size or test_batch_size or batch_size + self.test_batch_size = test_batch_size or val_batch_size or batch_size + self.num_train = num_train self.num_val = num_val self.num_test = num_test self.split_file = split_file self.splitting = splitting or RandomSplit() self.num_workers = num_workers + self._pin_memory = pin_memory self.train_transforms = train_transforms or copy(transforms) or [] self.val_transforms = val_transforms or copy(transforms) or [] @@ -60,22 +70,29 @@ def __init__( self._val_dataset = None self._test_dataset = None + self._train_dataloader = None + self._val_dataloader = None + self._test_dataloader = None + self.provider: Optional[StatsAtomrefProvider] = None + self.train_sampler_cls = train_sampler_cls + self.train_sampler_args = train_sampler_args or {} + @property - def train_dataset(self) -> BaseAtomsData: + def train_dataset(self) -> ASEAtomsData: if self._train_dataset is None: raise RuntimeError("Call setup() before accessing train_dataset.") return self._train_dataset @property - def val_dataset(self) -> BaseAtomsData: + def val_dataset(self) -> ASEAtomsData: if self._val_dataset is None: raise RuntimeError("Call setup() before accessing val_dataset.") return self._val_dataset @property - def test_dataset(self) -> BaseAtomsData: + def test_dataset(self) -> ASEAtomsData: if self._test_dataset is None: raise RuntimeError("Call setup() before accessing test_dataset.") return self._test_dataset @@ -132,6 +149,11 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: self.test_idx = split_data["test_idx"].tolist() return + if num_train is None or num_val is None: + raise ValueError( + "If no split file is given, num_train and num_val must be set." + ) + train_idx, val_idx, test_idx = self.splitting.split( self.dataset, num_train, num_val, num_test ) @@ -148,26 +170,57 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: test_idx=test_idx, ) - def train_dataloader(self): - return AtomsLoader( - self.train_dataset, + def _setup_sampler(self, sampler_cls, sampler_args, dataset): + if sampler_cls is None: + return None + + return BatchSampler( + sampler=sampler_cls( + data_source=dataset, + num_samples=len(dataset), + **sampler_args, + ), batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, + drop_last=True, ) + def train_dataloader(self): + if self._train_dataloader is None: + train_batch_sampler = self._setup_sampler( + sampler_cls=self.train_sampler_cls, + sampler_args=self.train_sampler_args, + dataset=self.train_dataset, + ) + + self._train_dataloader = AtomsLoader( + self.train_dataset, + batch_size=self.batch_size if train_batch_sampler is None else 1, + shuffle=True if train_batch_sampler is None else False, + batch_sampler=train_batch_sampler, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._train_dataloader + def val_dataloader(self): - return AtomsLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) + if self._val_dataloader is None: + self._val_dataloader = AtomsLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._val_dataloader def test_dataloader(self): - return AtomsLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) + if self._test_dataloader is None: + self._test_dataloader = AtomsLoader( + self.test_dataset, + batch_size=self.test_batch_size, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._test_dataloader \ No newline at end of file From b48ca020574d3ba568005b957a653eeecd015579 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 7 Mar 2026 01:18:20 +0100 Subject: [PATCH 20/68] refactor: simplify ASEAtomsData by removing unused methods and properties --- src/schnetpack/data/atoms.py | 103 +++--------------------------- src/schnetpack/data/datamodule.py | 12 ++-- 2 files changed, 14 insertions(+), 101 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index f3f911766..70e559b8f 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -17,9 +17,7 @@ __all__ = [ "ASEAtomsData", "AtomsDataFormat", - "resolve_format", - "create_dataset", - "load_dataset", + "load_dataset" ] @@ -50,16 +48,14 @@ def __init__( distance_unit: Optional[str] = None, ): self.datapath = datapath + self.subset_idx = subset_idx + self._check_db() self.conn = connect(self.datapath, use_lock_file=False) # merged ASEAtomsData state - self._transform_module = None - self._transforms: List[Transform] = [] + self.transforms: List[Transform] = list(transforms) if transforms is not None else [] self._load_properties: Optional[List[str]] = None self.load_structure = load_structure - self.subset_idx = subset_idx - - self._check_db() # units from metadata md = self.metadata @@ -95,23 +91,8 @@ def __init__( # now validate load_properties against available_properties self.load_properties = load_properties - # set transforms last - self.transforms = transforms # ---------- merged ASEAtomsData bits ---------- - - @property - def transforms(self) -> List[Transform]: - return self._transforms - - @transforms.setter - def transforms(self, value: Optional[List[Transform]]): - self._transforms = [] - self._transform_module = None - if value: - self._transforms.extend(value) - self._transform_module = torch.nn.Sequential(*self._transforms) - def subset(self, subset_idx: List[int]): if subset_idx is None: raise ValueError("subset_idx must be provided.") @@ -152,11 +133,9 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: ) return self._apply_transforms(props) - def _apply_transforms( - self, props: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: - if self._transform_module is not None: - props = self._transform_module(props) + def _apply_transforms(self, props: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for tf in self.transforms: + props = tf(props) return props def _check_db(self): @@ -378,41 +357,6 @@ def _add_system( conn.write(atoms, data=data, key_value_pairs=atoms_metadata) - -def create_dataset( - datapath: str, - format: AtomsDataFormat, - distance_unit: str, - property_unit_dict: Dict[str, str], - **kwargs, -) -> ASEAtomsData: - """ - Create a new atoms dataset. - - Args: - datapath: file path - format: atoms data format - distance_unit: unit of atom positiona etc. as string - property_unit_dict: dictionary that maps properties to units, - e.g. {"energy": "kcal/mol"} - **kwargs: arguments for passed to AtomsData init - - Returns: - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=property_unit_dict, - **kwargs, - ) - else: - raise AtomsDataError(f"Unknown format: {format}") - - return dataset - - def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: """ Load dataset. @@ -426,35 +370,4 @@ def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsDa if format is AtomsDataFormat.ASE: dataset = ASEAtomsData(datapath=datapath, **kwargs) else: - raise AtomsDataError(f"Unknown format: {format}") - return dataset - - -def resolve_format( - datapath: str, format: Optional[AtomsDataFormat] = None -) -> Tuple[str, AtomsDataFormat]: - """ - Extract data format from file suffix, check for consistency with (optional) given - format, or append suffix to file path. - - Args: - datapath: path to atoms data - format: atoms data format - - """ - file, suffix = os.path.splitext(datapath) - if suffix == ".db": - if format is None: - format = AtomsDataFormat.ASE - assert ( - format is AtomsDataFormat.ASE - ), f"File extension {suffix} is not compatible with chosen format {format}" - elif len(suffix) == 0 and format: - datapath = datapath + extension_map[format] - elif len(suffix) == 0 and format is None: - raise AtomsDataError( - "If format is not given, `datapath` needs a supported file extension!" - ) - else: - raise AtomsDataError(f"Unsupported file extension: {suffix}") - return datapath, format + raise AtomsDataError(f"Unknown format: {format}") \ No newline at end of file diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index ea361e0a5..8abead8a4 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -11,9 +11,9 @@ from schnetpack.data import ( AtomsDataFormat, - resolve_format, + #resolve_format, load_dataset, - BaseAtomsData, + ASEAtomsData, AtomsLoader, calculate_stats, estimate_atomrefs, @@ -116,7 +116,7 @@ def __init__( self.num_test = num_test self.splitting = splitting or RandomSplit() self.split_file = split_file - self.datapath, self.format = resolve_format(datapath, format) + #self.datapath, self.format = resolve_format(datapath, format) self.load_properties = load_properties self.num_workers = num_workers self.num_val_workers = self.num_workers @@ -386,15 +386,15 @@ def get_atomrefs( return {property: atomrefs} @property - def train_dataset(self) -> BaseAtomsData: + def train_dataset(self) -> ASEAtomsData: return self._train_dataset @property - def val_dataset(self) -> BaseAtomsData: + def val_dataset(self) -> ASEAtomsData: return self._val_dataset @property - def test_dataset(self) -> BaseAtomsData: + def test_dataset(self) -> ASEAtomsData: return self._test_dataset def train_dataloader(self) -> AtomsLoader: From bfc2c9774202d462065c9c4e4f92989b6db2904e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 7 Mar 2026 01:19:33 +0100 Subject: [PATCH 21/68] refactor: update dataset method signatures to use ASEAtomsData --- src/schnetpack/datasets/ani1.py | 2 +- src/schnetpack/datasets/materials_project.py | 2 +- src/schnetpack/datasets/md17.py | 2 +- src/schnetpack/datasets/qm7x.py | 2 +- src/schnetpack/datasets/qm9.py | 8 ++++---- src/schnetpack/datasets/qm9_legacy.py | 4 +--- src/schnetpack/datasets/rmd17.py | 2 +- src/schnetpack/datasets/tmqm.py | 2 +- 8 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 48fb0f062..ffe7e92cf 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -136,7 +136,7 @@ def prepare_data(self): else: dataset = load_dataset(self.datapath, self.format) - def _download_data(self, tmpdir, dataset: BaseAtomsData): + def _download_data(self, tmpdir, dataset: ASEAtomsData): logging.info("downloading ANI-1 data...") tar_path = os.path.join(tmpdir, "ANI1_release.tar.gz") raw_path = os.path.join(tmpdir, "data") diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 8779e4bff..9189140ef 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -145,7 +145,7 @@ def prepare_data(self): else: dataset = load_dataset(self.datapath, self.format) - def _download_data_nextgen(self, dataset: BaseAtomsData): + def _download_data_nextgen(self, dataset: ASEAtomsData): """ Downloads dataset provided it does not exist in self.path Returns: diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 0e23e7ecb..6cd866269 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -151,7 +151,7 @@ def prepare_data(self): def _download_data( self, tmpdir, - dataset: BaseAtomsData, + dataset: ASEAtomsData, ): logging.info("Downloading {} data".format(self.molecule)) rawpath = os.path.join(tmpdir, self.datasets_dict[self.molecule]) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 03b4d4087..e1179ac62 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -308,7 +308,7 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st return extracted - def _parse_data(self, files: List[str], dataset: BaseAtomsData): + def _parse_data(self, files: List[str], dataset: ASEAtomsData): """ Parse the downloaded data files and add them to the dataset. """ diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 091505c6a..f0686d567 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -58,7 +58,7 @@ def __init__( format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, remove_uncharacterized: bool = False, load_properties: Optional[List[str]] = None, - # transforms=None, + transforms=None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -67,7 +67,7 @@ def __init__( self.remove_uncharacterized = remove_uncharacterized self.format = format - self.prepare( + self.download( datapath=datapath, distance_unit=distance_unit or "Ang", ) @@ -75,7 +75,7 @@ def __init__( super().__init__( datapath=datapath, load_properties=load_properties, - # transforms=transforms, + transforms=transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, @@ -103,7 +103,7 @@ def _native_property_units() -> Dict[str, str]: QM9.Cv: "cal/mol/K", } - def prepare(self, datapath: str, distance_unit: str = "Ang") -> None: + def download(self, datapath: str, distance_unit: str = "Ang") -> None: """ Make sure the QM9 database exists. diff --git a/src/schnetpack/datasets/qm9_legacy.py b/src/schnetpack/datasets/qm9_legacy.py index 3085a21a7..7673d02d0 100644 --- a/src/schnetpack/datasets/qm9_legacy.py +++ b/src/schnetpack/datasets/qm9_legacy.py @@ -227,9 +227,7 @@ def _download_atomrefs(self, tmpdir): atref = {k: v.tolist() for k, v in atref.items()} return atref - def _download_data( - self, tmpdir, dataset: BaseAtomsData, uncharacterized: List[int] - ): + def _download_data(self, tmpdir, dataset: ASEAtomsData, uncharacterized: List[int]): logging.info("Downloading GDB-9 data...") tar_path = os.path.join(tmpdir, "gdb9.tar.gz") raw_path = os.path.join(tmpdir, "gdb9_xyz") diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 50cf59385..b3f27c35c 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -190,7 +190,7 @@ def prepare_data(self): def _download_data( self, tmpdir, - dataset: BaseAtomsData, + dataset: ASEAtomsData, ): logging.info("Downloading {} data".format(self.molecule)) raw_path = os.path.join(tmpdir, "rmd17") diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index 17c856ec7..b27b5ea67 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -151,7 +151,7 @@ def prepare_data(self): else: dataset = load_dataset(self.datapath, self.format) - def _download_data(self, tmpdir, dataset: BaseAtomsData): + def _download_data(self, tmpdir, dataset: ASEAtomsData): tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz") url = [ "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz", From 860fbee35cbf5a1657355a0c310ba41d00420ae7 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 7 Mar 2026 01:20:25 +0100 Subject: [PATCH 22/68] refactor: update references from BaseAtomsData to ASEAtomsData in data loader, provider, sampler, and stats modules --- src/schnetpack/data/loader.py | 2 +- src/schnetpack/data/provider.py | 4 ++-- src/schnetpack/data/sampler.py | 6 +++--- src/schnetpack/data/stats.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/schnetpack/data/loader.py b/src/schnetpack/data/loader.py index f116021ee..56f8d95f3 100644 --- a/src/schnetpack/data/loader.py +++ b/src/schnetpack/data/loader.py @@ -59,7 +59,7 @@ def _atoms_collate_fn(batch): class AtomsLoader(DataLoader): - """Data loader for subclasses of BaseAtomsData""" + """Data loader for subclasses of ASEAtomsData""" def __init__( self, diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py index d05e60d2b..8fa6432b7 100644 --- a/src/schnetpack/data/provider.py +++ b/src/schnetpack/data/provider.py @@ -4,7 +4,7 @@ import torch -from schnetpack.data.atoms import BaseAtomsData +from schnetpack.data.atoms import ASEAtomsData from schnetpack.data.stats import calculate_stats, estimate_atomrefs @@ -13,7 +13,7 @@ class StatsAtomrefProvider: Compute and cache statistics and atom references from the training dataset. """ - def __init__(self, train_dataset: BaseAtomsData) -> None: + def __init__(self, train_dataset: ASEAtomsData) -> None: self.train_dataset = train_dataset self.train_atomrefs = getattr(train_dataset, "atomrefs", None) diff --git a/src/schnetpack/data/sampler.py b/src/schnetpack/data/sampler.py index 0e353ef88..45d8beac9 100644 --- a/src/schnetpack/data/sampler.py +++ b/src/schnetpack/data/sampler.py @@ -4,7 +4,7 @@ from torch.utils.data import Sampler, WeightedRandomSampler from schnetpack import properties -from schnetpack.data import BaseAtomsData +from schnetpack.data import ASEAtomsData __all__ = [ @@ -53,8 +53,8 @@ class StratifiedSampler(WeightedRandomSampler): def __init__( self, - data_source: BaseAtomsData, - partition_criterion: Callable[[BaseAtomsData], List], + data_source: ASEAtomsData, + partition_criterion: Callable[[ASEAtomsData], List], num_samples: int, num_bins: int = 10, replacement: bool = True, diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 51f049507..6a7e68b3a 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -4,14 +4,14 @@ from tqdm import tqdm import schnetpack.properties as properties -from schnetpack.data.atoms import BaseAtomsData +from schnetpack.data.atoms import ASEAtomsData from schnetpack.data.loader import AtomsLoader __all__ = ["calculate_stats", "estimate_atomrefs"] def calculate_stats( - dataset: BaseAtomsData, + dataset: ASEAtomsData, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] = None, batch_size: int = 10000, @@ -76,7 +76,7 @@ def calculate_stats( def estimate_atomrefs( - dataset: BaseAtomsData, + dataset: ASEAtomsData, is_extensive: Dict[str, bool], z_max: int = 100, batch_size: int = 10000, From c83501af36aa5692c873fd293b3f8250ef4eca53 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 00:38:28 +0100 Subject: [PATCH 23/68] refactor: update checkpoint loading in training process and adjust dataset loading in QM9 class --- src/schnetpack/cli.py | 2 +- src/schnetpack/datasets/qm9.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 28af40075..de4e18c48 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -185,7 +185,7 @@ def train(config: DictConfig): log.info(f"Best checkpoint path:\n{best_path}") log.info(f"Store best model") - best_task = type(task).load_from_checkpoint(best_path) + best_task = type(task).load_from_checkpoint(best_path,weights_only=False) torch.save(best_task, config.globals.model_path + ".task") best_task.save_model(config.globals.model_path, do_postprocessing=True) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index f0686d567..0b4b4b4e4 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -16,7 +16,7 @@ import schnetpack.properties as structure from schnetpack.data import AtomsDataFormat -from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM9"] @@ -111,7 +111,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: remove_uncharacterized setting. """ if os.path.exists(datapath): - dataset = load_dataset(datapath, self.format, load_structure=False) + dataset = ASEAtomsData(datapath=datapath, load_structure=False) if self.remove_uncharacterized and len(dataset) == 133885: raise AtomsDataError( From 3fa44d30921a4d6e3d190bb7e5aedc7b8b61b683 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 16:37:49 +0100 Subject: [PATCH 24/68] refactor: clean up code formatting and remove legacy QM9 dataset file --- src/schnetpack/cli.py | 7 +- src/schnetpack/data/atoms.py | 18 +- src/schnetpack/data/datamodule.py | 4 +- src/schnetpack/data/datamodule_v2.py | 2 +- src/schnetpack/datasets/qm9.py | 3 +- src/schnetpack/datasets/qm9_legacy.py | 277 -------------------------- 6 files changed, 19 insertions(+), 292 deletions(-) delete mode 100644 src/schnetpack/datasets/qm9_legacy.py diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index de4e18c48..9f371c1c8 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -178,19 +178,22 @@ def train(config: DictConfig): # Evaluate model on test set after training log.info("Starting testing.") - trainer.test(model=task, datamodule=datamodule, ckpt_path="best",weights_only=False) + trainer.test( + model=task, datamodule=datamodule, ckpt_path="best", weights_only=False + ) # Store best model best_path = trainer.checkpoint_callback.best_model_path log.info(f"Best checkpoint path:\n{best_path}") log.info(f"Store best model") - best_task = type(task).load_from_checkpoint(best_path,weights_only=False) + best_task = type(task).load_from_checkpoint(best_path, weights_only=False) torch.save(best_task, config.globals.model_path + ".task") best_task.save_model(config.globals.model_path, do_postprocessing=True) log.info(f"Best model stored at {os.path.abspath(config.globals.model_path)}") + @hydra.main(config_path="configs", config_name="predict", version_base="1.2") def predict(config: DictConfig): log.info(f"Load data from `{config.data.datapath}`") diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 70e559b8f..6efeded4a 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -14,11 +14,7 @@ logger = logging.getLogger(__name__) -__all__ = [ - "ASEAtomsData", - "AtomsDataFormat", - "load_dataset" -] +__all__ = ["ASEAtomsData", "AtomsDataFormat", "load_dataset"] class AtomsDataFormat(Enum): @@ -53,7 +49,9 @@ def __init__( self.conn = connect(self.datapath, use_lock_file=False) # merged ASEAtomsData state - self.transforms: List[Transform] = list(transforms) if transforms is not None else [] + self.transforms: List[Transform] = ( + list(transforms) if transforms is not None else [] + ) self._load_properties: Optional[List[str]] = None self.load_structure = load_structure @@ -91,7 +89,6 @@ def __init__( # now validate load_properties against available_properties self.load_properties = load_properties - # ---------- merged ASEAtomsData bits ---------- def subset(self, subset_idx: List[int]): if subset_idx is None: @@ -133,7 +130,9 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: ) return self._apply_transforms(props) - def _apply_transforms(self, props: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _apply_transforms( + self, props: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: for tf in self.transforms: props = tf(props) return props @@ -357,6 +356,7 @@ def _add_system( conn.write(atoms, data=data, key_value_pairs=atoms_metadata) + def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: """ Load dataset. @@ -370,4 +370,4 @@ def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsDa if format is AtomsDataFormat.ASE: dataset = ASEAtomsData(datapath=datapath, **kwargs) else: - raise AtomsDataError(f"Unknown format: {format}") \ No newline at end of file + raise AtomsDataError(f"Unknown format: {format}") diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index 8abead8a4..d9f4ebed1 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -11,7 +11,7 @@ from schnetpack.data import ( AtomsDataFormat, - #resolve_format, + # resolve_format, load_dataset, ASEAtomsData, AtomsLoader, @@ -116,7 +116,7 @@ def __init__( self.num_test = num_test self.splitting = splitting or RandomSplit() self.split_file = split_file - #self.datapath, self.format = resolve_format(datapath, format) + # self.datapath, self.format = resolve_format(datapath, format) self.load_properties = load_properties self.num_workers = num_workers self.num_val_workers = self.num_workers diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 65fca0833..3cc493247 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -223,4 +223,4 @@ def test_dataloader(self): pin_memory=self._pin_memory, ) - return self._test_dataloader \ No newline at end of file + return self._test_dataloader diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 0b4b4b4e4..98f5dcad3 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -7,6 +7,7 @@ import tempfile from typing import Dict, List, Optional from urllib import request as request +import torch import numpy as np from ase import Atoms @@ -58,7 +59,7 @@ def __init__( format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, remove_uncharacterized: bool = False, load_properties: Optional[List[str]] = None, - transforms=None, + transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, diff --git a/src/schnetpack/datasets/qm9_legacy.py b/src/schnetpack/datasets/qm9_legacy.py deleted file mode 100644 index 7673d02d0..000000000 --- a/src/schnetpack/datasets/qm9_legacy.py +++ /dev/null @@ -1,277 +0,0 @@ -import io -import logging -import os -import re -import shutil -import tarfile -import tempfile -from typing import List, Optional, Dict -from urllib import request as request - -import numpy as np -from ase import Atoms -from ase.io.extxyz import read_xyz -from tqdm import tqdm - -import torch -from schnetpack.data import * -import schnetpack.properties as structure -from schnetpack.data import AtomsDataModuleError, AtomsDataModule - -__all__ = ["QM9"] - - -class QM9(AtomsDataModule): - """QM9 benchmark database for organic molecules. - - The QM9 database contains small organic molecules with up to nine non-hydrogen atoms - from including C, O, N, F. This class adds convenient functions to download QM9 from - figshare and load the data into pytorch. - - References: - - .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 - """ - - base_urls = [ - "https://ndownloader.figshare.com/files/", - "https://springernature.figshare.com/ndownloader/files/", - ] - file_ids = { - "data": "3195389", - "atomrefs": "3195395", - "uncharacterized": "3195404", - } - - # properties - A = "rotational_constant_A" - B = "rotational_constant_B" - C = "rotational_constant_C" - mu = "dipole_moment" - alpha = "isotropic_polarizability" - homo = "homo" - lumo = "lumo" - gap = "gap" - r2 = "electronic_spatial_extent" - zpve = "zpve" - U0 = "energy_U0" - U = "energy_U" - H = "enthalpy_H" - G = "free_energy" - Cv = "heat_capacity" - - def __init__( - self, - datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, - load_properties: Optional[List[str]] = None, - remove_uncharacterized: bool = False, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, - property_units: Optional[Dict[str, str]] = None, - distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - **kwargs, - ): - """ - - Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - remove_uncharacterized: do not include uncharacterized molecules. - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. - """ - super().__init__( - datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, - load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, - transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, - property_units=property_units, - distance_unit=distance_unit, - data_workdir=data_workdir, - **kwargs, - ) - - self.remove_uncharacterized = remove_uncharacterized - - def _download_file(self, file_id: str, destination: str): - for base_url in self.base_urls: - url = f"{base_url}{file_id}" - try: - request.urlretrieve(url, destination) - return - except Exception: - logging.warning(f"Could not download from {url}, trying next source...") - raise AtomsDataModuleError( - f"Could not download file with id {file_id} from any source." - ) - - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - QM9.A: "GHz", - QM9.B: "GHz", - QM9.C: "GHz", - QM9.mu: "Debye", - QM9.alpha: "a0 a0 a0", - QM9.homo: "Ha", - QM9.lumo: "Ha", - QM9.gap: "Ha", - QM9.r2: "a0 a0", - QM9.zpve: "Ha", - QM9.U0: "Ha", - QM9.U: "Ha", - QM9.H: "Ha", - QM9.G: "Ha", - QM9.Cv: "cal/mol/K", - } - - tmpdir = tempfile.mkdtemp("qm9") - atomrefs = self._download_atomrefs(tmpdir) - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, - ) - - if self.remove_uncharacterized: - uncharacterized = self._download_uncharacterized(tmpdir) - else: - uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) - if self.remove_uncharacterized and len(dataset) == 133885: - raise AtomsDataModuleError( - "The dataset at the chosen location contains the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=False`!" - ) - elif not self.remove_uncharacterized and len(dataset) < 133885: - raise AtomsDataModuleError( - "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=True`!" - ) - - def _download_uncharacterized(self, tmpdir): - logging.info("Downloading list of uncharacterized molecules...") - tmp_path = os.path.join(tmpdir, "uncharacterized.txt") - self._download_file(self.file_ids["uncharacterized"], tmp_path) - logging.info("Done.") - - uncharacterized = [] - with open(tmp_path) as f: - lines = f.readlines() - for line in lines[9:-1]: - uncharacterized.append(int(line.split()[0])) - return uncharacterized - - def _download_atomrefs(self, tmpdir): - logging.info("Downloading GDB-9 atom references...") - tmp_path = os.path.join(tmpdir, "atomrefs.txt") - self._download_file(self.file_ids["atomrefs"], tmp_path) - logging.info("Done.") - - props = [QM9.zpve, QM9.U0, QM9.U, QM9.H, QM9.G, QM9.Cv] - atref = {p: np.zeros((100,)) for p in props} - with open(tmp_path) as f: - lines = f.readlines() - for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): - for i, p in enumerate(props): - atref[p][z] = float(l.split()[i + 1]) - atref = {k: v.tolist() for k, v in atref.items()} - return atref - - def _download_data(self, tmpdir, dataset: ASEAtomsData, uncharacterized: List[int]): - logging.info("Downloading GDB-9 data...") - tar_path = os.path.join(tmpdir, "gdb9.tar.gz") - raw_path = os.path.join(tmpdir, "gdb9_xyz") - self._download_file(self.file_ids["data"], tar_path) - logging.info("Done.") - - logging.info("Extracting files...") - tar = tarfile.open(tar_path) - tar.extractall(raw_path) - tar.close() - logging.info("Done.") - - logging.info("Parse xyz files...") - ordered_files = sorted( - os.listdir(raw_path), key=lambda x: (int(re.sub(r"\D", "", x)), x) - ) - - property_list = [] - - irange = np.arange(len(ordered_files), dtype=int) - if uncharacterized is not None: - irange = np.setdiff1d(irange, np.array(uncharacterized, dtype=int) - 1) - - for i in tqdm(irange): - xyzfile = os.path.join(raw_path, ordered_files[i]) - properties = {} - - tmp = io.StringIO() - with open(xyzfile, "r") as f: - lines = f.readlines() - l = lines[1].split()[2:] - for pn, p in zip(dataset.available_properties, l): - properties[pn] = np.array([float(p)]) - for line in lines: - tmp.write(line.replace("*^", "e")) - - tmp.seek(0) - ats: Atoms = list(read_xyz(tmp, 0))[0] - properties[structure.Z] = ats.numbers - properties[structure.R] = ats.positions - properties[structure.cell] = ats.cell - properties[structure.pbc] = ats.pbc - property_list.append(properties) - - logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) - logging.info("Done.") From 7406cefc15e7480a76f4aef78089724baa1462d1 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 17:36:23 +0100 Subject: [PATCH 25/68] refactor: update rMD17 dataset class --- src/schnetpack/datasets/rmd17.py | 356 +++++++++++++++---------------- 1 file changed, 174 insertions(+), 182 deletions(-) diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index b3f27c35c..c75af1f5d 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -1,32 +1,32 @@ import logging import os import shutil -import tempfile import tarfile -from typing import List, Optional, Dict -from urllib import request as request +import tempfile +from typing import Dict, List, Optional +from urllib.request import Request, urlopen +from urllib.error import HTTPError, URLError import numpy as np from ase import Atoms -import torch import schnetpack.properties as structure - -from schnetpack.data import * +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform __all__ = ["rMD17"] -class rMD17(AtomsDataModule): +class rMD17(ASEAtomsData): """ - Revised MD17 benchmark data set for molecular dynamics of small molecules + Revised MD17 benchmark dataset for molecular dynamics of small molecules containing molecular forces. References: .. [#md17_1] https://figshare.com/articles/dataset/ Revised_MD17_dataset_rMD17_/12672038?file=24013628 .. [#md17_2] http://quantum-machine.org/gdml/#datasets - """ energy = "energy" @@ -46,219 +46,211 @@ class rMD17(AtomsDataModule): ] } - datasets_dict = dict( - aspirin="rmd17_aspirin.npz", - azobenzene="rmd17_azobenzene.npz", - benzene="rmd17_benzene.npz", - ethanol="rmd17_ethanol.npz", - malonaldehyde="rmd17_malonaldehyde.npz", - naphthalene="rmd17_naphthalene.npz", - paracetamol="rmd17_paracetamol.npz", - salicylic_acid="rmd17_salicylic.npz", - toluene="rmd17_toluene.npz", - uracil="rmd17_uracil.npz", - ) - - # properties + datasets_dict = { + "aspirin": "rmd17_aspirin.npz", + "azobenzene": "rmd17_azobenzene.npz", + "benzene": "rmd17_benzene.npz", + "ethanol": "rmd17_ethanol.npz", + "malonaldehyde": "rmd17_malonaldehyde.npz", + "naphthalene": "rmd17_naphthalene.npz", + "paracetamol": "rmd17_paracetamol.npz", + "salicylic_acid": "rmd17_salicylic.npz", + "toluene": "rmd17_toluene.npz", + "uracil": "rmd17_uracil.npz", + } + + download_urls = [ + "https://figshare.com/ndownloader/files/23950376", + "https://archive.materialscloud.org/records/pfffs-fff86/files/rmd17.tar.bz2?download=1", + ] + def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - split_id: Optional[int] = None, **kwargs, ): - """ - Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then - batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then - batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers - (overrides num_workers). - num_test_workers: Number of test data loader workers - (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string - (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for - faster performance. - split_id: The id of the predefined rMD17 train/test splits (0-4). - """ + if molecule not in self.datasets_dict: + raise AtomsDataError(f"Molecule {molecule} is not supported!") - if split_id is not None: - splitting = SubsamplePartitions( - split_partition_sources=["known", "known", "test"], split_id=split_id - ) - else: - splitting = RandomSplit() + self.molecule = molecule + self.format = format + + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, - splitting=splitting, **kwargs, ) - if molecule not in rMD17.datasets_dict.keys(): - raise AtomsDataModuleError("Molecule {} is not supported!".format(molecule)) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + rMD17.energy: "kcal/mol", + rMD17.forces: "kcal/mol/Ang", + } - self.molecule = molecule - - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - rMD17.energy: "kcal/mol", - rMD17.forces: "kcal/mol/Ang", - } - - tmpdir = tempfile.mkdtemp("md17") - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=rMD17.atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) - - self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + """ + Ensure the ASE DB exists and matches the requested molecule. + """ + if os.path.exists(datapath): + dataset = ASEAtomsData(datapath, load_structure=False) md = dataset.metadata + if "molecule" not in md: - raise AtomsDataModuleError( - "Not a valid rMD17 dataset! The molecule needs to be specified in " - + "the metadata." + raise AtomsDataError( + "Not a valid rMD17 dataset. Metadata must contain `molecule`." ) + if md["molecule"] != self.molecule: - raise AtomsDataModuleError( - f"The dataset at the given location does not contain the specified " - + f"molecule: `{md['molecule']}` instead of `{self.molecule}`" + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." ) + return + + tmpdir = tempfile.mkdtemp("rmd17") + try: + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self.atomrefs, + ) + dataset.update_metadata(molecule=self.molecule) + self._download_data(tmpdir, dataset) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: + logging.info("Downloading %s data...", self.molecule) - def _download_data( - self, - tmpdir, - dataset: ASEAtomsData, - ): - logging.info("Downloading {} data".format(self.molecule)) raw_path = os.path.join(tmpdir, "rmd17") - tar_path = os.path.join(tmpdir, "rmd17.tar.gz") - url = "https://figshare.com/ndownloader/files/23950376" - request.urlretrieve(url, tar_path) + tar_path = os.path.join(tmpdir, "rmd17.tar") + + self._download_archive(tar_path) logging.info("Done.") logging.info("Extracting data...") - tar = tarfile.open(tar_path) - tar.extract( - path=raw_path, member=f"rmd17/npz_data/{self.datasets_dict[self.molecule]}" - ) + os.makedirs(raw_path, exist_ok=True) + + with tarfile.open(tar_path, mode="r:*") as tar: + tar.extract( + path=raw_path, + member=f"rmd17/npz_data/{self.datasets_dict[self.molecule]}", + ) - logging.info("Parsing molecule {:s}".format(self.molecule)) + logging.info("Parsing molecule %s", self.molecule) - data = np.load( - os.path.join( - raw_path, "rmd17", "npz_data", self.datasets_dict[self.molecule] + data = np.load( + os.path.join( + raw_path, + "rmd17", + "npz_data", + self.datasets_dict[self.molecule], + ) ) - ) - numbers = data["nuclear_charges"] - property_list = [] - for positions, energies, forces in zip( - data["coords"], data["energies"], data["forces"] - ): - ats = Atoms(positions=positions, numbers=numbers) - properties = { - rMD17.energy: np.array([energies]), - rMD17.forces: forces, - structure.Z: ats.numbers, - structure.R: ats.positions, - structure.cell: ats.cell, - structure.pbc: ats.pbc, - } - property_list.append(properties) - - logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) - logging.info("Done.") + numbers = data["nuclear_charges"] + property_list = [] + + for positions, energies, forces in zip( + data["coords"], data["energies"], data["forces"] + ): + ats = Atoms(positions=positions, numbers=numbers) + properties = { + rMD17.energy: np.array([energies]), + rMD17.forces: forces, + structure.Z: ats.numbers, + structure.R: ats.positions, + structure.cell: ats.cell, + structure.pbc: ats.pbc, + } + property_list.append(properties) - train_splits = [] - test_splits = [] - for i in range(1, 6): - tar.extract(path=raw_path, member=f"rmd17/splits/index_train_0{i}.csv") - tar.extract(path=raw_path, member=f"rmd17/splits/index_test_0{i}.csv") + logging.info("Write atoms to db...") + dataset.add_systems(property_list=property_list) + logging.info("Done.") - train_split = ( - np.loadtxt( - os.path.join(raw_path, "rmd17", "splits", f"index_train_0{i}.csv") + train_splits = [] + test_splits = [] + + for i in range(1, 6): + tar.extract(path=raw_path, member=f"rmd17/splits/index_train_0{i}.csv") + tar.extract(path=raw_path, member=f"rmd17/splits/index_test_0{i}.csv") + + train_split = ( + np.loadtxt( + os.path.join( + raw_path, "rmd17", "splits", f"index_train_0{i}.csv" + ) + ) + .flatten() + .astype(int) + .tolist() ) - .flatten() - .astype(int) - .tolist() - ) - train_splits.append(train_split) - test_split = ( - np.loadtxt( - os.path.join(raw_path, "rmd17", "splits", f"index_test_0{i}.csv") + train_splits.append(train_split) + + test_split = ( + np.loadtxt( + os.path.join( + raw_path, "rmd17", "splits", f"index_test_0{i}.csv" + ) + ) + .flatten() + .astype(int) + .tolist() ) - .flatten() - .astype(int) - .tolist() - ) - test_splits.append(test_split) + test_splits.append(test_split) dataset.update_metadata(splits={"known": train_splits, "test": test_splits}) - - tar.close() logging.info("Done.") + + def _download_archive(self, destination: str) -> None: + last_error = None + + for url in self.download_urls: + try: + logging.info("Downloading from: %s", url) + req = Request(url) + with urlopen(req, timeout=600) as resp, open(destination, "wb") as f: + shutil.copyfileobj(resp, f) + + if not os.path.exists(destination): + raise RuntimeError("Download did not create a file.") + + size = os.path.getsize(destination) + ctype = (resp.headers.get("Content-Type") or "").lower() + + if size == 0: + raise RuntimeError("Downloaded file is empty.") + + if "text/html" in ctype: + raise RuntimeError( + f"Got HTML instead of archive (Content-Type={ctype})." + ) + + return + + except (HTTPError, URLError, RuntimeError) as e: + last_error = e + logging.warning("Download failed from %s: %s", url, e) + + raise AtomsDataError( + f"rMD17 download failed from all sources. Last error: {last_error}" + ) From 091a0e5395321b1ead48d141f743e029439f1800 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 17:37:56 +0100 Subject: [PATCH 26/68] refactor: remove irrelevant refactor pytest --- tests/data/test_refactor.py | 235 ------------------------------------ 1 file changed, 235 deletions(-) delete mode 100644 tests/data/test_refactor.py diff --git a/tests/data/test_refactor.py b/tests/data/test_refactor.py deleted file mode 100644 index f3f6f9465..000000000 --- a/tests/data/test_refactor.py +++ /dev/null @@ -1,235 +0,0 @@ -import pytest -import torch -import numpy as np - -import schnetpack.properties as structure -import schnetpack.data.provider as providers_mod -from schnetpack.data.datamodule_v2 import AtomsDataModuleV2 -from schnetpack.transform.atomistic import AddOffsets, RemoveOffsets, ScaleProperty - - -class AtomsDataset: - """ - Adapter to make the existing `example_data` list fixture behave like a dataset. - """ - - def __init__(self, data): - self._data = list(data) - self.transforms = [] - - _, props0 = self._data[0] - self.available_properties = list(props0.keys()) - - def __len__(self): - return len(self._data) - - def subset(self, indices): - # indices can be list/np array - idx_list = list(indices) - sub = AtomsDataset([self._data[i] for i in idx_list]) - # do not carry transforms automatically; DM will attach per split - return sub - - def __getitem__(self, idx): - ats, props = self._data[idx] - out = {} - - # Structure keys expected by SchNetPack transforms/loader - out[structure.Z] = torch.tensor(ats.numbers, dtype=torch.long) - out[structure.R] = torch.tensor(np.asarray(ats.positions), dtype=torch.float) - out[structure.cell] = torch.tensor( - np.asarray(ats.cell.array), dtype=torch.float - ) - out[structure.pbc] = torch.tensor(np.asarray(ats.pbc), dtype=torch.bool) - out[structure.n_atoms] = torch.tensor([len(ats.numbers)], dtype=torch.long) - - # Add properties - for k, v in props.items(): - # ensure torch tensor - if isinstance(v, torch.Tensor): - out[k] = v - else: - out[k] = torch.tensor(np.asarray(v), dtype=torch.float) - - # Apply transforms per-system (like BaseAtomsData.__getitem__) - for t in self.transforms: - out = t(out) - - return out - - -def _first_scalar_property_key(dataset: AtomsDataset): - # choose the first property key, but ensure it's not a structure key - for p in dataset.available_properties: - if p not in ( - structure.Z, - structure.R, - structure.cell, - structure.pbc, - structure.n_atoms, - ): - return p - raise AssertionError( - "No suitable scalar property found in dataset.available_properties" - ) - - -def _make_constant_atomrefs(zmax=100, value=1.0): - atref = torch.zeros((zmax,), dtype=torch.float) - atref[:] = float(value) - return atref - - -@pytest.mark.parametrize("batch_size", [1, 4]) -def test_v2_setup_attaches_transforms(example_data, batch_size): - dataset = AtomsDataset(example_data) - - dm = AtomsDataModuleV2( - dataset=dataset, - batch_size=batch_size, - num_train=0.6, - num_val=0.2, - num_test=0.2, - transforms=[], - num_workers=0, - split_file=None, - ) - dm.setup() - - assert dm.train_dataset is not None - assert dm.val_dataset is not None - assert dm.test_dataset is not None - - # transforms attribute exists and is settable - dm.train_dataset.transforms = [] - dm.val_dataset.transforms = [] - dm.test_dataset.transforms = [] - - -def test_provider_initializes_stats_transforms(example_data): - dataset = AtomsDataset(example_data) - prop = _first_scalar_property_key(dataset) - - transforms = [ - RemoveOffsets( - property=prop, remove_mean=True, remove_atomrefs=False, is_extensive=True - ), - ScaleProperty( - input_key=prop, target_key=prop, output_key=prop, scale_by_mean=False - ), - ] - - dm = AtomsDataModuleV2( - dataset=dataset, - batch_size=4, - num_train=0.6, - num_val=0.2, - num_test=0.2, - transforms=transforms, - num_workers=0, - split_file=None, - ) - dm.setup() - - ro = transforms[0] - sp = transforms[1] - - assert hasattr(ro, "mean") - assert ro.mean is not None - assert hasattr(sp, "scale") - assert sp.scale is not None - - -@pytest.mark.parametrize("is_extensive", [True, False]) -def test_addoffsets_unbatched_and_batched(example_data, is_extensive): - dataset = AtomsDataset(example_data) - prop = _first_scalar_property_key(dataset) - - zmax = 100 - atomref_tensor = _make_constant_atomrefs(zmax=zmax, value=1.0) - - t = AddOffsets( - property=prop, - add_mean=False, - add_atomrefs=True, - is_extensive=is_extensive, - zmax=zmax, - atomrefs=atomref_tensor, - ) - - dm = AtomsDataModuleV2( - dataset=dataset, - batch_size=4, - num_train=0.6, - num_val=0.2, - num_test=0.2, - transforms=[t], - num_workers=0, - split_file=None, - ) - dm.setup() - - # Unbatched: dataset[0] (transform runs in __getitem__) - old = dm.train_dataset.transforms - dm.train_dataset.transforms = [] - raw = dm.train_dataset[0] - dm.train_dataset.transforms = old - one = dm.train_dataset[0] - - y_raw = raw[prop] - y_one = one[prop] - delta = (y_one - y_raw).detach().view(-1)[0] - - n_atoms = int(one[structure.n_atoms].view(-1)[0].item()) - expected = float(n_atoms) if is_extensive else 1.0 - - assert torch.allclose(delta, torch.tensor(expected, dtype=delta.dtype), atol=1e-6) - - # Batched: loader should include idx_m and not crash - batch = next(iter(dm.train_dataloader())) - assert structure.idx_m in batch - assert prop in batch - - -def test_provider_caches_stats_calls(example_data, monkeypatch): - """ - Ensure provider caching prevents recomputing the same stats key multiple times. - In this test transforms request the same key when: - - RemoveOffsets is_extensive=True and remove_atomrefs=False -> (prop, True, False) - - ScaleProperty always requests (prop, True, False) - """ - - dataset = AtomsDataset(example_data) - prop = _first_scalar_property_key(dataset) - - call_count = {"n": 0} - real_calculate_stats = providers_mod.calculate_stats - - def wrapped_calculate_stats(*args, **kwargs): - call_count["n"] += 1 - return real_calculate_stats(*args, **kwargs) - - monkeypatch.setattr(providers_mod, "calculate_stats", wrapped_calculate_stats) - - transforms = [ - RemoveOffsets( - property=prop, remove_mean=True, remove_atomrefs=False, is_extensive=True - ), - ScaleProperty( - input_key=prop, target_key=prop, output_key=prop, scale_by_mean=False - ), - ] - - dm = AtomsDataModuleV2( - dataset=dataset, - batch_size=4, - num_train=0.6, - num_val=0.2, - num_test=0.2, - transforms=transforms, - num_workers=0, - split_file=None, - ) - dm.setup() - - assert call_count["n"] == 1 From 2d92e05c6324e3b536d463251ac2cde22ecb562b Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 20:33:39 +0100 Subject: [PATCH 27/68] refactor: update md17, md22, qm7x, rmd17 --- src/schnetpack/datasets/md17.py | 228 +++++++------------ src/schnetpack/datasets/md22.py | 64 ++---- src/schnetpack/datasets/qm7x.py | 370 +++++++++++-------------------- src/schnetpack/datasets/qm9.py | 12 + src/schnetpack/datasets/rmd17.py | 16 +- 5 files changed, 245 insertions(+), 445 deletions(-) diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 6cd866269..39903f0bb 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -4,149 +4,122 @@ import tempfile from typing import List, Optional, Dict from urllib import request as request +import torch import numpy as np from ase import Atoms -import torch import schnetpack.properties as structure - -from schnetpack.data import * +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["MD17"] -class GDMLDataModule(AtomsDataModule): +class GDMLDataset(ASEAtomsData): """ - Base class for GDML type data (e.g. MD17 or MD22). Requires a dictionary translating between molecule and filenames - and an URL under which the molecular datasets can be found. + Base class for GDML-type datasets (e.g. MD17 or MD22). + Requires a dictionary translating between molecule and filenames + and a URL under which the molecular datasets can be found. """ energy = "energy" forces = "forces" - # properties def __init__( self, datasets_dict: Dict[str, str], download_url: str, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", + tmpdir: str = "gdml_tmp", + atomrefs: Optional[Dict[str, List[float]]] = None, format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - tmpdir: str = "gdml_tmp", - atomrefs: Optional[Dict[str, List[float]]] = None, **kwargs, ): """ Args: datasets_dict: dictionary mapping molecule names to dataset names. download_url: URL where individual molecule datasets can me found. - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + datapath: path to dataset. + molecule: name of the molecule. tmpdir: name of temporary directory used for parsing. - atomrefs: properties of free atoms + atomrefs: properties of free atoms. + format: dataset format (e.g: ASE). + load_properties: subset of properties to load. + transforms: Transform applied to each system separately before batching. + subset_idx: indices of the subset to load. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. """ - super().__init__( - datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, - load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, - transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, - property_units=property_units, - distance_unit=distance_unit, - data_workdir=data_workdir, - **kwargs, - ) - self.datasets_dict = datasets_dict self.download_url = download_url - self.atomrefs = atomrefs + self._native_atomrefs = atomrefs self.tmpdir = tmpdir + self.format = format - if molecule not in self.datasets_dict.keys(): - raise AtomsDataModuleError("Molecule {} is not supported!".format(molecule)) + if molecule not in self.datasets_dict: + raise AtomsDataError(f"Molecule {molecule} is not supported!") self.molecule = molecule - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - self.energy: "kcal/mol", - self.forces: "kcal/mol/Ang", - } + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) - tmpdir = tempfile.mkdtemp(self.tmpdir) + super().__init__( + datapath=datapath, + load_properties=load_properties, + transforms=transforms, + subset_idx=subset_idx, + property_units=property_units, + distance_unit=distance_unit, + **kwargs, + ) - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=self.atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + GDMLDataset.energy: "kcal/mol", + GDMLDataset.forces: "kcal/mol/Ang", + } - self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + if os.path.exists(datapath): + dataset = ASEAtomsData(datapath, load_structure=False) md = dataset.metadata + if "molecule" not in md: - raise AtomsDataModuleError( - "Not a valid GDML dataset! The molecule needs to be specified in the metadata." + raise AtomsDataError( + "Not a valid GDML dataset. Metadata must contain `molecule`." ) + if md["molecule"] != self.molecule: - raise AtomsDataModuleError( - f"The dataset at the given location does not contain the specified molecule: " - + f"`{md['molecule']}` instead of `{self.molecule}`" + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." ) + return + + tmpdir = tempfile.mkdtemp(self.tmpdir) + try: + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self._native_atomrefs, + ) + dataset.update_metadata(molecule=self.molecule) + self._download_data(tmpdir, dataset) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) def _download_data( self, @@ -184,62 +157,38 @@ def _download_data( logging.info("Done.") -class MD17(GDMLDataModule): +class MD17(GDMLDataset): """ MD17 benchmark data set for molecular dynamics of small molecules containing molecular forces. References: .. [#md17_1] http://quantum-machine.org/gdml/#datasets - """ def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms=None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, **kwargs, ): """ Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. + datapath: path to dataset. + molecule: name of the molecule. + format: dataset format (e.g: ASE). + load_properties: subset of properties to load. transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + subset_idx: indices of the subset to load. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. """ atomrefs = { self.energy: [ @@ -257,46 +206,29 @@ def __init__( datasets_dict = dict( aspirin="md17_aspirin.npz", - # aspirin_ccsd='aspirin_ccsd.zip', azobenzene="azobenzene_dft.npz", benzene="md17_benzene2017.npz", ethanol="md17_ethanol.npz", - # ethanol_ccsdt='ethanol_ccsd_t.zip', malonaldehyde="md17_malonaldehyde.npz", - # malonaldehyde_ccsdt='malonaldehyde_ccsd_t.zip', naphthalene="md17_naphthalene.npz", paracetamol="paracetamol_dft.npz", salicylic_acid="md17_salicylic.npz", toluene="md17_toluene.npz", - # toluene_ccsdt='toluene_ccsd_t.zip', uracil="md17_uracil.npz", ) - super(MD17, self).__init__( + super().__init__( datasets_dict=datasets_dict, download_url="http://www.quantum-machine.org/gdml/data/npz/", tmpdir="md17", molecule=molecule, datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, atomrefs=atomrefs, **kwargs, ) diff --git a/src/schnetpack/datasets/md22.py b/src/schnetpack/datasets/md22.py index 9d51008f7..936f68320 100644 --- a/src/schnetpack/datasets/md22.py +++ b/src/schnetpack/datasets/md22.py @@ -1,68 +1,42 @@ -import torch from typing import Optional, Dict, List -from schnetpack.data import * -from schnetpack.datasets.md17 import GDMLDataModule +from schnetpack.data import AtomsDataFormat +from schnetpack.datasets.md17 import GDMLDataset +__all__ = ["MD22"] -all = ["MD22"] - -class MD22(GDMLDataModule): +class MD22(GDMLDataset): """ MD22 benchmark data set for extended molecules containing molecular forces. References: .. [#md22_1] http://quantum-machine.org/gdml/#datasets - """ def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms=None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, **kwargs, ): """ Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions + molecule: name of the molecule format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). + subset_idx: indices of the subset to load. + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + **kwargs: additional keyword arguments. """ atomrefs = { self.energy: [ @@ -77,6 +51,7 @@ def __init__( -47069.30768969713, ] } + datasets_dict = { "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz", "DHA": "md22_DHA.npz", @@ -87,31 +62,18 @@ def __init__( "double-walled_nanotube": "md22_double-walled_nanotube.npz", } - super(MD22, self).__init__( + super().__init__( datasets_dict=datasets_dict, download_url="http://www.quantum-machine.org/gdml/repo/datasets/", tmpdir="md22", molecule=molecule, datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, atomrefs=atomrefs, **kwargs, ) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index e1179ac62..6e9a7236a 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -11,24 +11,19 @@ import h5py import numpy as np import progressbar -import torch from ase import Atoms -from tqdm import tqdm -from schnetpack.data import * -from schnetpack.data import AtomsDataModule -from schnetpack.data.splitting import GroupSplit +import schnetpack.properties as structure +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset +from schnetpack.transform.base import Transform __all__ = ["QM7X"] -# Helper functions pbar = None def show_progress(block_num: int, block_size: int, total_size: int): - """ - progress callback for files downloads - """ global pbar if pbar is None: pbar = progressbar.ProgressBar(maxval=total_size) @@ -42,42 +37,32 @@ def show_progress(block_num: int, block_size: int, total_size: int): pbar = None -def download_and_check(url: str, tar_path: str, checksum: str): - """ - Download file from url to tar_path and check md5 checksum. - """ +def download_and_check(url: str, target_path: str, checksum: str): + file_name = url.split("/")[-1] - file = url.split("/")[-1] - - # check if file already exists and has correct checksum - if os.path.exists(tar_path): - md5_sum = hashlib.md5(open(tar_path, "rb").read()).hexdigest() + if os.path.exists(target_path): + md5_sum = hashlib.md5(open(target_path, "rb").read()).hexdigest() if md5_sum == checksum: logging.info( - f"File {file} already exists and has correct checksum. Skipping download." + f"File {file_name} already exists and has correct checksum. Skipping download." ) return - else: - logging.info( - f"File {file} already exists but has wrong checksum. Redownloading." - ) - os.remove(tar_path) + logging.info( + f"File {file_name} already exists but has wrong checksum. Redownloading." + ) + os.remove(target_path) logging.info(f"Downloading {url} ...") - request.urlretrieve(url, tar_path, show_progress) + request.urlretrieve(url, target_path, show_progress) - if hashlib.md5(open(tar_path, "rb").read()).hexdigest() == checksum: - logging.info("Done.") - else: + if hashlib.md5(open(target_path, "rb").read()).hexdigest() != checksum: raise RuntimeError( - f"Checksum of downloaded file {file} does not match. Please try again." + f"Checksum of downloaded file {file_name} does not match. Please try again." ) + logging.info("Done.") def extract_xz(source: str, target: str): - """ - helper to extract xz files. - """ s_file = source.split("/")[-1] t_file = target.split("/")[-1] @@ -86,43 +71,30 @@ def extract_xz(source: str, target: str): return logging.info(f"Extracting {s_file} ...") - try: with lzma.open(source) as fin, open(target, mode="wb") as fout: shutil.copyfileobj(fin, fout) - except: + except Exception as e: if os.path.exists(target): os.remove(target) - raise RuntimeError(f"Could not extract file {s_file}. Please try again.") + raise RuntimeError(f"Could not extract file {s_file}. Please try again.") from e logging.info("Done.") -class QM7X(AtomsDataModule): +class QM7X(ASEAtomsData): """ - QM7-X a comprehensive dataset of > 40 physicochemical properties for ~4.2 M equilibrium and non-equilibrium - structure of small organic molecules with up to seven non-hydrogen (C, N, O, S, Cl) atoms. - This class adds convenient functions to download QM7-X and load the data into pytorch. - - References: - - .. [#qm7x_1] https://zenodo.org/record/4288677 - + QM7-X dataset of equilibrium and non-equilibrium structures of small organic molecules. """ - # more molecular and atomic properties can be found in the original paper and added here - # Notice that adding more properties can drastically increase the size of the dataset - # adding more properties here requires to add them to the property_unit_dict - # and there key mapping in the raw dataset in property_dataset_keys. - - forces = "forces" # total ePBE0+MBD forces - energy = "energy" # ePBE0+MBD: total energy after convergence of the PBE0 exchange-correlation functional and the MBD dispersion correction - Eat = "Eat" # atomization energy using PBE0 energy per atom and ePBE0+MBD total energy - EPBE0 = "EPBE0" # ePBE0: total energy at the level of PBE0 - EMBD = "EMBD" # eMBD: total energy at the level of MBD - FPBE0 = "FMBD" # FPBE0: total ePBE0 forces - FMBD = "FMBD" # FMBD: total eMBD forces - RMSD = "rmsd" # root mean square deviation of the atomic positions from the equilibrium structure + forces = "forces" + energy = "energy" + Eat = "Eat" + EPBE0 = "EPBE0" + EMBD = "EMBD" + FPBE0 = "FPBE0" + FMBD = "FMBD" + RMSD = "rmsd" property_unit_dict = { forces: "eV/Ang", @@ -135,7 +107,6 @@ class QM7X(AtomsDataModule): RMSD: "Ang", } - # the original keys in the raw dataset to query the properties property_dataset_keys = { forces: "totFOR", energy: "ePBE0+MBD", @@ -147,7 +118,6 @@ class QM7X(AtomsDataModule): RMSD: "sRMSD", } - # atom energies (atomrefs) from PBE0 EPBE0_atom = { 1: -13.641404161, 6: -1027.592489146, @@ -160,114 +130,118 @@ class QM7X(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - raw_data_path: str = None, + raw_data_path: Optional[str] = None, remove_duplicates: bool = True, only_equilibrium: bool = False, only_non_equilibrium: bool = False, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - splitting: Optional[SplittingStrategy] = None, **kwargs, ): - """ - Args: - datapath: path to dataset - batch_size: (train) batch size - raw_data_path: path to raw data. If None use tmp dir otherwise persist data and not remove it. - remove_duplicates: remove duplicated equilibrium structures with different non-equilibrium structures - only_equilibrium: only use equilibrium structures - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. - splitting: Method to generate train/validation/test partitions - (default: GroupSplit(splitting_key="smiles_id")) - """ + if only_equilibrium and only_non_equilibrium: + raise AtomsDataError( + "only_equilibrium and only_non_equilibrium cannot both be True." + ) + + self.raw_data_path = raw_data_path + self.remove_duplicates = remove_duplicates + self.duplicates_ids = None + self.only_equilibrium = only_equilibrium + self.only_non_equilibrium = only_non_equilibrium + self.format = format + + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) + + # initialize without subset first, then apply dataset-specific filtering super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=None, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, - splitting=splitting or GroupSplit(splitting_key="smiles_id"), **kwargs, ) - self.raw_data_path = raw_data_path - self.remove_duplicates = remove_duplicates - self.duplicates_ids = None - self.only_equilibrium = only_equilibrium - self.only_non_equilibrium = only_non_equilibrium + self._apply_structure_filter(original_subset_idx=subset_idx) + + def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> None: + effective_subset = original_subset_idx + + if self.only_equilibrium or self.only_non_equilibrium: + step_ids = self.metadata["groups_ids"]["step_id"] + + if len(step_ids) != self.conn.count(): + raise AtomsDataError( + "Dataset size does not match size of step_id metadata." + ) + + if self.only_equilibrium: + filtered = [i for i, s in enumerate(step_ids) if s == 0] + else: + filtered = [i for i, s in enumerate(step_ids) if s != 0] + + if effective_subset is None: + effective_subset = filtered + else: + filtered_set = set(filtered) + effective_subset = [i for i in effective_subset if i in filtered_set] + + self.subset_idx = effective_subset + + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + if os.path.exists(datapath): + _ = load_dataset(datapath, self.format, load_structure=False) + return + + tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") + try: + atomrefs = { + QM7X.energy: [ + QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 + for i in range(0, 18) + ] + } + + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=QM7X.property_unit_dict, + atomrefs=atomrefs, + ) + + hd_files = self._download_data(tar_dir) + if self.remove_duplicates: + self._download_duplicates_ids(tar_dir) + self._parse_data(hd_files, dataset) + + finally: + if self.raw_data_path is None: + shutil.rmtree(tar_dir, ignore_errors=True) def _download_duplicates_ids(self, tar_dir: str): - """ - download duplicates ids for QM7-X - """ - url = f"https://zenodo.org/record/4288677/files/DupMols.dat" - tar_path = os.path.join(tar_dir, "DupMols.dat") + url = "https://zenodo.org/record/4288677/files/DupMols.dat" + target_path = os.path.join(tar_dir, "DupMols.dat") checksum = "5d886ccac38877c8cb26c07704dd1034" - download_and_check(url, tar_path, checksum) + download_and_check(url, target_path, checksum) - # fetch duplicates ids dup_mols = [] - for line in open(tar_path, "r"): - dup_mols.append(line.rstrip("\n")[:-4]) + with open(target_path, "r") as f: + for line in f: + dup_mols.append(line.rstrip("\n")[:-4]) + self.duplicates_ids = dup_mols def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[str]: - """ - download data and extract them - """ file_ids = ["1000", "2000", "3000", "4000", "5000", "6000", "7000", "8000"] - - # file fingerprints to check integrity checksums = [ "b50c6a5d0a4493c274368cf22285503e", "4418a813daf5e0d44aa5a26544249ee6", @@ -281,42 +255,31 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st logging.info("Downloading QM7-X data files ...") - # download files for i, file_id in enumerate(file_ids): if ignore_extracted and os.path.exists( os.path.join(tar_dir, f"{file_id}.hdf5") ): logging.info( - f"File {file_id}.xz exists in extracted version {file_id}.hdf5 already, skipping download." + f"File {file_id}.hdf5 already exists. Skipping download of {file_id}.xz." ) continue url = f"https://zenodo.org/record/4288677/files/{file_id}.xz" + xz_path = os.path.join(tar_dir, f"{file_id}.xz") + download_and_check(url, xz_path, checksums[i]) - tar_path = os.path.join(tar_dir, f"{file_id}.xz") - download_and_check(url, tar_path, checksums[i]) - - # extract the compressed files extracted = [] - for i, file_id in enumerate(file_ids): + for file_id in file_ids: xz_path = os.path.join(tar_dir, f"{file_id}.xz") hd_path = os.path.join(tar_dir, f"{file_id}.hdf5") - extract_xz(xz_path, hd_path) - extracted.append(hd_path) return extracted def _parse_data(self, files: List[str], dataset: ASEAtomsData): - """ - Parse the downloaded data files and add them to the dataset. - """ - - # parse the data files - for file in files: - logging.info(f"Parsing {file.split('/')[-1]} ...") + logging.info(f"Parsing {os.path.basename(file)} ...") atoms_list = [] property_list = [] @@ -328,9 +291,8 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): } with h5py.File(file, "r") as mol_dict: - for mol_id, mol in tqdm(mol_dict.items()): + for _mol_id, mol in mol_dict.items(): for conf_id, conf in mol.items(): - # exclude equilibrium duplicates trunc_id = conf_id[::-1].split("-", 1)[-1][::-1] if self.remove_duplicates and trunc_id in self.duplicates_ids: continue @@ -343,32 +305,25 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): for key in QM7X.property_unit_dict.keys() } - # get the hierarchical ids for each system if "opt" in conf_id: - conf_id = ( - conf_id[:-3] + "d0" - ) # repalce the 'opt' key with id 'd0' + conf_id = conf_id[:-3] + "d0" + ids = map(lambda x: int(x), re.findall(r"\d+", conf_id)) atoms_list.append(ats) property_list.append(properties) - # save the hierarchical ids for each system in same order as the systems - for i, j in zip(groups_ids.keys(), ids): - groups_ids[i].append(j) - - # add the data to the dataset - logging.info(f"Write parsed data from {file.split('/')[-1]} to db ...") + for key, idx in zip(groups_ids.keys(), ids): + groups_ids[key].append(idx) + logging.info(f"Write parsed data from {os.path.basename(file)} to db ...") dataset.add_systems(property_list=property_list, atoms_list=atoms_list) - # add the hierarchical ids to the metadata md = dataset.metadata - if "groups_ids" in md.keys(): + if "groups_ids" in md: for key, ids in groups_ids.items(): groups_ids[key] = md["groups_ids"][key] + ids - # add the ids as in the database of the new added systems last_id = md["groups_ids"]["id"][-1] sys_ids = list(range(last_id + 1, last_id + len(atoms_list) + 1)) groups_ids["id"] = md["groups_ids"]["id"] + sys_ids @@ -376,79 +331,4 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): groups_ids["id"] = list(range(1, len(atoms_list) + 1)) dataset.update_metadata(groups_ids=groups_ids) - logging.info("Done.") - - def prepare_data(self): - """ - prepare data for pytorch lightning data module - """ - if not os.path.exists(self.datapath): - tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") - - atomrefs = { - QM7X.energy: [ - QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 - for i in range(0, 18) - ] - } - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=QM7X.property_unit_dict, - atomrefs=atomrefs, - ) - - hd_files = self._download_data(tar_dir) - if self.remove_duplicates: - self._download_duplicates_ids(tar_dir) - self._parse_data(hd_files, dataset) - - if self.raw_data_path is None: - shutil.rmtree(tar_dir) - - def setup(self, stage=None): - if self.data_workdir is None: - datapath = self.datapath - else: - datapath = self._copy_to_workdir() - - # (re)load datasets - if self.dataset is None: - self.dataset = load_dataset( - datapath, - self.format, - property_units=self.property_units, - distance_unit=self.distance_unit, - load_properties=self.load_properties, - ) - - # use subset of equilibrium structures - - if self.only_equilibrium or self.only_non_equilibrium: - step_ids = self.dataset.metadata["groups_ids"]["step_id"] - - if len(step_ids) != len(self.dataset): - raise ValueError( - "The dataset size does not match the size of step ids arrays in meta data." - ) - - if self.only_equilibrium: - eq_indices = [i for i, s in enumerate(step_ids) if s == 0] - else: - eq_indices = [i for i, s in enumerate(step_ids) if s != 0] - - self.dataset = self.dataset.subset(eq_indices) - - # load and generate partitions if needed - if self.train_idx is None: - self._load_partitions() - - # partition dataset - self._train_dataset = self.dataset.subset(self.train_idx) - self._val_dataset = self.dataset.subset(self.val_idx) - self._test_dataset = self.dataset.subset(self.test_idx) - - self._setup_transforms() diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 98f5dcad3..f6031158c 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -65,6 +65,18 @@ def __init__( distance_unit: Optional[str] = None, **kwargs, ): + """ + Args: + datapath: path to dataset + format: dataset format + remove_uncharacterized: do not include uncharacterized molecules. + load_properties: subset of properties to load + transforms: Transform applied to each system separately before batching. + subset_idx: indices of the subset to load. + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. + """ self.remove_uncharacterized = remove_uncharacterized self.format = format diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index c75af1f5d..58f0d83d8 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional from urllib.request import Request, urlopen from urllib.error import HTTPError, URLError +import torch import numpy as np from ase import Atoms @@ -70,12 +71,25 @@ def __init__( molecule: str, format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - transforms: Optional[List[Transform]] = None, + transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, ): + """ + Args: + datapath: path to dataset + molecule: name of the molecule + format: dataset format + load_properties: subset of properties to load + transforms: Transform applied to each system separately before batching. + subset_idx: indices of the subset to load. + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. + """ + if molecule not in self.datasets_dict: raise AtomsDataError(f"Molecule {molecule} is not supported!") From 61740d9b1b72985a90d6c1bdb10c969f94165126 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 23:41:56 +0100 Subject: [PATCH 28/68] refactor: update dataset classes mp, ani1, iso17 --- src/schnetpack/datasets/ani1.py | 204 ++++++++----------- src/schnetpack/datasets/iso17.py | 183 ++++++++--------- src/schnetpack/datasets/materials_project.py | 174 ++++++---------- src/schnetpack/datasets/md22.py | 4 +- src/schnetpack/datasets/omdb.py | 4 +- src/schnetpack/datasets/qm7x.py | 27 ++- src/schnetpack/datasets/tmqm.py | 4 +- 7 files changed, 264 insertions(+), 336 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index ffe7e92cf..9ae478a55 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -1,32 +1,31 @@ import logging import os import shutil +import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request +import torch +import h5py import numpy as np from ase import Atoms -import torch -import tarfile -import h5py +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset +from schnetpack.transform.base import Transform -from schnetpack.data import * +__all__ = ["ANI1"] log = logging.getLogger(__name__) -class ANI1(AtomsDataModule): +class ANI1(ASEAtomsData): """ - ANI1 benchmark database. - This class adds convenience functions to download ANI1 from figshare and - load the data into pytorch. + ANI1 benchmark dataset. References: - .. [#ani1] https://arxiv.org/abs/1708.04987 - """ energy = "energy" @@ -41,166 +40,139 @@ class ANI1(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, num_heavy_atoms: int = 8, high_energies: bool = False, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, ): """ - Args: datapath: path to dataset - num_heavy_atoms: number of heavy atoms. (See 'Table 1' in Ref. [#ani1]_) - high_energies: add high energy conformations. (See 'Technical Validation' of Ref. [#ani1]_) - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions + num_heavy_atoms: number of heavy atoms + high_energies: whether to include high-energy conformations format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). + transforms: Transform applied to each system separately before batching + subset_idx: indices of the subset to load property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - + **kwargs: additional keyword arguments. """ self.num_heavy_atoms = num_heavy_atoms self.high_energies = high_energies + self.format = format + + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - ANI1.energy: "Hartree", - } - atomrefs = self._create_atomrefs() - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, - ) - tmpdir = tempfile.mkdtemp("ani1") + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + ANI1.energy: "Hartree", + } + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + """ + Ensure the ANI1 ASE DB exists. + """ + if os.path.exists(datapath): + _ = ASEAtomsData(datapath, load_structure=False) + return + + tmpdir = tempfile.mkdtemp("ani1") + try: + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self._create_atomrefs(), + ) self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data(self, tmpdir, dataset: ASEAtomsData): - logging.info("downloading ANI-1 data...") + def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: + logging.info("Downloading ANI-1 data...") tar_path = os.path.join(tmpdir, "ANI1_release.tar.gz") raw_path = os.path.join(tmpdir, "data") url = "https://ndownloader.figshare.com/files/9057631" request.urlretrieve(url, tar_path) + if not os.path.exists(tar_path): + raise AtomsDataError(f"Download failed, file not found: {tar_path}") + + if os.path.getsize(tar_path) == 0: + raise AtomsDataError(f"Downloaded file is empty: {tar_path}") + logging.info("Done.") - tar = tarfile.open(tar_path) - tar.extractall(raw_path) - tar.close() + with tarfile.open(tar_path) as tar: + tar.extractall(raw_path) - logging.info("parse files...") + logging.info("Parsing files...") for i in range(1, self.num_heavy_atoms + 1): - file_name = os.path.join(raw_path, "ANI-1_release", "ani_gdb_s0%d.h5" % i) - logging.info("start to parse %s" % file_name) + file_name = os.path.join(raw_path, "ANI-1_release", f"ani_gdb_s0{i}.h5") + logging.info("Start to parse %s", file_name) self._load_h5_file(file_name, dataset) - logging.info("done...") + logging.info("Done.") - def _load_h5_file(self, file_name, dataset): + def _load_h5_file(self, file_name: str, dataset: ASEAtomsData) -> None: atoms_list = [] properties_list = [] - store = h5py.File(file_name) - for file_key in store: - for molecule_key in store[file_key]: - molecule_group = store[file_key][molecule_key] - species = "".join([str(s)[-2] for s in molecule_group["species"]]) - positions = molecule_group["coordinates"] - energies = molecule_group["energies"] - - # loop over conformations - for i in range(energies.shape[0]): - atm = Atoms(species, positions[i]) - energy = energies[i] - properties = {self.energy: np.array([energy])} - atoms_list.append(atm) - properties_list.append(properties) - - # high energy conformations as described in 'Technical Validation' - # section of https://arxiv.org/abs/1708.04987 - if self.high_energies: - high_energy_positions = molecule_group["coordinatesHE"] - high_energies = molecule_group["energiesHE"] - - # loop over high energy conformations - for i in range(high_energies.shape[0]): - atm = Atoms(species, high_energy_positions[i]) - high_energy = high_energies[i] - properties = {self.energy: np.array([high_energy])} + with h5py.File(file_name, "r") as store: + for file_key in store: + for molecule_key in store[file_key]: + molecule_group = store[file_key][molecule_key] + species = "".join([str(s)[-2] for s in molecule_group["species"]]) + positions = molecule_group["coordinates"] + energies = molecule_group["energies"] + + # regular conformations + for i in range(energies.shape[0]): + atm = Atoms(species, positions[i]) + energy = energies[i] + properties = {self.energy: np.array([energy])} atoms_list.append(atm) properties_list.append(properties) - # write data to ase db + # high-energy conformations + # section of https://arxiv.org/abs/1708.04987 + if self.high_energies: + high_energy_positions = molecule_group["coordinatesHE"] + high_energies = molecule_group["energiesHE"] + + for i in range(high_energies.shape[0]): + atm = Atoms(species, high_energy_positions[i]) + high_energy = high_energies[i] + properties = {self.energy: np.array([high_energy])} + atoms_list.append(atm) + properties_list.append(properties) + dataset.add_systems(atoms_list=atoms_list, property_list=properties_list) - def _create_atomrefs(self): + def _create_atomrefs(self) -> Dict[str, List[float]]: atref = np.zeros((100,)) - - # converts units to eV (which are set to one in ase) atref[1] = self.self_energies["H"] atref[6] = self.self_energies["C"] atref[7] = self.self_energies["N"] atref[8] = self.self_energies["O"] - return {ANI1.energy: atref.tolist()} diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index cfc8bb391..8694fad82 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -1,31 +1,31 @@ import logging import os import shutil +import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request -from tqdm import tqdm -import numpy as np -from ase.db import connect from urllib.error import HTTPError, URLError -import tarfile import torch +import numpy as np +from ase.db import connect +from tqdm import tqdm -from schnetpack.data import * +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["ISO17"] -class ISO17(AtomsDataModule): +class ISO17(ASEAtomsData): """ - ISO17 benchmark data set for molecular dynamics of C7O2H10 isomers + ISO17 benchmark dataset for molecular dynamics of C7O2H10 isomers containing molecular forces. References: .. [#iso17] http://quantum-machine.org/datasets/ - """ energy = "total_energy" @@ -39,27 +39,14 @@ class ISO17(AtomsDataModule): "test_eq", ] - # properties def __init__( self, datapath: str, fold: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, @@ -68,103 +55,99 @@ def __init__( Args: datapath: path to dataset fold: select a specific dataset of iso17 - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). + transforms: Transform applied to each system separately before batching + subset_idx: indices of the subset to load + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. """ if fold not in self.existing_folds: - raise ValueError("Fold {:s} does not exist".format(fold)) + raise AtomsDataError(f"Fold {fold} does not exist.") - self.path = datapath + self.root_path = datapath self.fold = fold + self.format = format + dbpath = os.path.join(datapath, "iso17", fold + ".db") + self.download(datapath=dbpath, distance_unit=distance_unit or "Ang") + super().__init__( datapath=dbpath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - def prepare_data(self): - if not os.path.exists(self.datapath): - self._download_data() - else: - dataset = load_dataset(self.datapath, self.format) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + ISO17.energy: "eV", + ISO17.forces: "eV/Ang", + } + + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + """ + Ensure the ISO17 DB for the selected fold exists and has proper metadata. + """ + if os.path.exists(datapath): + _ = ASEAtomsData(datapath, load_structure=False) + return + + self._download_data() - def _download_data(self): + def _download_data(self) -> None: logging.info("Downloading ISO17 database...") tmpdir = tempfile.mkdtemp("iso17") - tarpath = os.path.join(tmpdir, "iso17.tar.gz") - url = "http://www.quantum-machine.org/datasets/iso17.tar.gz" try: - request.urlretrieve(url, tarpath) - except HTTPError as e: - logging.error("HTTP Error:", e.code, url) - return False - except URLError as e: - logging.error("URL Error:", e.reason, url) - return False - - tar = tarfile.open(tarpath) - tar.extractall(self.path) - tar.close() - - # update metadata - for fold in ISO17.existing_folds: - dbpath = os.path.join(self.path, "iso17", fold + ".db") - tmp_dbpath = os.path.join(tmpdir, "tmp.db") - with connect(dbpath) as conn: - with connect(tmp_dbpath) as tmp_conn: - tmp_conn.metadata = { - "_property_unit_dict": { - ISO17.energy: "eV", - ISO17.forces: "eV/Ang", - }, - "_distance_unit": "Ang", - "atomrefs": {}, - } - # add energy to data dict in db - for idx in tqdm( - range(len(conn)), f"parsing database file {dbpath}" - ): - atmsrw = conn.get(idx + 1) - data = atmsrw.data - data[ISO17.forces] = np.array(data[ISO17.forces]) - data[ISO17.energy] = np.array([atmsrw.total_energy]) - tmp_conn.write(atmsrw.toatoms(), data=data) - - os.remove(dbpath) - os.rename(tmp_dbpath, dbpath) - shutil.rmtree(tmpdir) + tarpath = os.path.join(tmpdir, "iso17.tar.gz") + url = "http://www.quantum-machine.org/datasets/iso17.tar.gz" + + try: + request.urlretrieve(url, tarpath) + except HTTPError as e: + raise AtomsDataError( + f"HTTP Error {e.code} while downloading {url}" + ) from e + except URLError as e: + raise AtomsDataError( + f"URL Error {e.reason} while downloading {url}" + ) from e + + with tarfile.open(tarpath) as tar: + tar.extractall(self.root_path) + + # update metadata + convert energy into row.data for every fold + for fold in self.existing_folds: + dbpath = os.path.join(self.root_path, "iso17", fold + ".db") + tmp_dbpath = os.path.join(tmpdir, f"{fold}_tmp.db") + + with connect(dbpath) as conn: + with connect(tmp_dbpath) as tmp_conn: + tmp_conn.metadata = { + "_property_unit_dict": self._native_property_units(), + "_distance_unit": "Ang", + "atomrefs": {}, + } + + for idx in tqdm( + range(len(conn)), + desc=f"parsing database file {dbpath}", + ): + atmsrw = conn.get(idx + 1) + data = atmsrw.data + data[self.forces] = np.array(data[self.forces]) + data[self.energy] = np.array([atmsrw.total_energy]) + tmp_conn.write(atmsrw.toatoms(), data=data) + + os.remove(dbpath) + os.rename(tmp_dbpath, dbpath) + + finally: + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 9189140ef..696c76db5 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -1,29 +1,25 @@ import logging import os from typing import List, Optional, Dict -import warnings - -from ase import Atoms import torch import numpy as np -from schnetpack.data import * -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from ase import Atoms +from schnetpack.data import AtomsDataFormat +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform __all__ = ["MaterialsProject"] -class MaterialsProject(AtomsDataModule): +class MaterialsProject(ASEAtomsData): """ Materials Project (MP) database of bulk crystals. - This class adds convenient functions to download Materials Project data into - pytorch. References: .. [#matproj] https://materialsproject.org/ - """ # properties @@ -31,144 +27,103 @@ class MaterialsProject(AtomsDataModule): EPerAtom = "energy_per_atom" BandGap = "band_gap" TotalMagnetization = "total_magnetization" - MaterialId = ("material_id",) + MaterialId = "material_id" CreatedAt = "created_at" def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, apikey: Optional[str] = None, **kwargs, ): - """ - - Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - apikey: Materials project key needed to download the data. - """ if apikey is not None and len(apikey) == 16: raise DeprecationWarning( - "You are using a legacy API key. This API is deprecated and no longer supported by MaterialsProject. " - "Please use the nextgen API instead. " - "Visit https://next-gen.materialsproject.org/ to get a valid API-key. " + "You are using a legacy API key. This API is deprecated and no longer " + "supported by Materials Project. Please use the next-gen API instead. " + "Visit https://next-gen.materialsproject.org/ to get a valid API key." ) + if apikey is not None and len(apikey) != 32: - raise AtomsDataModuleError( - "Invalid API-key. MaterialsProject requires an API-key of 32 characters. " - f"Your API-key contains {len(apikey)} characters. " - "Visit https://next-gen.materialsproject.org/ to get a valid API-key. " + raise AtomsDataError( + "Invalid API key. MaterialsProject requires an API key of 32 characters. " + f"Your API key contains {len(apikey)} characters. " + "Visit https://next-gen.materialsproject.org/ to get a valid API key." ) + self.apikey = apikey + self.format = format + + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - self.apikey = apikey - def prepare_data(self): - if not os.path.exists(self.datapath): - # check if apikey is provided - if self.apikey is None: - raise AtomsDataModuleError( - "No API-key provided, visit https://next-gen.materialsproject.org/ to get an API-key." - ) - - # initialize dataset - property_unit_dict = { - MaterialsProject.EformationPerAtom: "eV", - MaterialsProject.EPerAtom: "eV", - MaterialsProject.BandGap: "eV", - MaterialsProject.TotalMagnetization: "None", - } - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + MaterialsProject.EformationPerAtom: "eV", + MaterialsProject.EPerAtom: "eV", + MaterialsProject.BandGap: "eV", + MaterialsProject.TotalMagnetization: "None", + } + + def download(self, datapath: str, distance_unit: str = "Ang") -> None: + """ + Ensure the Materials Project ASE DB exists. + """ + if os.path.exists(datapath): + _ = ASEAtomsData(datapath, self.format, load_structure=False) + return + + if self.apikey is None: + raise AtomsDataError( + "No API key provided. Visit https://next-gen.materialsproject.org/ " + "to get an API key." ) - self._download_data_nextgen(dataset) - else: - dataset = load_dataset(self.datapath, self.format) + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + ) + + self._download_data_nextgen(dataset) - def _download_data_nextgen(self, dataset: ASEAtomsData): + def _download_data_nextgen(self, dataset: ASEAtomsData) -> None: """ - Downloads dataset provided it does not exist in self.path - Returns: - works (bool): true if download succeeded or file already exists + Download Materials Project entries and store them in the ASE DB. """ - # collect data - atms_list = [] + atoms_list = [] properties_list = [] atoms_metadata_list = [] + try: from pymatgen.core import Structure - import pymatgen as pmg from mp_api.client import MPRester - - except: + except Exception as e: raise ImportError( - "In order to download Materials Project data, you have to install " - "mp-api and pymatgen packages" - ) + "To download Materials Project data, install `mp-api` and `pymatgen`." + ) from e with MPRester(self.apikey) as m: query = m.materials.summary.search( - num_sites=(0, 300, 30), + num_sites=(0, 300), num_elements=(1, 9), fields=[ "structure", @@ -183,8 +138,8 @@ def _download_data_nextgen(self, dataset: ASEAtomsData): for q in query: s = q.structure - if type(s) is Structure: - atms_list.append( + if isinstance(s, Structure): + atoms_list.append( Atoms( numbers=s.atomic_numbers, positions=s.cart_coords, @@ -206,14 +161,13 @@ def _download_data_nextgen(self, dataset: ASEAtomsData): ) atoms_metadata_list.append( { - "material_id": q.material_id, + MaterialsProject.MaterialId: str(q.material_id), } ) - # write systems to database logging.info("Write atoms to db...") dataset.add_systems( - atoms_list=atms_list, + atoms_list=atoms_list, property_list=properties_list, atoms_metadata_list=atoms_metadata_list, ) diff --git a/src/schnetpack/datasets/md22.py b/src/schnetpack/datasets/md22.py index 936f68320..cb65eac48 100644 --- a/src/schnetpack/datasets/md22.py +++ b/src/schnetpack/datasets/md22.py @@ -1,5 +1,5 @@ from typing import Optional, Dict, List - +import torch from schnetpack.data import AtomsDataFormat from schnetpack.datasets.md17 import GDMLDataset @@ -20,7 +20,7 @@ def __init__( molecule: str, format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - transforms=None, + transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 413ef5ba0..22248c850 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -104,9 +104,9 @@ def prepare_data(self): if not os.path.exists(self.datapath): property_unit_dict = {OrganicMaterialsDatabase.BandGap: "eV"} - dataset = create_dataset( + dataset = ASEAtomsData( datapath=self.datapath, - format=self.format, + # format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 6e9a7236a..ed5103a21 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -9,14 +9,14 @@ from urllib import request as request import h5py +import torch import numpy as np import progressbar from ase import Atoms import schnetpack.properties as structure from schnetpack.data import AtomsDataFormat -from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset -from schnetpack.transform.base import Transform +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM7X"] @@ -136,12 +136,28 @@ def __init__( only_non_equilibrium: bool = False, format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - transforms: Optional[List[Transform]] = None, + transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, ): + """ + Args: + datapath: path to dataset + raw_data_path: path to raw data + remove_duplicates: do not include duplicate molecules + only_equilibrium: only include equilibrium molecules + only_non_equilibrium: only include non-equilibrium molecules + format: dataset format + load_properties: subset of properties to load + transforms: Transform applied to each system separately before batching + subset_idx: indices of the subset to load + property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + **kwargs: additional keyword arguments. + """ + if only_equilibrium and only_non_equilibrium: raise AtomsDataError( "only_equilibrium and only_non_equilibrium cannot both be True." @@ -197,8 +213,11 @@ def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> N self.subset_idx = effective_subset def download(self, datapath: str, distance_unit: str = "Ang") -> None: + """ + Download the QM7-X dataset and create the ASEAtomsData object. + """ if os.path.exists(datapath): - _ = load_dataset(datapath, self.format, load_structure=False) + _ = ASEAtomsData(datapath, self.format, load_structure=False) return tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index b27b5ea67..252a5a9a8 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -139,9 +139,9 @@ def prepare_data(self): tmpdir = tempfile.mkdtemp("tmQM") - dataset = create_dataset( + dataset = ASEAtomsData( datapath=self.datapath, - format=self.format, + # format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) From c242648a9cb7a8abff44135575b9f0dd2ccb2e44 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 23:44:43 +0100 Subject: [PATCH 29/68] refactor: fix format error in MaterialsProject --- src/schnetpack/datasets/materials_project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 696c76db5..15ff3b04a 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -88,7 +88,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Ensure the Materials Project ASE DB exists. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, self.format, load_structure=False) + _ = ASEAtomsData(datapath, load_structure=False) return if self.apikey is None: From fb2bd3d3ce0349962080490523528c4422497620 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 8 Mar 2026 23:46:21 +0100 Subject: [PATCH 30/68] refactor: remove format parameter in QM7X dataset loading --- src/schnetpack/datasets/qm7x.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index ed5103a21..41a3b921a 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -217,7 +217,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Download the QM7-X dataset and create the ASEAtomsData object. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, self.format, load_structure=False) + _ = ASEAtomsData(datapath, load_structure=False) return tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") From 1fb9d97392b4cc2a1269170050c19fe81ca37222 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 01:28:03 +0100 Subject: [PATCH 31/68] refactor: remove legacy atoms_legacy.py file and streamline dataset loading in datamodule --- src/schnetpack/data/atoms_legacy.py | 635 --------------------------- src/schnetpack/data/datamodule.py | 5 +- src/schnetpack/data/datamodule_v2.py | 43 +- 3 files changed, 19 insertions(+), 664 deletions(-) delete mode 100644 src/schnetpack/data/atoms_legacy.py diff --git a/src/schnetpack/data/atoms_legacy.py b/src/schnetpack/data/atoms_legacy.py deleted file mode 100644 index d3888dd9f..000000000 --- a/src/schnetpack/data/atoms_legacy.py +++ /dev/null @@ -1,635 +0,0 @@ -""" -This module contains all functionalities required to load atomistic data, -generate batches and compute statistics. It makes use of the ASE database -for atoms [#ase2]_. - -References ----------- -.. [#ase2] Larsen, Mortensen, Blomqvist, Castelli, Christensen, Dułak, Friis, - Groves, Hammer, Hargus: - The atomic simulation environment -- a Python library for working with atoms. - Journal of Physics: Condensed Matter, 9, 27. 2017. -""" - -import logging -import os -from abc import ABC, abstractmethod -from enum import Enum -from typing import Optional, List, Dict, Any, Iterable, Union, Tuple - -import torch -import copy -from ase import Atoms -from ase.db import connect - -import schnetpack as spk -import schnetpack.properties as structure -from schnetpack.transform.base import Transform - -logger = logging.getLogger(__name__) - -__all__ = [ - "ASEAtomsData", - "BaseAtomsData", - "AtomsDataFormat", - "resolve_format", - "create_dataset", - "load_dataset", -] - - -class AtomsDataFormat(Enum): - """Enumeration of data formats""" - - ASE = "ase" - - -class AtomsDataError(Exception): - pass - - -extension_map = {AtomsDataFormat.ASE: ".db"} - - -class BaseAtomsData(ABC): - """ - Base mixin class for atomistic data. Use together with PyTorch Dataset or - IterableDataset to implement concrete data formats. - """ - - def __init__( - self, - load_properties: Optional[List[str]] = None, - load_structure: bool = True, - transforms: Optional[List[Transform]] = None, - subset_idx: Optional[List[int]] = None, - ): - """ - Args: - load_properties: Set of properties to be loaded and returned. - If None, all properties in the ASE dB will be returned. - load_structure: If True, load structure properties. - transforms: preprocessing transforms (see schnetpack.data.transforms) - subset: List of data indices. - """ - self._transform_module = None - self.load_properties = load_properties - self.load_structure = load_structure - self.transforms = transforms - self.subset_idx = subset_idx - - def __len__(self) -> int: - raise NotImplementedError - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, value: Optional[List[Transform]]): - self._transforms = [] - self._transform_module = None - - if value is not None: - for tf in value: - self._transforms.append(tf) - self._transform_module = torch.nn.Sequential(*self._transforms) - - def subset(self, subset_idx: List[int]): - assert ( - subset_idx is not None - ), "Indices for creation of the subset need to be provided!" - ds = copy.copy(self) - if ds.subset_idx: - ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] - else: - ds.subset_idx = subset_idx - return ds - - @property - @abstractmethod - def available_properties(self) -> List[str]: - """Available properties in the dataset""" - pass - - @property - @abstractmethod - def units(self) -> Dict[str, str]: - """Property to unit dict""" - pass - - @property - def load_properties(self) -> List[str]: - """Properties to be loaded""" - if self._load_properties is None: - return self.available_properties - else: - return self._load_properties - - @load_properties.setter - def load_properties(self, val: List[str]): - if val is not None: - props = self.available_properties - assert all( - [p in props for p in val] - ), "Not all given properties are available in the dataset!" - self._load_properties = val - - @property - @abstractmethod - def metadata(self) -> Dict[str, Any]: - """Global metadata""" - pass - - @property - @abstractmethod - def atomrefs(self) -> Dict[str, torch.Tensor]: - """Single-atom reference values for properties""" - pass - - @abstractmethod - def update_metadata(self, **kwargs): - pass - - @abstractmethod - def iter_properties( - self, - indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, - load_structure: Optional[bool] = None, - ): - pass - - @staticmethod - @abstractmethod - def create( - datapath: str, - position_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Dict[str, List[float]], - **kwargs, - ) -> "ASEAtomsData": - pass - - @abstractmethod - def add_systems( - self, - property_list: List[Dict[str, Any]], - atoms_list: Optional[List[Atoms]] = None, - atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, - ): - pass - - @abstractmethod - def add_system(self, atoms: Optional[Atoms] = None, **properties): - pass - - -class ASEAtomsData(ASEAtomsData): - """ - PyTorch dataset for atomistic data. The raw data is stored in the specified - ASE database. - - """ - - def __init__( - self, - datapath: str, - load_properties: Optional[List[str]] = None, - load_structure: bool = True, - transforms: Optional[List[torch.nn.Module]] = None, - subset_idx: Optional[List[int]] = None, - property_units: Optional[Dict[str, str]] = None, - distance_unit: Optional[str] = None, - ): - """ - Args: - datapath: Path to ASE DB. - load_properties: Set of properties to be loaded and returned. - If None, all properties in the ASE dB will be returned. - load_structure: If True, load structure properties. - transforms: preprocessing torch.nn.Module (see schnetpack.data.transforms) - subset_idx: List of data indices. - units: property-> unit string dictionary that overwrites the native units - of the dataset. Units are converted automatically during loading. - """ - self.datapath = datapath - self.conn = connect(self.datapath, use_lock_file=False) - - ASEAtomsData.__init__( - self, - load_properties=load_properties, - load_structure=load_structure, - transforms=transforms, - subset_idx=subset_idx, - ) - - self._check_db() - - # initialize units - md = self.metadata - if "_distance_unit" not in md.keys(): - raise AtomsDataError( - "Dataset does not have a distance unit set. Please add units to the " - + "dataset using `spkconvert`!" - ) - if "_property_unit_dict" not in md.keys(): - raise AtomsDataError( - "Dataset does not have a property units set. Please add units to the " - + "dataset using `spkconvert`!" - ) - - if distance_unit: - self.distance_conversion = spk.units.convert_units( - md["_distance_unit"], distance_unit - ) - self.distance_unit = distance_unit - else: - self.distance_conversion = 1.0 - self.distance_unit = md["_distance_unit"] - - self._units = md["_property_unit_dict"] - self.conversions = {prop: 1.0 for prop in self._units} - if property_units is not None: - for prop, unit in property_units.items(): - self.conversions[prop] = spk.units.convert_units( - self._units[prop], unit - ) - self._units[prop] = unit - - def __len__(self) -> int: - if self.subset_idx is not None: - return len(self.subset_idx) - - return self.conn.count() - - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - if self.subset_idx is not None: - idx = self.subset_idx[idx] - - props = self._get_properties( - self.conn, idx, self.load_properties, self.load_structure - ) - props = self._apply_transforms(props) - - return props - - def _apply_transforms(self, props): - if self._transform_module is not None: - props = self._transform_module(props) - return props - - def _check_db(self): - if not os.path.exists(self.datapath): - raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") - - if self.subset_idx: - with connect(self.datapath, use_lock_file=False) as conn: - n_structures = conn.count() - - assert max(self.subset_idx) < n_structures - - def iter_properties( - self, - indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, - load_structure: Optional[bool] = None, - load_metadata: bool = False, - ): - """ - Return property dictionary at given indices. - - Args: - indices: data indices - load_properties (sequence or None): subset of available properties to load - load_structure: load and return structure - load_metadata: load and return metadata - - Returns: - properties (dict): dictionary with molecular properties - - """ - if load_properties is None: - load_properties = self.load_properties - load_structure = load_structure or self.load_structure - - if self.subset_idx: - if indices is None: - indices = self.subset_idx - elif type(indices) is int: - indices = [self.subset_idx[indices]] - else: - indices = [self.subset_idx[i] for i in indices] - else: - if indices is None: - indices = range(len(self)) - elif type(indices) is int: - indices = [indices] - - # read from ase db - for i in indices: - yield self._get_properties( - self.conn, - i, - load_properties=load_properties, - load_structure=load_structure, - load_metadata=load_metadata, - ) - - def _get_properties( - self, - conn, - idx: int, - load_properties: List[str], - load_structure: bool, - load_metadata: bool = False, - ): - row = conn.get(idx + 1) - - # extract properties - # TODO: can the copies be avoided? - properties = {} - properties[structure.idx] = torch.tensor([idx]) - for pname in load_properties: - properties[pname] = ( - torch.tensor(row.data[pname].copy()) * self.conversions[pname] - ) - - Z = row["numbers"].copy() - properties[structure.n_atoms] = torch.tensor([Z.shape[0]]) - - if load_structure: - properties[structure.Z] = torch.tensor(Z, dtype=torch.long) - properties[structure.position] = ( - torch.tensor(row["positions"].copy()) * self.distance_conversion - ) - properties[structure.cell] = ( - torch.tensor(row["cell"][None].copy()) * self.distance_conversion - ) - properties[structure.pbc] = torch.tensor(row["pbc"]) - - if load_metadata: - properties["metadata"] = row.key_value_pairs - - return properties - - # Metadata - @property - def metadata(self): - with connect(self.datapath, use_lock_file=False) as conn: - return conn.metadata - - def _set_metadata(self, val: Dict[str, Any]): - with connect(self.datapath, use_lock_file=False) as conn: - conn.metadata = val - - def update_metadata(self, **kwargs): - assert all( - key[0] != 0 for key in kwargs - ), "Metadata keys starting with '_' are protected!" - - md = self.metadata - md.update(kwargs) - self._set_metadata(md) - - @property - def available_properties(self) -> List[str]: - md = self.metadata - return list(md["_property_unit_dict"].keys()) - - @property - def units(self) -> Dict[str, str]: - """Dictionary of properties to units""" - return self._units - - @property - def atomrefs(self) -> Dict[str, torch.Tensor]: - md = self.metadata - arefs = md["atomrefs"] - arefs = {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} - return arefs - - ## Creation - - @staticmethod - def create( - datapath: str, - distance_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Optional[Dict[str, List[float]]] = None, - **kwargs, - ) -> "ASEAtomsData": - """ - - Args: - datapath: Path to ASE DB. - distance_unit: unit of atom positions and cell - property_unit_dict: Defines the available properties of the datasetseta and - provides units for ALL properties of the dataset. If a property is - unit-less, you can pass "arb. unit" or `None`. - atomrefs: dictionary mapping properies (the keys) to lists of single-atom - reference values of the property. This is especially useful for - extensive properties such as the energy, where the single atom energies - contribute a major part to the overall value. - kwargs: Pass arguments to init. - - Returns: - newly created ASEAtomsData - - """ - if not datapath.endswith(".db"): - raise AtomsDataError( - "Invalid datapath! Please make sure to add the file extension '.db' to " - "your dbpath." - ) - - if os.path.exists(datapath): - raise AtomsDataError(f"Dataset already exists: {datapath}") - - atomrefs = atomrefs or {} - - with connect(datapath) as conn: - conn.metadata = { - "_property_unit_dict": property_unit_dict, - "_distance_unit": distance_unit, - "atomrefs": atomrefs, - } - - return ASEAtomsData(datapath, **kwargs) - - # add systems - def add_system( - self, - atoms: Optional[Atoms] = None, - atoms_metadata: Optional[Dict[str, Any]] = None, - **properties, - ): - self._add_system(atoms, atoms_metadata, **properties) - - def add_systems( - self, - property_list: List[Dict[str, Any]], - atoms_list: Optional[List[Atoms]] = None, - atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, - ): - """ - Add atoms data to the dataset. - - Args: - atoms_list: System composition and geometry. If Atoms are None, - the structure needs to be given as part of the property dicts - (using structure.Z, structure.R, structure.cell, structure.pbc) - property_list: Properties as list of key-value pairs in the same - order as corresponding list of `atoms`. - Keys have to match the `available_properties` of the dataset - plus additional structure properties, if atoms is None. - atoms_metadata_list: Metadata of the atoms objects as list of key-value pairs in the same - order as corresponding list of `atoms`. - Metadata can not be used as a training property, but can be used for splitting - strategies (e.g. material_id, timestamp, ...). - """ - if atoms_list is None: - atoms_list = [None] * len(property_list) - - if atoms_metadata_list is None: - atoms_metadata_list = [{}] * len(property_list) - - for atoms, prop, atoms_metadata in zip( - atoms_list, property_list, atoms_metadata_list - ): - self._add_system( - atoms, - atoms_metadata, - **prop, - ) - - def _add_system( - self, - atoms: Optional[Atoms] = None, - atoms_metadata: Optional[Dict[str, Any]] = None, - **properties, - ): - """ - Add systems to DB. - """ - # create atoms object if not provided - if atoms is None: - try: - Z = properties[structure.Z] - R = properties[structure.R] - cell = properties[structure.cell] - pbc = properties[structure.pbc] - atoms = Atoms(numbers=Z, positions=R, cell=cell, pbc=pbc) - except KeyError as e: - raise AtomsDataError( - "Property dict does not contain all necessary structure keys" - ) from e - - if atoms_metadata is None: - atoms_metadata = {} - - with connect(self.datapath, use_lock_file=False) as conn: - prop_keys = conn.metadata["_property_unit_dict"].keys() - - valid_props = set().union( - prop_keys, - [structure.Z, structure.R, structure.cell, structure.pbc], - ) - for pname in properties: - if pname not in valid_props: - logger.warning( - f"Property `{pname}` is not a defined property for this dataset and " - + f"will be ignored. If it should be included, it has to be " - + f"provided together with its unit when calling " - + f"AseAtomsData.create()." - ) - - data = {} - for pname in prop_keys: - if pname in properties: - data[pname] = properties[pname] - else: - raise AtomsDataError("Required property missing:" + pname) - - conn.write(atoms, data=data, key_value_pairs=atoms_metadata) - - -def create_dataset( - datapath: str, - format: AtomsDataFormat, - distance_unit: str, - property_unit_dict: Dict[str, str], - **kwargs, -) -> ASEAtomsData: - """ - Create a new atoms dataset. - - Args: - datapath: file path - format: atoms data format - distance_unit: unit of atom positiona etc. as string - property_unit_dict: dictionary that maps properties to units, - e.g. {"energy": "kcal/mol"} - **kwargs: arguments for passed to AtomsData init - - Returns: - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=property_unit_dict, - **kwargs, - ) - else: - raise AtomsDataError(f"Unknown format: {format}") - - return dataset - - -def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: - """ - Load dataset. - - Args: - datapath: file path - format: atoms data format - **kwargs: arguments for passed to AtomsData init - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData(datapath=datapath, **kwargs) - else: - raise AtomsDataError(f"Unknown format: {format}") - return dataset - - -def resolve_format( - datapath: str, format: Optional[AtomsDataFormat] = None -) -> Tuple[str, AtomsDataFormat]: - """ - Extract data format from file suffix, check for consistency with (optional) given - format, or append suffix to file path. - - Args: - datapath: path to atoms data - format: atoms data format - - """ - file, suffix = os.path.splitext(datapath) - if suffix == ".db": - if format is None: - format = AtomsDataFormat.ASE - assert ( - format is AtomsDataFormat.ASE - ), f"File extension {suffix} is not compatible with chosen format {format}" - elif len(suffix) == 0 and format: - datapath = datapath + extension_map[format] - elif len(suffix) == 0 and format is None: - raise AtomsDataError( - "If format is not given, `datapath` needs a supported file extension!" - ) - else: - raise AtomsDataError(f"Unsupported file extension: {suffix}") - return datapath, format diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index d9f4ebed1..3deae9166 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -10,9 +10,6 @@ from torch.utils.data import BatchSampler from schnetpack.data import ( - AtomsDataFormat, - # resolve_format, - load_dataset, ASEAtomsData, AtomsLoader, calculate_stats, @@ -181,7 +178,7 @@ def setup(self, stage: Optional[str] = None): # (re)load datasets if self.dataset is None: - self.dataset = load_dataset( + self.dataset = ASEAtomsData( datapath, self.format, property_units=self.property_units, diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 3cc493247..902af3ab9 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,7 +1,6 @@ from __future__ import annotations - -from copy import copy from typing import List, Optional, Union, Dict, Any, Type +import os import numpy as np import pytorch_lightning as pl @@ -31,10 +30,6 @@ def __init__( num_test: Optional[Union[int, float]] = None, split_file: Optional[str] = "split.npz", splitting: Optional[SplittingStrategy] = None, - transforms: Optional[List] = None, - train_transforms: Optional[List] = None, - val_transforms: Optional[List] = None, - test_transforms: Optional[List] = None, num_workers: int = 0, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, @@ -58,10 +53,6 @@ def __init__( self.num_workers = num_workers self._pin_memory = pin_memory - self.train_transforms = train_transforms or copy(transforms) or [] - self.val_transforms = val_transforms or copy(transforms) or [] - self.test_transforms = test_transforms or copy(transforms) or [] - self.train_idx = None self.val_idx = None self.test_idx = None @@ -105,26 +96,30 @@ def setup(self, stage: Optional[str] = None) -> None: self._val_dataset = self.dataset.subset(self.val_idx) self._test_dataset = self.dataset.subset(self.test_idx) - self.provider = StatsAtomrefProvider(self._train_dataset) + transforms = self.dataset.transforms or [] + + train_transforms = self.dataset.train_transforms or transforms + val_transforms = self.dataset.val_transforms or transforms + test_transforms = self.dataset.test_transforms or transforms - self._initialize_transform_list(self.train_transforms) - self._initialize_transform_list(self.val_transforms) - self._initialize_transform_list(self.test_transforms) + self._train_dataset.transforms = train_transforms + self._val_dataset.transforms = val_transforms + self._test_dataset.transforms = test_transforms - self._train_dataset.transforms = self.train_transforms - self._val_dataset.transforms = self.val_transforms - self._test_dataset.transforms = self.test_transforms + self.provider = StatsAtomrefProvider(self._train_dataset) + + self._initialize_transforms(self._train_dataset) + self._initialize_transforms(self._val_dataset) + self._initialize_transforms(self._test_dataset) - def _initialize_transform_list(self, transforms: List) -> None: - if not transforms: + def _initialize_transforms(self, dataset: ASEAtomsData) -> None: + if not dataset.transforms: return - for t in transforms: + for t in dataset.transforms: t.initialize(provider=self.provider, atomrefs=self.provider.train_atomrefs) def _load_partitions(self) -> None: - import os - total_size = len(self.dataset) def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: @@ -150,9 +145,7 @@ def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: return if num_train is None or num_val is None: - raise ValueError( - "If no split file is given, num_train and num_val must be set." - ) + raise ValueError("num_train and num_val must be set if no split file.") train_idx, val_idx, test_idx = self.splitting.split( self.dataset, num_train, num_val, num_test From bc8cb0238cd6dd78029ea17813d502ea5b49f9be Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 01:28:39 +0100 Subject: [PATCH 32/68] refactor: change ASEAtomsData class with additional transform options and clean up unused code --- src/schnetpack/data/atoms.py | 50 +++++++++++++++--------------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 6efeded4a..2da6945d9 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -2,7 +2,7 @@ import logging import os from enum import Enum -from typing import Optional, List, Dict, Any, Iterable, Union, Tuple +from typing import Optional, List, Dict, Any, Iterable, Union import torch from ase import Atoms @@ -14,13 +14,7 @@ logger = logging.getLogger(__name__) -__all__ = ["ASEAtomsData", "AtomsDataFormat", "load_dataset"] - - -class AtomsDataFormat(Enum): - """Enumeration of data formats""" - - ASE = "ase" +__all__ = ["ASEAtomsData", "AtomsDataError"] class AtomsDataError(Exception): @@ -39,6 +33,9 @@ def __init__( load_properties: Optional[List[str]] = None, load_structure: bool = True, transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -48,10 +45,18 @@ def __init__( self._check_db() self.conn = connect(self.datapath, use_lock_file=False) - # merged ASEAtomsData state self.transforms: List[Transform] = ( list(transforms) if transforms is not None else [] ) + self.train_transforms: Optional[List[Transform]] = ( + list(train_transforms) if train_transforms is not None else None + ) + self.val_transforms: Optional[List[Transform]] = ( + list(val_transforms) if val_transforms is not None else None + ) + self.test_transforms: Optional[List[Transform]] = ( + list(test_transforms) if test_transforms is not None else None + ) self._load_properties: Optional[List[str]] = None self.load_structure = load_structure @@ -90,11 +95,12 @@ def __init__( self.load_properties = load_properties # ---------- merged ASEAtomsData bits ---------- + def subset(self, subset_idx: List[int]): if subset_idx is None: raise ValueError("subset_idx must be provided.") ds = copy.copy(self) - if ds.subset_idx: + if ds.subset_idx is not None: ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] else: ds.subset_idx = subset_idx @@ -141,13 +147,13 @@ def _check_db(self): if not os.path.exists(self.datapath): raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") - if self.subset_idx: + if self.subset_idx is not None: with connect(self.datapath, use_lock_file=False) as conn: n_structures = conn.count() if max(self.subset_idx) >= n_structures: raise AtomsDataError("subset_idx contains out-of-range indices") - # ---------- metadata / units ---------- + # ---------- metadata / units ----------- @property def metadata(self) -> Dict[str, Any]: @@ -194,7 +200,7 @@ def iter_properties( if load_structure is None: load_structure = self.load_structure - if self.subset_idx: + if self.subset_idx is not None: if indices is None: indices = self.subset_idx elif isinstance(indices, int): @@ -277,7 +283,7 @@ def create( "atomrefs": atomrefs, } - return ASEAtomsData(datapath, **kwargs) + return ASEAtomsData(datapath, **kwargs) ##NO RETURN HERE def add_system( self, @@ -355,19 +361,3 @@ def _add_system( data[pname] = properties[pname] conn.write(atoms, data=data, key_value_pairs=atoms_metadata) - - -def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> ASEAtomsData: - """ - Load dataset. - - Args: - datapath: file path - format: atoms data format - **kwargs: arguments for passed to AtomsData init - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData(datapath=datapath, **kwargs) - else: - raise AtomsDataError(f"Unknown format: {format}") From 4efe59bf163e8f81ab4b08034e97e32c2784b0ab Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 01:29:30 +0100 Subject: [PATCH 33/68] refactor: remove format parameter from all dataset classes --- src/schnetpack/datasets/ani1.py | 6 +----- src/schnetpack/datasets/iso17.py | 4 ---- src/schnetpack/datasets/materials_project.py | 3 --- src/schnetpack/datasets/md17.py | 7 ------- src/schnetpack/datasets/md22.py | 4 ---- src/schnetpack/datasets/omdb.py | 6 +----- src/schnetpack/datasets/qm7x.py | 4 ---- src/schnetpack/datasets/qm9.py | 14 +++++--------- src/schnetpack/datasets/rmd17.py | 4 ---- src/schnetpack/datasets/tmqm.py | 6 +----- 10 files changed, 8 insertions(+), 50 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 9ae478a55..e8dce9f4e 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -11,8 +11,7 @@ import numpy as np from ase import Atoms -from schnetpack.data import AtomsDataFormat -from schnetpack.data.atoms import ASEAtomsData, AtomsDataError, load_dataset +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError from schnetpack.transform.base import Transform __all__ = ["ANI1"] @@ -42,7 +41,6 @@ def __init__( datapath: str, num_heavy_atoms: int = 8, high_energies: bool = False, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -55,7 +53,6 @@ def __init__( datapath: path to dataset num_heavy_atoms: number of heavy atoms high_energies: whether to include high-energy conformations - format: dataset format load_properties: subset of properties to load transforms: Transform applied to each system separately before batching subset_idx: indices of the subset to load @@ -65,7 +62,6 @@ def __init__( """ self.num_heavy_atoms = num_heavy_atoms self.high_energies = high_energies - self.format = format self.download( datapath=datapath, diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index 8694fad82..852666afa 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -12,7 +12,6 @@ from ase.db import connect from tqdm import tqdm -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["ISO17"] @@ -43,7 +42,6 @@ def __init__( self, datapath: str, fold: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -55,7 +53,6 @@ def __init__( Args: datapath: path to dataset fold: select a specific dataset of iso17 - format: dataset format load_properties: subset of properties to load transforms: Transform applied to each system separately before batching subset_idx: indices of the subset to load @@ -68,7 +65,6 @@ def __init__( self.root_path = datapath self.fold = fold - self.format = format dbpath = os.path.join(datapath, "iso17", fold + ".db") diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 15ff3b04a..d361d555c 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -6,7 +6,6 @@ import numpy as np from ase import Atoms -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError from schnetpack.transform.base import Transform @@ -33,7 +32,6 @@ class MaterialsProject(ASEAtomsData): def __init__( self, datapath: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -57,7 +55,6 @@ def __init__( ) self.apikey = apikey - self.format = format self.download( datapath=datapath, diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 39903f0bb..5497e5cdf 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -10,7 +10,6 @@ from ase import Atoms import schnetpack.properties as structure -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["MD17"] @@ -34,7 +33,6 @@ def __init__( molecule: str, tmpdir: str = "gdml_tmp", atomrefs: Optional[Dict[str, List[float]]] = None, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -50,7 +48,6 @@ def __init__( molecule: name of the molecule. tmpdir: name of temporary directory used for parsing. atomrefs: properties of free atoms. - format: dataset format (e.g: ASE). load_properties: subset of properties to load. transforms: Transform applied to each system separately before batching. subset_idx: indices of the subset to load. @@ -62,7 +59,6 @@ def __init__( self.download_url = download_url self._native_atomrefs = atomrefs self.tmpdir = tmpdir - self.format = format if molecule not in self.datasets_dict: raise AtomsDataError(f"Molecule {molecule} is not supported!") @@ -170,7 +166,6 @@ def __init__( self, datapath: str, molecule: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms=None, subset_idx: Optional[List[int]] = None, @@ -182,7 +177,6 @@ def __init__( Args: datapath: path to dataset. molecule: name of the molecule. - format: dataset format (e.g: ASE). load_properties: subset of properties to load. transforms: Transform applied to each system separately before batching. subset_idx: indices of the subset to load. @@ -223,7 +217,6 @@ def __init__( tmpdir="md17", molecule=molecule, datapath=datapath, - format=format, load_properties=load_properties, transforms=transforms, subset_idx=subset_idx, diff --git a/src/schnetpack/datasets/md22.py b/src/schnetpack/datasets/md22.py index cb65eac48..ffb1db528 100644 --- a/src/schnetpack/datasets/md22.py +++ b/src/schnetpack/datasets/md22.py @@ -1,6 +1,5 @@ from typing import Optional, Dict, List import torch -from schnetpack.data import AtomsDataFormat from schnetpack.datasets.md17 import GDMLDataset __all__ = ["MD22"] @@ -18,7 +17,6 @@ def __init__( self, datapath: str, molecule: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -30,7 +28,6 @@ def __init__( Args: datapath: path to dataset molecule: name of the molecule - format: dataset format load_properties: subset of properties to load transforms: Transform applied to each system separately before batching. subset_idx: indices of the subset to load. @@ -68,7 +65,6 @@ def __init__( tmpdir="md22", molecule=molecule, datapath=datapath, - format=format, load_properties=load_properties, transforms=transforms, subset_idx=subset_idx, diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 22248c850..dd9b6f346 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -37,7 +37,6 @@ def __init__( num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, @@ -61,7 +60,6 @@ def __init__( num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. @@ -83,7 +81,6 @@ def __init__( num_val=num_val, num_test=num_test, split_file=split_file, - format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, @@ -106,14 +103,13 @@ def prepare_data(self): dataset = ASEAtomsData( datapath=self.datapath, - # format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) self._convert(dataset) else: - dataset = load_dataset(self.datapath, self.format) + dataset = ASEAtomsData(self.datapath) def _convert(self, dataset): """ diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 41a3b921a..054988904 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -15,7 +15,6 @@ from ase import Atoms import schnetpack.properties as structure -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM7X"] @@ -134,7 +133,6 @@ def __init__( remove_duplicates: bool = True, only_equilibrium: bool = False, only_non_equilibrium: bool = False, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -149,7 +147,6 @@ def __init__( remove_duplicates: do not include duplicate molecules only_equilibrium: only include equilibrium molecules only_non_equilibrium: only include non-equilibrium molecules - format: dataset format load_properties: subset of properties to load transforms: Transform applied to each system separately before batching subset_idx: indices of the subset to load @@ -168,7 +165,6 @@ def __init__( self.duplicates_ids = None self.only_equilibrium = only_equilibrium self.only_non_equilibrium = only_non_equilibrium - self.format = format self.download( datapath=datapath, diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index f6031158c..0083d3a3f 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -16,7 +16,6 @@ from tqdm import tqdm import schnetpack.properties as structure -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM9"] @@ -56,7 +55,6 @@ class QM9(ASEAtomsData): def __init__( self, datapath: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, remove_uncharacterized: bool = False, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, @@ -68,7 +66,6 @@ def __init__( """ Args: datapath: path to dataset - format: dataset format remove_uncharacterized: do not include uncharacterized molecules. load_properties: subset of properties to load transforms: Transform applied to each system separately before batching. @@ -78,7 +75,6 @@ def __init__( **kwargs: additional keyword arguments. """ self.remove_uncharacterized = remove_uncharacterized - self.format = format self.download( datapath=datapath, @@ -145,7 +141,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: try: atomrefs = self._download_atomrefs(tmpdir) - dataset = ASEAtomsData.create( + self.create( datapath=datapath, distance_unit=distance_unit, property_unit_dict=self._native_property_units(), @@ -157,7 +153,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: else: uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized) + self._download_data(tmpdir, uncharacterized) finally: shutil.rmtree(tmpdir, ignore_errors=True) @@ -208,7 +204,7 @@ def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: def _download_data( self, tmpdir: str, - dataset: ASEAtomsData, + # dataset: ASEAtomsData, uncharacterized: Optional[List[int]], ) -> None: @@ -244,7 +240,7 @@ def _download_data( lines = f.readlines() values = lines[1].split()[2:] - for pname, value in zip(dataset.available_properties, values): + for pname, value in zip(self.available_properties, values): properties[pname] = np.array([float(value)]) for line in lines: @@ -261,5 +257,5 @@ def _download_data( property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 58f0d83d8..204f21f3d 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -12,7 +12,6 @@ from ase import Atoms import schnetpack.properties as structure -from schnetpack.data import AtomsDataFormat from schnetpack.data.atoms import ASEAtomsData, AtomsDataError from schnetpack.transform.base import Transform @@ -69,7 +68,6 @@ def __init__( self, datapath: str, molecule: str, - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, transforms: Optional[List[torch.nn.Module]] = None, subset_idx: Optional[List[int]] = None, @@ -81,7 +79,6 @@ def __init__( Args: datapath: path to dataset molecule: name of the molecule - format: dataset format load_properties: subset of properties to load transforms: Transform applied to each system separately before batching. subset_idx: indices of the subset to load. @@ -94,7 +91,6 @@ def __init__( raise AtomsDataError(f"Molecule {molecule} is not supported!") self.molecule = molecule - self.format = format self.download( datapath=datapath, diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index 252a5a9a8..b08527414 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -59,7 +59,6 @@ def __init__( num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, @@ -84,7 +83,6 @@ def __init__( num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load remove_uncharacterized: do not include uncharacterized molecules. val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. @@ -107,7 +105,6 @@ def __init__( num_val=num_val, num_test=num_test, split_file=split_file, - format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, @@ -141,7 +138,6 @@ def prepare_data(self): dataset = ASEAtomsData( datapath=self.datapath, - # format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) @@ -149,7 +145,7 @@ def prepare_data(self): self._download_data(tmpdir, dataset) shutil.rmtree(tmpdir) else: - dataset = load_dataset(self.datapath, self.format) + dataset = ASEAtomsData(self.datapath) def _download_data(self, tmpdir, dataset: ASEAtomsData): tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz") From de23e5d6ebdfb0c1b3bd23ad58e1dab05b5a3e35 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 01:38:20 +0100 Subject: [PATCH 34/68] refactor: simplify format handling in AtomsDataModule (old) --- src/schnetpack/data/datamodule.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index 3deae9166..23d9a08c6 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -40,7 +40,7 @@ def __init__( num_val: Union[int, float] = None, num_test: Optional[Union[int, float]] = None, split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = None, + format=None, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, @@ -113,7 +113,8 @@ def __init__( self.num_test = num_test self.splitting = splitting or RandomSplit() self.split_file = split_file - # self.datapath, self.format = resolve_format(datapath, format) + self.datapath = datapath + self.format = format self.load_properties = load_properties self.num_workers = num_workers self.num_val_workers = self.num_workers From dadbf1b9227dc4f0058dcca49757be4b00c2a5f0 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 02:03:33 +0100 Subject: [PATCH 35/68] refactor: removed dict in ASEAtomsData and simplify download method in GDMLDataset --- src/schnetpack/data/atoms.py | 2 +- src/schnetpack/data/datamodule_v2.py | 3 +-- src/schnetpack/datasets/md17.py | 8 ++------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 2da6945d9..30f4306a9 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -80,7 +80,7 @@ def __init__( self.distance_conversion = 1.0 self.distance_unit = md["_distance_unit"] - self._units = dict(md["_property_unit_dict"]) + self._units = md["_property_unit_dict"] self.conversions = {prop: 1.0 for prop in self._units} # apply unit overrides on load only diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index 902af3ab9..bcf59287d 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,5 +1,4 @@ -from __future__ import annotations -from typing import List, Optional, Union, Dict, Any, Type +from typing import Optional, Union, Dict, Any, Type import os import numpy as np diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 5497e5cdf..ae074d353 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -113,15 +113,11 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: atomrefs=self._native_atomrefs, ) dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir, dataset) + self._download_data(tmpdir) finally: shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data( - self, - tmpdir, - dataset: ASEAtomsData, - ): + def _download_data(self, tmpdir): logging.info("Downloading {} data".format(self.molecule)) rawpath = os.path.join(tmpdir, self.datasets_dict[self.molecule]) url = self.download_url + self.datasets_dict[self.molecule] From f788a91b9af9f31fce28e0b1213ad8cdad609813 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 15:06:14 +0100 Subject: [PATCH 36/68] refactor: add deprecation warnings for legacy datamodule methods in atomistic transforms --- src/schnetpack/transform/atomistic.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index ad6ba9676..11103a625 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional +from typing import Dict +import warnings import torch from ase.data import atomic_masses @@ -145,6 +146,12 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ + warnings.warn( + "RemoveOffsets.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) provider = StatsAtomrefProvider(_datamodule.train_dataset) return self.initialize(provider, atomrefs=provider.train_atomrefs) @@ -226,6 +233,12 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ + warnings.warn( + "ScaleProperty.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) provider = StatsAtomrefProvider(_datamodule.train_dataset) return self.initialize(provider, atomrefs=None) @@ -250,7 +263,7 @@ class AddOffsets(Transform): precision. """ - is_preprocessor: bool = True + is_preprocessor: bool = False is_postprocessor: bool = True atomref: torch.Tensor @@ -329,6 +342,12 @@ def datamodule(self, _datamodule): """ Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. """ + warnings.warn( + "AddOffsets.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) provider = StatsAtomrefProvider(_datamodule.train_dataset) return self.initialize(provider, atomrefs=provider.train_atomrefs) From 70cb3e1db9e6be3450f6e6d28ed0a9d0795d2168 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 15:21:21 +0100 Subject: [PATCH 37/68] refactor: add docstring and deprecation warnings for legacy arguments in AtomsDataModuleV2 --- src/schnetpack/data/datamodule_v2.py | 47 ++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index bcf59287d..e932ef71a 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -1,5 +1,6 @@ from typing import Optional, Union, Dict, Any, Type import os +import warnings import numpy as np import pytorch_lightning as pl @@ -37,6 +38,52 @@ def __init__( pin_memory: bool = False, **kwargs, ): + """ + dataset: prebuilt ASEAtomsData dataset instance + batch_size: (train) batch size + num_train: number of training examples (absolute or relative) + num_val: number of validation examples (absolute or relative) + num_test: number of test examples (absolute or relative) + split_file: path to npz file with data partitions + splitting: Method to generate train/validation/test partitions + (default: RandomSplit) + num_workers: Number of data loader workers + val_batch_size: validation batch size. If None, use test_batch_size, then + batch_size + test_batch_size: test batch size. If None, use val_batch_size, then + batch_size + train_sampler_cls: type of torch training sampler. + This is by default wrapped into a torch.utils.data.BatchSampler. + train_sampler_args: dict of train_sampler keyword arguments. + pin_memory: If true, pin memory of loaded data to GPU. Default: Will be + set to true, when GPUs are used. + """ + legacy_args = { + "datapath", + "format", + "load_properties", + "transforms", + "train_transforms", + "val_transforms", + "test_transforms", + "num_val_workers", + "num_test_workers", + "property_units", + "distance_unit", + "data_workdir", + "cleanup_workdir_stage", + } + used_legacy_args = [k for k in legacy_args if k in kwargs] + + if used_legacy_args: + warnings.warn( + "The following arguments are deprecated in `AtomsDataModuleV2`: " + f"{used_legacy_args}. " + "Use a prebuilt dataset instance and configure these options on the " + "dataset instead.", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.dataset = dataset From 3bbc31dfb172c688e2a766746edb097dcbd59cf9 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 16:34:43 +0100 Subject: [PATCH 38/68] refactor: replace property_unit_dict with _native_property_units method in QM7X dataset --- src/schnetpack/datasets/qm7x.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 054988904..d459f7a28 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -95,17 +95,6 @@ class QM7X(ASEAtomsData): FMBD = "FMBD" RMSD = "rmsd" - property_unit_dict = { - forces: "eV/Ang", - energy: "eV", - Eat: "eV", - EPBE0: "eV", - EMBD: "eV", - FPBE0: "eV/Ang", - FMBD: "eV/Ang", - RMSD: "Ang", - } - property_dataset_keys = { forces: "totFOR", energy: "ePBE0+MBD", @@ -184,6 +173,19 @@ def __init__( self._apply_structure_filter(original_subset_idx=subset_idx) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + QM7X.forces: "totFOR", + QM7X.energy: "ePBE0+MBD", + QM7X.Eat: "eAT", + QM7X.EPBE0: "ePBE0", + QM7X.EMBD: "eMBD", + QM7X.FPBE0: "pbe0FOR", + QM7X.FMBD: "vdwFOR", + QM7X.RMSD: "sRMSD", + } + def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> None: effective_subset = original_subset_idx @@ -228,7 +230,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: dataset = ASEAtomsData.create( datapath=datapath, distance_unit=distance_unit, - property_unit_dict=QM7X.property_unit_dict, + property_unit_dict=self._native_property_units(), atomrefs=atomrefs, ) @@ -317,7 +319,7 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): key: np.array( conf[QM7X.property_dataset_keys[key]], dtype=np.float64 ) - for key in QM7X.property_unit_dict.keys() + for key in QM7X._native_property_units().keys() } if "opt" in conf_id: From ed7328494fb856f7d4ed144393b574ae56a0add2 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 16:37:19 +0100 Subject: [PATCH 39/68] refactor: enhance QM9 dataset with train/val/test transform options --- src/schnetpack/datasets/qm9.py | 55 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 0083d3a3f..fd84e2b4f 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -7,7 +7,6 @@ import tempfile from typing import Dict, List, Optional from urllib import request as request -import torch import numpy as np from ase import Atoms @@ -17,6 +16,8 @@ import schnetpack.properties as structure from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform + __all__ = ["QM9"] @@ -57,7 +58,10 @@ def __init__( datapath: str, remove_uncharacterized: bool = False, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -69,10 +73,12 @@ def __init__( remove_uncharacterized: do not include uncharacterized molecules. load_properties: subset of properties to load transforms: Transform applied to each system separately before batching. + train_transforms: optional train-only transforms + val_transforms: optional val-only transforms + test_transforms: optional test-only transforms subset_idx: indices of the subset to load. property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. """ self.remove_uncharacterized = remove_uncharacterized @@ -93,7 +99,6 @@ def __init__( @staticmethod def _native_property_units() -> Dict[str, str]: - # IMPORTANT: full native QM9 schema, stored in DB metadata return { QM9.A: "GHz", QM9.B: "GHz", @@ -138,40 +143,36 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: return tmpdir = tempfile.mkdtemp("qm9") - try: - atomrefs = self._download_atomrefs(tmpdir) - self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=atomrefs, - ) + atomrefs = self._download_atomrefs(tmpdir) + + dataset = self.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=atomrefs, + ) + + if self.remove_uncharacterized: + uncharacterized = self._download_uncharacterized(tmpdir) + else: + uncharacterized = None - if self.remove_uncharacterized: - uncharacterized = self._download_uncharacterized(tmpdir) - else: - uncharacterized = None + self._download_data(tmpdir, dataset, uncharacterized) - self._download_data(tmpdir, uncharacterized) - finally: - shutil.rmtree(tmpdir, ignore_errors=True) + shutil.rmtree(tmpdir, ignore_errors=True) def _download_file(self, file_id: str, destination: str) -> None: for base_url in self.base_urls: url = f"{base_url}{file_id}" - try: - request.urlretrieve(url, destination) - return - except Exception: - logging.warning(f"Could not download from {url}, trying next source...") + request.urlretrieve(url, destination) + return raise AtomsDataError( f"Could not download file with id {file_id} from any source." ) def _download_uncharacterized(self, tmpdir: str) -> List[int]: - logging.info("Downloading list of uncharacterized molecules...") tmp_path = os.path.join(tmpdir, "uncharacterized.txt") self._download_file(self.file_ids["uncharacterized"], tmp_path) @@ -204,7 +205,7 @@ def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: def _download_data( self, tmpdir: str, - # dataset: ASEAtomsData, + dataset: ASEAtomsData, uncharacterized: Optional[List[int]], ) -> None: @@ -257,5 +258,5 @@ def _download_data( property_list.append(properties) logging.info("Write atoms to db...") - self.add_systems(property_list=property_list) + dataset.add_systems(property_list=property_list) logging.info("Done.") From e627a6962a5bada49cca33fc23725455f814bef8 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 22:07:07 +0100 Subject: [PATCH 40/68] refactor: add train/val/test transform options and docstring across multiple dataset classes --- src/schnetpack/datasets/ani1.py | 17 +++++--- src/schnetpack/datasets/iso17.py | 20 +++++++--- src/schnetpack/datasets/materials_project.py | 22 ++++++++++- src/schnetpack/datasets/md17.py | 36 ++++++++++------- src/schnetpack/datasets/md22.py | 22 +++++++---- src/schnetpack/datasets/omdb.py | 29 +++++++------- src/schnetpack/datasets/qm7x.py | 23 +++++++---- src/schnetpack/datasets/qm9.py | 19 +++++---- src/schnetpack/datasets/rmd17.py | 21 ++++++---- src/schnetpack/datasets/tmqm.py | 41 ++++++++------------ 10 files changed, 154 insertions(+), 96 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index e8dce9f4e..7b78d54c9 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -6,7 +6,6 @@ from typing import Dict, List, Optional from urllib import request as request -import torch import h5py import numpy as np from ase import Atoms @@ -42,7 +41,10 @@ def __init__( num_heavy_atoms: int = 8, high_energies: bool = False, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -55,10 +57,12 @@ def __init__( high_energies: whether to include high-energy conformations load_properties: subset of properties to load transforms: Transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing subset_idx: indices of the subset to load - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). """ self.num_heavy_atoms = num_heavy_atoms self.high_energies = high_energies @@ -72,6 +76,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index 852666afa..2bbf058b6 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -7,11 +7,11 @@ from urllib import request as request from urllib.error import HTTPError, URLError -import torch import numpy as np from ase.db import connect from tqdm import tqdm +from schnetpack.transform.base import Transform from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["ISO17"] @@ -43,7 +43,10 @@ def __init__( datapath: str, fold: str, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -54,11 +57,13 @@ def __init__( datapath: path to dataset fold: select a specific dataset of iso17 load_properties: subset of properties to load - transforms: Transform applied to each system separately before batching + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing subset_idx: indices of the subset to load - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ if fold not in self.existing_folds: raise AtomsDataError(f"Fold {fold} does not exist.") @@ -74,6 +79,9 @@ def __init__( datapath=dbpath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index d361d555c..400d3c2d7 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -2,7 +2,6 @@ import os from typing import List, Optional, Dict -import torch import numpy as np from ase import Atoms @@ -33,13 +32,29 @@ def __init__( self, datapath: str, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, apikey: Optional[str] = None, **kwargs, ): + """ + Args: + datapath: path to dataset + load_properties: subset of properties to load + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) + apikey: api key use to get data + """ if apikey is not None and len(apikey) == 16: raise DeprecationWarning( "You are using a legacy API key. This API is deprecated and no longer " @@ -65,6 +80,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index ae074d353..eaf860edd 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -4,12 +4,12 @@ import tempfile from typing import List, Optional, Dict from urllib import request as request -import torch import numpy as np from ase import Atoms import schnetpack.properties as structure +from schnetpack.transform.base import Transform from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["MD17"] @@ -34,7 +34,10 @@ def __init__( tmpdir: str = "gdml_tmp", atomrefs: Optional[Dict[str, List[float]]] = None, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -42,18 +45,20 @@ def __init__( ): """ Args: - datasets_dict: dictionary mapping molecule names to dataset names. - download_url: URL where individual molecule datasets can me found. - datapath: path to dataset. - molecule: name of the molecule. - tmpdir: name of temporary directory used for parsing. - atomrefs: properties of free atoms. - load_properties: subset of properties to load. - transforms: Transform applied to each system separately before batching. - subset_idx: indices of the subset to load. - property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + datasets_dict: dictionary mapping molecule names to dataset names + download_url: URL where individual molecule datasets can me found + datapath: path to dataset + molecule: name of the molecule + tmpdir: name of temporary directory used for parsing + atomrefs: properties of free atoms + load_properties: subset of properties to load + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ self.datasets_dict = datasets_dict self.download_url = download_url @@ -74,6 +79,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/md22.py b/src/schnetpack/datasets/md22.py index ffb1db528..4cc0f8997 100644 --- a/src/schnetpack/datasets/md22.py +++ b/src/schnetpack/datasets/md22.py @@ -1,6 +1,6 @@ from typing import Optional, Dict, List -import torch from schnetpack.datasets.md17 import GDMLDataset +from schnetpack.transform.base import Transform __all__ = ["MD22"] @@ -18,7 +18,10 @@ def __init__( datapath: str, molecule: str, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -29,11 +32,13 @@ def __init__( datapath: path to dataset molecule: name of the molecule load_properties: subset of properties to load - transforms: Transform applied to each system separately before batching. - subset_idx: indices of the subset to load. - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ atomrefs = { self.energy: [ @@ -67,6 +72,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index dd9b6f346..9de26fb47 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -3,12 +3,11 @@ import tarfile from typing import List, Optional, Dict from ase.io import read - import numpy as np -import torch from schnetpack.data import * from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.transform.base import Transform __all__ = ["OrganicMaterialsDatabase"] @@ -40,10 +39,10 @@ def __init__( load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, @@ -63,15 +62,15 @@ def __init__( load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + transforms: transform applied to each system separately before batching. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. + num_workers: number of data loader workers. + num_val_workers: number of validation data loader workers (overrides num_workers). + num_test_workers: number of test data loader workers (overrides num_workers). + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). raw_path: path to raw tar.gz file with the data """ super().__init__( diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index d459f7a28..e52619e0d 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -9,12 +9,11 @@ from urllib import request as request import h5py -import torch import numpy as np import progressbar from ase import Atoms -import schnetpack.properties as structure +from schnetpack.transform.base import Transform from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM7X"] @@ -123,7 +122,10 @@ def __init__( only_equilibrium: bool = False, only_non_equilibrium: bool = False, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -137,11 +139,13 @@ def __init__( only_equilibrium: only include equilibrium molecules only_non_equilibrium: only include non-equilibrium molecules load_properties: subset of properties to load - transforms: Transform applied to each system separately before batching + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing subset_idx: indices of the subset to load - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ if only_equilibrium and only_non_equilibrium: @@ -165,7 +169,10 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, - subset_idx=None, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index fd84e2b4f..cdff9f94d 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -72,13 +72,13 @@ def __init__( datapath: path to dataset remove_uncharacterized: do not include uncharacterized molecules. load_properties: subset of properties to load - transforms: Transform applied to each system separately before batching. - train_transforms: optional train-only transforms - val_transforms: optional val-only transforms - test_transforms: optional test-only transforms - subset_idx: indices of the subset to load. - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ self.remove_uncharacterized = remove_uncharacterized @@ -91,6 +91,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, @@ -241,7 +244,7 @@ def _download_data( lines = f.readlines() values = lines[1].split()[2:] - for pname, value in zip(self.available_properties, values): + for pname, value in zip(dataset.available_properties, values): properties[pname] = np.array([float(value)]) for line in lines: diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 204f21f3d..317cb41ac 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -6,7 +6,6 @@ from typing import Dict, List, Optional from urllib.request import Request, urlopen from urllib.error import HTTPError, URLError -import torch import numpy as np from ase import Atoms @@ -69,7 +68,10 @@ def __init__( datapath: str, molecule: str, load_properties: Optional[List[str]] = None, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -80,11 +82,13 @@ def __init__( datapath: path to dataset molecule: name of the molecule load_properties: subset of properties to load - transforms: Transform applied to each system separately before batching. - subset_idx: indices of the subset to load. - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ if molecule not in self.datasets_dict: @@ -101,6 +105,9 @@ def __init__( datapath=datapath, load_properties=load_properties, transforms=transforms, + train_transforms=train_transforms, + val_transforms=val_transforms, + test_transforms=test_transforms, subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index b08527414..cbf75dfd5 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -1,24 +1,17 @@ -import io -import logging import os -import re import shutil -import tarfile import tempfile from typing import List, Optional, Dict from urllib import request as request import gzip import numpy as np -from ase import Atoms -from ase.io.extxyz import read_xyz from ase.io import read -from tqdm import tqdm -import torch from schnetpack.data import * -import schnetpack.properties as structure -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.data import AtomsDataModule +from schnetpack.transform.base import Transform + __all__ = ["TMQM"] @@ -62,10 +55,10 @@ def __init__( load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, @@ -87,16 +80,16 @@ def __init__( remove_uncharacterized: do not include uncharacterized molecules. val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + transforms: transform applied to each system separately before batching. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. + num_workers: number of data loader workers. + num_val_workers: number of validation data loader workers (overrides num_workers). + num_test_workers: number of test data loader workers (overrides num_workers). + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). + data_workdir: copy data here as part of setup, e.g. cluster scratch for faster performance. """ super().__init__( datapath=datapath, From 607eab1629a51fdfcd86dd777e10c32f1b7e561e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 22:12:25 +0100 Subject: [PATCH 41/68] refactor: update docstrings in atomistic transforms --- src/schnetpack/transform/atomistic.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index 11103a625..b7a6f2337 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -22,6 +22,7 @@ class SubtractCenterOfMass(Transform): """ Subtract center of mass from positions. + """ is_preprocessor: bool = True @@ -44,6 +45,7 @@ def forward( class SubtractCenterOfGeometry(Transform): """ Subtract center of geometry from positions. + """ is_preprocessor: bool = True @@ -62,8 +64,6 @@ class RemoveOffsets(Transform): Remove offsets from property based on the mean of the training data and/or the single atom reference calculations. - The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, - when it is used. Otherwise, they have to be provided in the init manually. """ is_preprocessor: bool = True @@ -90,6 +90,7 @@ def __init__( tensor. atomrefs: Provide single-atom references directly. property_mean: Provide mean property value / n_atoms. + estimate_atomref: If true, add estimated atomrefs. """ super().__init__() self._property = property @@ -255,12 +256,6 @@ class AddOffsets(Transform): Add offsets to property based on the mean of the training data and/or the single atom reference calculations. - The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, - when it is used. Otherwise, they have to be provided in the init manually. - - Hint: - Place this postprocessor after casting to float64 for higher numerical - precision. """ is_preprocessor: bool = False @@ -288,6 +283,7 @@ def __init__( tensor. atomrefs: Provide single-atom references directly. property_mean: Provide mean property value / n_atoms. + estimate_atomref: If true, add estimated atomrefs. """ super().__init__() self._property = property From 2a6ce76fef9a3b6d60f6b7b88599e1190591ed42 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 22:22:54 +0100 Subject: [PATCH 42/68] refactor: add docstrings in ASEAtomsData --- src/schnetpack/data/atoms.py | 89 ++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 30f4306a9..d05a201f6 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -1,3 +1,16 @@ +""" +This module contains all functionalities required to load atomistic data, +generate batches and compute statistics. It makes use of the ASE database +for atoms [#ase2]_. + +References +---------- +.. [#ase2] Larsen, Mortensen, Blomqvist, Castelli, Christensen, Dułak, Friis, + Groves, Hammer, Hargus: + The atomic simulation environment -- a Python library for working with atoms. + Journal of Physics: Condensed Matter, 9, 27. 2017. +""" + import copy import logging import os @@ -40,6 +53,20 @@ def __init__( property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, ): + """ + Args: + datapath: Path to ASE DB. + load_properties: Set of properties to be loaded and returned. + If None, all properties in the ASE dB will be returned. + load_structure: If True, load structure properties. + transforms: preprocessing torch.nn.Module (see schnetpack.data.transforms) + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: List of data indices. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) + """ self.datapath = datapath self.subset_idx = subset_idx self._check_db() @@ -195,6 +222,19 @@ def iter_properties( load_structure: Optional[bool] = None, load_metadata: bool = False, ): + """ + Return property dictionary at given indices. + + Args: + indices: data indices + load_properties (sequence or None): subset of available properties to load + load_structure: load and return structure + load_metadata: load and return metadata + + Returns: + properties (dict): dictionary with molecular properties + + """ if load_properties is None: load_properties = self.load_properties if load_structure is None: @@ -230,6 +270,19 @@ def _get_properties( load_structure: bool, load_metadata: bool = False, ): + """ + Load properties of a single system from the ASE database. + + Args: + conn: ASE database connection. + idx: Zero-based system index. + load_properties: Properties to load. + load_structure: Whether to load structural information. + load_metadata: Whether to load metadata. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing the requested properties. + """ row = conn.get(idx + 1) # TODO: can the copies be avoided? properties: Dict[str, torch.Tensor] = {} @@ -268,6 +321,23 @@ def create( atomrefs: Optional[Dict[str, List[float]]] = None, **kwargs, ) -> "ASEAtomsData": + """ + + Args: + datapath: Path to ASE DB. + distance_unit: unit of atom positions and cell + property_unit_dict: Defines the available properties of the datasetseta and + provides units for ALL properties of the dataset. If a property is + unit-less, you can pass "arb. unit" or `None`. + atomrefs: dictionary mapping properies (the keys) to lists of single-atom + reference values of the property. This is especially useful for + extensive properties such as the energy, where the single atom energies + contribute a major part to the overall value. + + Returns: + newly created ASEAtomsData + + """ if not datapath.endswith(".db"): raise AtomsDataError("Invalid datapath! Add '.db' extension.") if os.path.exists(datapath): @@ -313,6 +383,22 @@ def add_systems( atoms_list: Optional[List[Atoms]] = None, atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, ): + """ + Add atoms data to the dataset. + + Args: + property_list: Properties as list of key-value pairs in the same + order as corresponding list of `atoms`. + Keys have to match the `available_properties` of the dataset + plus additional structure properties, if atoms is None. + atoms_list: System composition and geometry. If Atoms are None, + the structure needs to be given as part of the property dicts + (using structure.Z, structure.R, structure.cell, structure.pbc) + atoms_metadata_list: Metadata of the atoms objects as list of key-value pairs in the same + order as corresponding list of `atoms`. + Metadata can not be used as a training property, but can be used for splitting + strategies (e.g. material_id, timestamp, ...). + """ if atoms_list is None: atoms_list = [None] * len(property_list) if atoms_metadata_list is None: @@ -329,6 +415,9 @@ def _add_system( atoms_metadata: Optional[Dict[str, Any]] = None, **properties, ): + """ + Add systems to DB. + """ if atoms is None: try: Z = properties[structure.Z] From f0dd4c5ebae688cedf400d24d0ccbad0ed616ace Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 22:32:11 +0100 Subject: [PATCH 43/68] refactor: add docstrings for calculate_stats() and estimate_atomrefs() --- src/schnetpack/data/stats.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 6a7e68b3a..9df2365aa 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -18,6 +18,25 @@ def calculate_stats( num_workers: int = 4, loader_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: + """ + Use the incremental Welford algorithm described in [h1]_ to accumulate + the mean and standard deviation over a set of samples. + + References: + ----------- + .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + + Args: + dataset: Dataset used to compute statistics. + divide_by_atoms: Mapping from property name to bool indicating whether the property + should be divided by the number of atoms before computing statistics. + atomref: Optional single-atom reference values to subtract before computing statistics. + batch_size: Batch size used for the temporary data loader. + num_workers: Number of workers used by the data loader. + + Returns: + Mapping from property name to `(mean, std)` tensors. + """ loader_kwargs = loader_kwargs or {} dataloader = AtomsLoader( @@ -83,6 +102,20 @@ def estimate_atomrefs( num_workers: int = 4, loader_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: + """ + Uses linear regression to estimate the elementwise biases (atomrefs). + + Args: + dataset: Dataset used to estimate atom reference values. + is_extensive: Mapping from property name to bool indicating whether the property is + extensive. If False, atom type counts are divided by the number of atoms before fitting. + z_max: Maximum atomic number used to size the atomref tensors. + batch_size: Batch size used for the temporary data loader. + num_workers: Number of workers used by the data loader. + + Returns: + Mapping from property name to estimated atom reference tensor. + """ loader_kwargs = loader_kwargs or {} dataloader = AtomsLoader( From 4ca36484ac5ff0eb328364ca665d78a4f7a45feb Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 22:59:49 +0100 Subject: [PATCH 44/68] refactor: simplify transform initialization in ASEAtomsData and update docstring format in AtomsDataModuleV2 --- src/schnetpack/data/atoms.py | 18 ++++--------- src/schnetpack/data/datamodule_v2.py | 37 ++++++++++++++------------- src/schnetpack/transform/atomistic.py | 2 +- 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index d05a201f6..3a4a2d098 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -14,7 +14,6 @@ import copy import logging import os -from enum import Enum from typing import Optional, List, Dict, Any, Iterable, Union import torch @@ -72,18 +71,11 @@ def __init__( self._check_db() self.conn = connect(self.datapath, use_lock_file=False) - self.transforms: List[Transform] = ( - list(transforms) if transforms is not None else [] - ) - self.train_transforms: Optional[List[Transform]] = ( - list(train_transforms) if train_transforms is not None else None - ) - self.val_transforms: Optional[List[Transform]] = ( - list(val_transforms) if val_transforms is not None else None - ) - self.test_transforms: Optional[List[Transform]] = ( - list(test_transforms) if test_transforms is not None else None - ) + self.transforms = list(transforms or []) + self.train_transforms = list(train_transforms) if train_transforms else None + self.val_transforms = list(val_transforms) if val_transforms else None + self.test_transforms = list(test_transforms) if test_transforms else None + self._load_properties: Optional[List[str]] = None self.load_structure = load_structure diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index e932ef71a..ddf68dc49 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -39,24 +39,25 @@ def __init__( **kwargs, ): """ - dataset: prebuilt ASEAtomsData dataset instance - batch_size: (train) batch size - num_train: number of training examples (absolute or relative) - num_val: number of validation examples (absolute or relative) - num_test: number of test examples (absolute or relative) - split_file: path to npz file with data partitions - splitting: Method to generate train/validation/test partitions - (default: RandomSplit) - num_workers: Number of data loader workers - val_batch_size: validation batch size. If None, use test_batch_size, then - batch_size - test_batch_size: test batch size. If None, use val_batch_size, then - batch_size - train_sampler_cls: type of torch training sampler. - This is by default wrapped into a torch.utils.data.BatchSampler. - train_sampler_args: dict of train_sampler keyword arguments. - pin_memory: If true, pin memory of loaded data to GPU. Default: Will be - set to true, when GPUs are used. + Args: + dataset: prebuilt ASEAtomsData dataset instance + batch_size: (train) batch size + num_train: number of training examples (absolute or relative) + num_val: number of validation examples (absolute or relative) + num_test: number of test examples (absolute or relative) + split_file: path to npz file with data partitions + splitting: Method to generate train/validation/test partitions + (default: RandomSplit) + num_workers: Number of data loader workers + val_batch_size: validation batch size. If None, use test_batch_size, then + batch_size + test_batch_size: test batch size. If None, use val_batch_size, then + batch_size + train_sampler_cls: type of torch training sampler. + This is by default wrapped into a torch.utils.data.BatchSampler. + train_sampler_args: dict of train_sampler keyword arguments. + pin_memory: If true, pin memory of loaded data to GPU. Default: Will be + set to true, when GPUs are used. """ legacy_args = { "datapath", diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index b7a6f2337..b6fde9c25 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -22,7 +22,7 @@ class SubtractCenterOfMass(Transform): """ Subtract center of mass from positions. - + """ is_preprocessor: bool = True From 93d9ce8c0ae8a0934e441d8aea4a34ecb60ee70f Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 23:38:39 +0100 Subject: [PATCH 45/68] refactor: restructure configs for all datasets --- src/schnetpack/configs/data/ani1.yaml | 20 ++++++++++--------- src/schnetpack/configs/data/custom.yaml | 15 +++++++------- src/schnetpack/configs/data/iso17.yaml | 8 +++++--- .../configs/data/materials_project.yaml | 10 ++++++---- src/schnetpack/configs/data/md17.yaml | 10 +++++++--- src/schnetpack/configs/data/md22.yaml | 10 +++++++--- src/schnetpack/configs/data/omdb.yaml | 8 +++++--- src/schnetpack/configs/data/qm7x.yaml | 9 +++++++-- src/schnetpack/configs/data/qm9.yaml | 6 +----- src/schnetpack/configs/data/rmd17.yaml | 11 ++++++---- 10 files changed, 63 insertions(+), 44 deletions(-) diff --git a/src/schnetpack/configs/data/ani1.yaml b/src/schnetpack/configs/data/ani1.yaml index a3bdcb48c..2fd1dd23b 100644 --- a/src/schnetpack/configs/data/ani1.yaml +++ b/src/schnetpack/configs/data/ani1.yaml @@ -1,16 +1,18 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.ANI1 +dataset: + _target_: schnetpack.datasets.ANI1 + datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml + num_heavy_atoms: 8 + high_energies: false + distance_unit: Ang + property_units: + energy: eV + transforms: ${data.transforms} + -datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml batch_size: 32 num_train: 10000000 num_val: 100000 -num_heavy_atoms: 8 -high_energies: False - -# convert to typically used units -distance_unit: Ang -property_units: - energy: eV \ No newline at end of file diff --git a/src/schnetpack/configs/data/custom.yaml b/src/schnetpack/configs/data/custom.yaml index 6345d0a44..e71360659 100644 --- a/src/schnetpack/configs/data/custom.yaml +++ b/src/schnetpack/configs/data/custom.yaml @@ -1,26 +1,25 @@ # @package data _target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2 -# dataset must be provided by concrete config dataset: _target_: schnetpack.data.ASEAtomsData datapath: ??? load_properties: null distance_unit: Ang property_units: {} + transforms: ${data.transforms} + train_transforms: null + val_transforms: null + test_transforms: null batch_size: 10 num_train: ??? num_val: ??? num_test: null - split_file: ${run.data_dir}/split.npz splitting: null - -transforms: ${data.transforms} -train_transforms: null -val_transforms: null -test_transforms: null - num_workers: 8 +train_sampler_cls: null +train_sampler_args: {} +pin_memory: false diff --git a/src/schnetpack/configs/data/iso17.yaml b/src/schnetpack/configs/data/iso17.yaml index b86dc5474..364a65fe4 100644 --- a/src/schnetpack/configs/data/iso17.yaml +++ b/src/schnetpack/configs/data/iso17.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.ISO17 +dataset: + _target_: schnetpack.datasets.ISO17 + datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml + folder: reference -datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml -folder: reference batch_size: 32 num_train: 0.9 num_val: 0.1 diff --git a/src/schnetpack/configs/data/materials_project.yaml b/src/schnetpack/configs/data/materials_project.yaml index f8516adbe..800cc9532 100644 --- a/src/schnetpack/configs/data/materials_project.yaml +++ b/src/schnetpack/configs/data/materials_project.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MaterialsProject +dataset: + _target_: schnetpack.datasets.MaterialsProject + datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml + apikey: ??? -datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml batch_size: 32 num_train: 60000 -num_val: 2000 -apikey: ??? \ No newline at end of file +num_val: 2000 \ No newline at end of file diff --git a/src/schnetpack/configs/data/md17.yaml b/src/schnetpack/configs/data/md17.yaml index 59fb35a5a..e2b53be47 100644 --- a/src/schnetpack/configs/data/md17.yaml +++ b/src/schnetpack/configs/data/md17.yaml @@ -1,10 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MD17 - -datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml molecule: aspirin + +dataset: + _target_: schnetpack.datasets.MD17 + datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml + molecule: ${data.molecule} + batch_size: 10 num_train: 950 num_val: 50 diff --git a/src/schnetpack/configs/data/md22.yaml b/src/schnetpack/configs/data/md22.yaml index 89ce71e6f..322042c1f 100644 --- a/src/schnetpack/configs/data/md22.yaml +++ b/src/schnetpack/configs/data/md22.yaml @@ -1,10 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MD22 - -datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml molecule: Ac-Ala3-NHMe + +dataset: + _target_: schnetpack.datasets.MD22 + datapath: ${run.data_dir}/${data.molecule}.db + molecule: ${data.molecule} + batch_size: 10 num_train: 5700 num_val: 300 diff --git a/src/schnetpack/configs/data/omdb.yaml b/src/schnetpack/configs/data/omdb.yaml index 4e885c532..b9bf959c9 100644 --- a/src/schnetpack/configs/data/omdb.yaml +++ b/src/schnetpack/configs/data/omdb.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.OrganicMaterialsDatabase +dataset: + _target_: schnetpack.datasets.OrganicMaterialsDatabase + datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml + raw_path: null -datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml batch_size: 32 num_train: 0.8 num_val: 0.1 -raw_path: null \ No newline at end of file diff --git a/src/schnetpack/configs/data/qm7x.yaml b/src/schnetpack/configs/data/qm7x.yaml index 09cd4d805..95a21ed26 100644 --- a/src/schnetpack/configs/data/qm7x.yaml +++ b/src/schnetpack/configs/data/qm7x.yaml @@ -1,9 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.QM7X +dataset: + _target_: schnetpack.datasets.QM7X + datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml + remove_duplicates: true + only_equilibrium: false + only_non_equilibrium: false -datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml batch_size: 100 num_train: 5550 num_val: 700 \ No newline at end of file diff --git a/src/schnetpack/configs/data/qm9.yaml b/src/schnetpack/configs/data/qm9.yaml index bb12876f7..f24337472 100644 --- a/src/schnetpack/configs/data/qm9.yaml +++ b/src/schnetpack/configs/data/qm9.yaml @@ -2,11 +2,6 @@ defaults: - custom -datapath: ${run.data_dir}/qm9.db -#legacy inputs -train_sampler_cls: null -train_sampler_args: {} - dataset: _target_: schnetpack.datasets.qm9.QM9 datapath: ${run.data_dir}/qm9.db @@ -22,6 +17,7 @@ dataset: lumo: eV gap: eV zpve: eV + transforms: ${data.transforms} batch_size: 100 num_train: 110000 diff --git a/src/schnetpack/configs/data/rmd17.yaml b/src/schnetpack/configs/data/rmd17.yaml index 76614e110..aff1c9194 100644 --- a/src/schnetpack/configs/data/rmd17.yaml +++ b/src/schnetpack/configs/data/rmd17.yaml @@ -1,11 +1,14 @@ +# @package data defaults: - custom +molecule: aspirin -_target_: schnetpack.datasets.rMD17 +dataset: + _target_: schnetpack.datasets.rMD17 + datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml + molecule: ${data.molecule} + split_id: null -datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml -molecule: aspirin batch_size: 10 num_train: 950 num_val: 50 -split_id: null \ No newline at end of file From 99ba77812707e352ddcd253a3735db5a1b21386d Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 12 Mar 2026 23:44:41 +0100 Subject: [PATCH 46/68] refactor: update omdb to support datamodulev2 --- src/schnetpack/datasets/omdb.py | 120 ++++++++++++++++---------------- 1 file changed, 59 insertions(+), 61 deletions(-) diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 9de26fb47..51b58cd51 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -2,18 +2,19 @@ import os import tarfile from typing import List, Optional, Dict -from ase.io import read + import numpy as np +from ase.io import read +import torch -from schnetpack.data import * -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError from schnetpack.transform.base import Transform __all__ = ["OrganicMaterialsDatabase"] -class OrganicMaterialsDatabase(AtomsDataModule): +class OrganicMaterialsDatabase(ASEAtomsData): """ Organic Materials Database (OMDB) of bulk organic crystals. Registration to the OMDB is free for academic users. This database contains DFT @@ -31,21 +32,12 @@ class OrganicMaterialsDatabase(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, transforms: Optional[List[Transform]] = None, train_transforms: Optional[List[Transform]] = None, val_transforms: Optional[List[Transform]] = None, test_transforms: Optional[List[Transform]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, raw_path: Optional[str] = None, @@ -54,85 +46,91 @@ def __init__( """ Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: transform applied to each system separately before batching. train_transforms: overrides transform_fn for training. val_transforms: overrides transform_fn for validation. test_transforms: overrides transform_fn for testing. - num_workers: number of data loader workers. - num_val_workers: number of validation data loader workers (overrides num_workers). - num_test_workers: number of test data loader workers (overrides num_workers). + subset_idx: indices of the subset to load. property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). raw_path: path to raw tar.gz file with the data """ + self.raw_path = raw_path + + self.download( + datapath=datapath, + distance_unit=distance_unit or "Ang", + ) + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, + load_structure=True, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - self.raw_path = raw_path - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = {OrganicMaterialsDatabase.BandGap: "eV"} - - dataset = ASEAtomsData( - datapath=self.datapath, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - ) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return {OrganicMaterialsDatabase.BandGap: "eV"} - self._convert(dataset) - else: - dataset = ASEAtomsData(self.datapath) - - def _convert(self, dataset): + def download(self, datapath: str, distance_unit: str = "Ang") -> None: """ - Converts .tar.gz to a .db file + Make sure the OMDB database exists. """ + if os.path.exists(datapath): + _ = ASEAtomsData(datapath=datapath, load_structure=False) + return + if self.raw_path is None or not os.path.exists(self.raw_path): - # TODO: can we download here automatically like QM9? - raise AtomsDataModuleError( + raise AtomsDataError( "The path to the raw dataset is not provided or invalid and the db-file does " "not exist!" ) - logging.info("Converting %s to a .db file.." % self.raw_path) - tar = tarfile.open(self.raw_path, "r:gz") - names = tar.getnames() - tar.extractall() - tar.close() - structures = read("structures.xyz", index=":") - Y = np.loadtxt("bandgaps.csv") - [os.remove(name) for name in names] + dataset = ASEAtomsData.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + ) + + self._convert(dataset) + + def _convert(self, dataset: ASEAtomsData) -> None: + """ + Converts .tar.gz to a .db file + """ + logging.info("Converting %s to a .db file..", self.raw_path) + + extract_dir = os.path.dirname(self.raw_path) or "." + with tarfile.open(self.raw_path, "r:gz") as tar: + names = tar.getnames() + tar.extractall(path=extract_dir) + + structures_path = os.path.join(extract_dir, "structures.xyz") + bandgaps_path = os.path.join(extract_dir, "bandgaps.csv") + + structures = read(structures_path, index=":") + y = np.loadtxt(bandgaps_path) atoms_list = [] property_list = [] for i, at in enumerate(structures): atoms_list.append(at) - property_list.append({OrganicMaterialsDatabase.BandGap: np.array([Y[i]])}) + property_list.append( + {OrganicMaterialsDatabase.BandGap: np.array([y[i]], dtype=np.float64)} + ) + dataset.add_systems(atoms_list=atoms_list, property_list=property_list) + + for name in names: + path = os.path.join(extract_dir, name) + if os.path.exists(path): + os.remove(path) From 5d44d74c7626262fc324dd6320af94ba1cba2811 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 14 Mar 2026 02:36:16 +0100 Subject: [PATCH 47/68] refactor: streamline transform assignment and initialization in AtomsDataModuleV2 --- src/schnetpack/data/datamodule_v2.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py index ddf68dc49..e00e34b13 100644 --- a/src/schnetpack/data/datamodule_v2.py +++ b/src/schnetpack/data/datamodule_v2.py @@ -149,21 +149,25 @@ def setup(self, stage: Optional[str] = None) -> None: val_transforms = self.dataset.val_transforms or transforms test_transforms = self.dataset.test_transforms or transforms - self._train_dataset.transforms = train_transforms - self._val_dataset.transforms = val_transforms - self._test_dataset.transforms = test_transforms + self._train_dataset.transforms = [] + self._val_dataset.transforms = [] + self._test_dataset.transforms = [] self.provider = StatsAtomrefProvider(self._train_dataset) - self._initialize_transforms(self._train_dataset) - self._initialize_transforms(self._val_dataset) - self._initialize_transforms(self._test_dataset) + self._initialize_transforms(train_transforms) + self._initialize_transforms(val_transforms) + self._initialize_transforms(test_transforms) + + self._train_dataset.transforms = train_transforms + self._val_dataset.transforms = val_transforms + self._test_dataset.transforms = test_transforms - def _initialize_transforms(self, dataset: ASEAtomsData) -> None: - if not dataset.transforms: + def _initialize_transforms(self, transforms) -> None: + if not transforms: return - for t in dataset.transforms: + for t in transforms: t.initialize(provider=self.provider, atomrefs=self.provider.train_atomrefs) def _load_partitions(self) -> None: From 43e605d6ed677dce0a37235d0cd41fa2d8ad1cff Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sat, 14 Mar 2026 03:23:06 +0100 Subject: [PATCH 48/68] refactor: update pytest test_stats to accept data and batch parameters directly --- tests/data/test_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index eda02e7e5..f710f23ee 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -108,9 +108,11 @@ def test_stats(): atomref = {"property1": torch.ones((100,)) / 3.0} for bs in range(1, 7): stats = calculate_stats( - AtomsLoader(data, batch_size=bs), + data, {"property1": True, "property2": False}, atomref=atomref, + batch_size=bs, + num_workers=0, ) assert np.allclose(stats["property1"][0].numpy(), np.array([0.0])) assert np.allclose(stats["property1"][1].numpy(), np.array([1.0])) From db0a96e32ab882f2a7f48b9232f71008c1808c5e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 15:21:28 +0100 Subject: [PATCH 49/68] refactor: improve ANI1 dataset loading and validation --- src/schnetpack/datasets/ani1.py | 37 +++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 7b78d54c9..77a3acfd4 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -5,6 +5,7 @@ import tempfile from typing import Dict, List, Optional from urllib import request as request +from ase.db import connect import h5py import numpy as np @@ -21,6 +22,8 @@ class ANI1(ASEAtomsData): """ ANI1 benchmark dataset. + This class adds convenience functions to download ANI1 from figshare and + load the data into pytorch. References: .. [#ani1] https://arxiv.org/abs/1708.04987 @@ -96,20 +99,32 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Ensure the ANI1 ASE DB exists. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, load_structure=False) + with connect(datapath, use_lock_file=False) as conn: + md = conn.metadata + + if md.get("num_heavy_atoms") != self.num_heavy_atoms: + raise AtomsDataError( + f"Existing ANI1 dataset was created with num_heavy_atoms={md.get('num_heavy_atoms')}, " + f"but requested num_heavy_atoms={self.num_heavy_atoms}." + ) + + if md.get("high_energies") != self.high_energies: + raise AtomsDataError( + f"Existing ANI1 dataset was created with high_energies={md.get('high_energies')}, " + f"but requested high_energies={self.high_energies}." + ) return tmpdir = tempfile.mkdtemp("ani1") - try: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self._create_atomrefs(), - ) - self._download_data(tmpdir, dataset) - finally: - shutil.rmtree(tmpdir, ignore_errors=True) + + dataset = self.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self._create_atomrefs(), + ) + self._download_data(tmpdir, dataset) + shutil.rmtree(tmpdir, ignore_errors=True) def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: logging.info("Downloading ANI-1 data...") From 074bee98910ec8e1b43c7b0ba49e80a73c9ed974 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 15:26:03 +0100 Subject: [PATCH 50/68] refactor: enhance QM9 dataset loading --- src/schnetpack/datasets/qm9.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index cdff9f94d..0d6dcc798 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -11,6 +11,7 @@ import numpy as np from ase import Atoms from ase.io.extxyz import read_xyz +from ase.db import connect from tqdm import tqdm @@ -128,16 +129,17 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: remove_uncharacterized setting. """ if os.path.exists(datapath): - dataset = ASEAtomsData(datapath=datapath, load_structure=False) + with connect(datapath, use_lock_file=False) as conn: + data_count = conn.count() - if self.remove_uncharacterized and len(dataset) == 133885: + if self.remove_uncharacterized and data_count == 133885: raise AtomsDataError( "The dataset at the chosen location contains the uncharacterized 3054 molecules. " "Choose a different location to reload the data or set " "`remove_uncharacterized=False`." ) - if (not self.remove_uncharacterized) and len(dataset) < 133885: + if (not self.remove_uncharacterized) and data_count < 133885: raise AtomsDataError( "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " "Choose a different location to reload the data or set " From d4ca4b06731ef97c007931109d45914b94fc2bf1 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 15:37:59 +0100 Subject: [PATCH 51/68] refactor: simplify download method in ISO17 dataset --- src/schnetpack/datasets/iso17.py | 102 +++++++++++++++++-------------- 1 file changed, 55 insertions(+), 47 deletions(-) diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index 2bbf058b6..8291f8004 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -95,63 +95,71 @@ def _native_property_units() -> Dict[str, str]: ISO17.forces: "eV/Ang", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: + def download(self, datapath: str) -> None: """ Ensure the ISO17 DB for the selected fold exists and has proper metadata. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, load_structure=False) return + """ + with connect(datapath, use_lock_file=False) as conn: + md = conn.metadata + + if md.get("_property_unit_dict") != self._native_property_units(): + raise AtomsDataError( + f"Existing ISO17 dataset at {datapath} has incompatible property units." + ) + + if md.get("_distance_unit") != "Ang": + raise AtomsDataError( + f"Existing ISO17 dataset at {datapath} has incompatible distance unit." + ) + """ self._download_data() def _download_data(self) -> None: logging.info("Downloading ISO17 database...") tmpdir = tempfile.mkdtemp("iso17") + tarpath = os.path.join(tmpdir, "iso17.tar.gz") + url = "http://www.quantum-machine.org/datasets/iso17.tar.gz" try: - tarpath = os.path.join(tmpdir, "iso17.tar.gz") - url = "http://www.quantum-machine.org/datasets/iso17.tar.gz" - - try: - request.urlretrieve(url, tarpath) - except HTTPError as e: - raise AtomsDataError( - f"HTTP Error {e.code} while downloading {url}" - ) from e - except URLError as e: - raise AtomsDataError( - f"URL Error {e.reason} while downloading {url}" - ) from e - - with tarfile.open(tarpath) as tar: - tar.extractall(self.root_path) - - # update metadata + convert energy into row.data for every fold - for fold in self.existing_folds: - dbpath = os.path.join(self.root_path, "iso17", fold + ".db") - tmp_dbpath = os.path.join(tmpdir, f"{fold}_tmp.db") - - with connect(dbpath) as conn: - with connect(tmp_dbpath) as tmp_conn: - tmp_conn.metadata = { - "_property_unit_dict": self._native_property_units(), - "_distance_unit": "Ang", - "atomrefs": {}, - } - - for idx in tqdm( - range(len(conn)), - desc=f"parsing database file {dbpath}", - ): - atmsrw = conn.get(idx + 1) - data = atmsrw.data - data[self.forces] = np.array(data[self.forces]) - data[self.energy] = np.array([atmsrw.total_energy]) - tmp_conn.write(atmsrw.toatoms(), data=data) - - os.remove(dbpath) - os.rename(tmp_dbpath, dbpath) - - finally: - shutil.rmtree(tmpdir, ignore_errors=True) + request.urlretrieve(url, tarpath) + + except HTTPError as e: + raise AtomsDataError(f"HTTP Error {e.code} while downloading {url}") from e + + except URLError as e: + raise AtomsDataError(f"URL Error {e.reason} while downloading {url}") from e + + with tarfile.open(tarpath) as tar: + tar.extractall(self.root_path) + + # update metadata + convert energy into row.data for every fold + for fold in self.existing_folds: + dbpath = os.path.join(self.root_path, "iso17", fold + ".db") + tmp_dbpath = os.path.join(tmpdir, f"{fold}_tmp.db") + + with connect(dbpath) as conn: + with connect(tmp_dbpath) as tmp_conn: + tmp_conn.metadata = { + "_property_unit_dict": self._native_property_units(), + "_distance_unit": "Ang", + "atomrefs": {}, + } + + for idx in tqdm( + range(len(conn)), + desc=f"parsing database file {dbpath}", + ): + atmsrw = conn.get(idx + 1) + data = atmsrw.data + data[self.forces] = np.array(data[self.forces]) + data[self.energy] = np.array([atmsrw.total_energy]) + tmp_conn.write(atmsrw.toatoms(), data=data) + + os.remove(dbpath) + os.rename(tmp_dbpath, dbpath) + + shutil.rmtree(tmpdir, ignore_errors=True) From 985032e3498263b086875a80d3b11fa6062be1d7 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 15:44:24 +0100 Subject: [PATCH 52/68] refactor: enhance MaterialsProject and update docstring --- src/schnetpack/datasets/materials_project.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 400d3c2d7..3790b910b 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -14,6 +14,8 @@ class MaterialsProject(ASEAtomsData): """ Materials Project (MP) database of bulk crystals. + This class adds convenient functions to download Materials Project data into + pytorch. References: @@ -103,7 +105,6 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Ensure the Materials Project ASE DB exists. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, load_structure=False) return if self.apikey is None: @@ -112,7 +113,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: "to get an API key." ) - dataset = ASEAtomsData.create( + dataset = self.create( datapath=datapath, distance_unit=distance_unit, property_unit_dict=self._native_property_units(), From 6bf81d3074308b7c004c9310a9e0909b5af310fc Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 15:54:05 +0100 Subject: [PATCH 53/68] refactor: optimize GDMLDataset download method in md17 --- src/schnetpack/datasets/md17.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index eaf860edd..f7cc431e5 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -7,6 +7,7 @@ import numpy as np from ase import Atoms +from ase.db import connect import schnetpack.properties as structure from schnetpack.transform.base import Transform @@ -97,8 +98,8 @@ def _native_property_units() -> Dict[str, str]: def download(self, datapath: str, distance_unit: str = "Ang") -> None: if os.path.exists(datapath): - dataset = ASEAtomsData(datapath, load_structure=False) - md = dataset.metadata + with connect(datapath, use_lock_file=False) as conn: + md = conn.metadata if "molecule" not in md: raise AtomsDataError( @@ -113,19 +114,17 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: return tmpdir = tempfile.mkdtemp(self.tmpdir) - try: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self._native_atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir) - finally: - shutil.rmtree(tmpdir, ignore_errors=True) - - def _download_data(self, tmpdir): + dataset = self.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self._native_atomrefs, + ) + dataset.update_metadata(molecule=self.molecule) + self._download_data(tmpdir, dataset) + shutil.rmtree(tmpdir, ignore_errors=True) + + def _download_data(self, tmpdir, dataset: ASEAtomsData) -> None: logging.info("Downloading {} data".format(self.molecule)) rawpath = os.path.join(tmpdir, self.datasets_dict[self.molecule]) url = self.download_url + self.datasets_dict[self.molecule] From 6f30249af6ae9945d08068ae7917ed8838d89305 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 16:12:23 +0100 Subject: [PATCH 54/68] refactor: enhance QM7X dataset docstring and improve download methods --- src/schnetpack/datasets/qm7x.py | 116 +++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 39 deletions(-) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index e52619e0d..41af189b1 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -22,6 +22,9 @@ def show_progress(block_num: int, block_size: int, total_size: int): + """ + progress callback for files downloads + """ global pbar if pbar is None: pbar = progressbar.ProgressBar(maxval=total_size) @@ -36,6 +39,9 @@ def show_progress(block_num: int, block_size: int, total_size: int): def download_and_check(url: str, target_path: str, checksum: str): + """ + Download file from url to tar_path and check md5 checksum. + """ file_name = url.split("/")[-1] if os.path.exists(target_path): @@ -61,6 +67,9 @@ def download_and_check(url: str, target_path: str, checksum: str): def extract_xz(source: str, target: str): + """ + helper to extract xz files. + """ s_file = source.split("/")[-1] t_file = target.split("/")[-1] @@ -82,18 +91,31 @@ def extract_xz(source: str, target: str): class QM7X(ASEAtomsData): """ - QM7-X dataset of equilibrium and non-equilibrium structures of small organic molecules. - """ + QM7-X a comprehensive dataset of > 40 physicochemical properties for ~4.2 M equilibrium and non-equilibrium + structure of small organic molecules with up to seven non-hydrogen (C, N, O, S, Cl) atoms. + This class adds convenient functions to download QM7-X and load the data into pytorch. + + References: + + .. [#qm7x_1] https://zenodo.org/record/4288677 - forces = "forces" - energy = "energy" - Eat = "Eat" - EPBE0 = "EPBE0" - EMBD = "EMBD" - FPBE0 = "FPBE0" - FMBD = "FMBD" - RMSD = "rmsd" + """ + # more molecular and atomic properties can be found in the original paper and added here + # Notice that adding more properties can drastically increase the size of the dataset + # adding more properties here requires to add them to the property_unit_dict + # and there key mapping in the raw dataset in property_dataset_keys. + + forces = "forces" # total ePBE0+MBD forces + energy = "energy" # ePBE0+MBD: total energy after convergence of the PBE0 exchange-correlation functional and the MBD dispersion correction + Eat = "Eat" # atomization energy using PBE0 energy per atom and ePBE0+MBD total energy + EPBE0 = "EPBE0" # ePBE0: total energy at the level of PBE0 + EMBD = "EMBD" # eMBD: total energy at the level of MBD + FPBE0 = "FMBD" # FPBE0: total ePBE0 forces + FMBD = "FMBD" # FMBD: total eMBD forces + RMSD = "rmsd" # root mean square deviation of the atomic positions from the equilibrium structure + + # the original keys in the raw dataset to query the properties property_dataset_keys = { forces: "totFOR", energy: "ePBE0+MBD", @@ -105,6 +127,7 @@ class QM7X(ASEAtomsData): RMSD: "sRMSD", } + # atom energies (atomrefs) from PBE0 EPBE0_atom = { 1: -13.641404161, 6: -1027.592489146, @@ -183,14 +206,14 @@ def __init__( @staticmethod def _native_property_units() -> Dict[str, str]: return { - QM7X.forces: "totFOR", - QM7X.energy: "ePBE0+MBD", - QM7X.Eat: "eAT", - QM7X.EPBE0: "ePBE0", - QM7X.EMBD: "eMBD", - QM7X.FPBE0: "pbe0FOR", - QM7X.FMBD: "vdwFOR", - QM7X.RMSD: "sRMSD", + QM7X.forces: "eV/Ang", + QM7X.energy: "eV", + QM7X.Eat: "eV", + QM7X.EPBE0: "eV", + QM7X.EMBD: "eV", + QM7X.FPBE0: "eV/Ang", + QM7X.FMBD: "eV/Ang", + QM7X.RMSD: "Ang", } def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> None: @@ -222,41 +245,42 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Download the QM7-X dataset and create the ASEAtomsData object. """ if os.path.exists(datapath): - _ = ASEAtomsData(datapath, load_structure=False) return tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") - try: - atomrefs = { - QM7X.energy: [ - QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 - for i in range(0, 18) - ] - } + atomrefs = { + QM7X.energy: [ + QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 + for i in range(0, 18) + ] + } - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=atomrefs, - ) + dataset = self.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=atomrefs, + ) - hd_files = self._download_data(tar_dir) - if self.remove_duplicates: - self._download_duplicates_ids(tar_dir) - self._parse_data(hd_files, dataset) + hd_files = self._download_data(tar_dir) + if self.remove_duplicates: + self._download_duplicates_ids(tar_dir) + self._parse_data(hd_files, dataset) - finally: - if self.raw_data_path is None: - shutil.rmtree(tar_dir, ignore_errors=True) + if self.raw_data_path is None: + shutil.rmtree(tar_dir, ignore_errors=True) def _download_duplicates_ids(self, tar_dir: str): + """ + download duplicates ids for QM7-X + """ url = "https://zenodo.org/record/4288677/files/DupMols.dat" target_path = os.path.join(tar_dir, "DupMols.dat") checksum = "5d886ccac38877c8cb26c07704dd1034" download_and_check(url, target_path, checksum) + # fetch duplicates ids dup_mols = [] with open(target_path, "r") as f: for line in f: @@ -265,7 +289,12 @@ def _download_duplicates_ids(self, tar_dir: str): self.duplicates_ids = dup_mols def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[str]: + """ + download data and extract them + """ file_ids = ["1000", "2000", "3000", "4000", "5000", "6000", "7000", "8000"] + + # file fingerprints to check integrity checksums = [ "b50c6a5d0a4493c274368cf22285503e", "4418a813daf5e0d44aa5a26544249ee6", @@ -292,6 +321,7 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st xz_path = os.path.join(tar_dir, f"{file_id}.xz") download_and_check(url, xz_path, checksums[i]) + # extract the compressed files extracted = [] for file_id in file_ids: xz_path = os.path.join(tar_dir, f"{file_id}.xz") @@ -302,6 +332,9 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st return extracted def _parse_data(self, files: List[str], dataset: ASEAtomsData): + """ + Parse the downloaded data files and add them to the dataset. + """ for file in files: logging.info(f"Parsing {os.path.basename(file)} ...") @@ -317,6 +350,7 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): with h5py.File(file, "r") as mol_dict: for _mol_id, mol in mol_dict.items(): for conf_id, conf in mol.items(): + # exclude equilibrium duplicates trunc_id = conf_id[::-1].split("-", 1)[-1][::-1] if self.remove_duplicates and trunc_id in self.duplicates_ids: continue @@ -329,6 +363,7 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): for key in QM7X._native_property_units().keys() } + # get the hierarchical ids for each system if "opt" in conf_id: conf_id = conf_id[:-3] + "d0" @@ -337,17 +372,20 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): atoms_list.append(ats) property_list.append(properties) + # save the hierarchical ids for each system in same order as the systems for key, idx in zip(groups_ids.keys(), ids): groups_ids[key].append(idx) logging.info(f"Write parsed data from {os.path.basename(file)} to db ...") dataset.add_systems(property_list=property_list, atoms_list=atoms_list) + # add the hierarchical ids to the metadata md = dataset.metadata if "groups_ids" in md: for key, ids in groups_ids.items(): groups_ids[key] = md["groups_ids"][key] + ids + # add the ids as in the database of the new added systems last_id = md["groups_ids"]["id"][-1] sys_ids = list(range(last_id + 1, last_id + len(atoms_list) + 1)) groups_ids["id"] = md["groups_ids"]["id"] + sys_ids From b5b896979cd57e8245444475034a04b6a7733d67 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Sun, 15 Mar 2026 16:20:21 +0100 Subject: [PATCH 55/68] refactor: improve rMD17 dataset loading and metadata handling --- src/schnetpack/datasets/rmd17.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 317cb41ac..ad74fdc02 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -9,6 +9,7 @@ import numpy as np from ase import Atoms +from ase.db import connect import schnetpack.properties as structure from schnetpack.data.atoms import ASEAtomsData, AtomsDataError @@ -91,7 +92,7 @@ def __init__( distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ - if molecule not in self.datasets_dict: + if molecule not in self.datasets_dict.keys(): raise AtomsDataError(f"Molecule {molecule} is not supported!") self.molecule = molecule @@ -126,8 +127,8 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: Ensure the ASE DB exists and matches the requested molecule. """ if os.path.exists(datapath): - dataset = ASEAtomsData(datapath, load_structure=False) - md = dataset.metadata + with connect(datapath, use_lock_file=False) as conn: + md = conn.metadata if "molecule" not in md: raise AtomsDataError( @@ -142,17 +143,15 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: return tmpdir = tempfile.mkdtemp("rmd17") - try: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self.atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir, dataset) - finally: - shutil.rmtree(tmpdir, ignore_errors=True) + dataset = self.create( + datapath=datapath, + distance_unit=distance_unit, + property_unit_dict=self._native_property_units(), + atomrefs=self.atomrefs, + ) + dataset.update_metadata(molecule=self.molecule) + self._download_data(tmpdir, dataset) + shutil.rmtree(tmpdir, ignore_errors=True) def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: logging.info("Downloading %s data...", self.molecule) From 799d5a8de97603afad98152c45ca4603766639e0 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Tue, 17 Mar 2026 16:58:39 +0100 Subject: [PATCH 56/68] refactor: enhance ASEAtomsData and QM9 dataset handling --- src/schnetpack/data/atoms.py | 20 ++++++++-- src/schnetpack/datasets/qm9.py | 69 ++++++++++++++-------------------- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 3a4a2d098..f9bf05e10 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -68,6 +68,9 @@ def __init__( """ self.datapath = datapath self.subset_idx = subset_idx + if not os.path.exists(self.datapath): + self.download() + self._check_db() self.conn = connect(self.datapath, use_lock_file=False) @@ -166,9 +169,13 @@ def _check_db(self): if not os.path.exists(self.datapath): raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") + with connect(self.datapath, use_lock_file=False) as conn: + n_structures = conn.count() + + if n_structures == 0: + raise AtomsDataError(f"ASE DB at {self.datapath} is empty") + if self.subset_idx is not None: - with connect(self.datapath, use_lock_file=False) as conn: - n_structures = conn.count() if max(self.subset_idx) >= n_structures: raise AtomsDataError("subset_idx contains out-of-range indices") @@ -312,7 +319,7 @@ def create( property_unit_dict: Dict[str, str], atomrefs: Optional[Dict[str, List[float]]] = None, **kwargs, - ) -> "ASEAtomsData": + ) -> None: """ Args: @@ -345,7 +352,7 @@ def create( "atomrefs": atomrefs, } - return ASEAtomsData(datapath, **kwargs) ##NO RETURN HERE + return def add_system( self, @@ -442,3 +449,8 @@ def _add_system( data[pname] = properties[pname] conn.write(atoms, data=data, key_value_pairs=atoms_metadata) + + def download(self): + raise NotImplementedError( + f"{self.__class__.__name__} must implement download()." + ) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 0d6dcc798..5a054e071 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -83,11 +83,6 @@ def __init__( """ self.remove_uncharacterized = remove_uncharacterized - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) - super().__init__( datapath=datapath, load_properties=load_properties, @@ -100,6 +95,7 @@ def __init__( distance_unit=distance_unit, **kwargs, ) + self._check_metadata() @staticmethod def _native_property_units() -> Dict[str, str]: @@ -121,39 +117,32 @@ def _native_property_units() -> Dict[str, str]: QM9.Cv: "cal/mol/K", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: - """ - Make sure the QM9 database exists. - - If the DB already exists, validate consistency with the - remove_uncharacterized setting. - """ - if os.path.exists(datapath): - with connect(datapath, use_lock_file=False) as conn: - data_count = conn.count() - - if self.remove_uncharacterized and data_count == 133885: - raise AtomsDataError( - "The dataset at the chosen location contains the uncharacterized 3054 molecules. " - "Choose a different location to reload the data or set " - "`remove_uncharacterized=False`." - ) - - if (not self.remove_uncharacterized) and data_count < 133885: - raise AtomsDataError( - "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " - "Choose a different location to reload the data or set " - "`remove_uncharacterized=True`." - ) - return - + def _check_metadata(self) -> None: + with connect(self.datapath, use_lock_file=False) as conn: + data_count = conn.count() + + if self.remove_uncharacterized and data_count == 133885: + raise AtomsDataError( + "The dataset at the chosen location contains the uncharacterized 3054 molecules. " + "Choose a different location to reload the data or set " + "`remove_uncharacterized=False`." + ) + + if (not self.remove_uncharacterized) and data_count < 133885: + raise AtomsDataError( + "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " + "Choose a different location to reload the data or set " + "`remove_uncharacterized=True`." + ) + + def download(self) -> None: tmpdir = tempfile.mkdtemp("qm9") atomrefs = self._download_atomrefs(tmpdir) - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, + self.create( + datapath=self.datapath, + distance_unit="Ang", property_unit_dict=self._native_property_units(), atomrefs=atomrefs, ) @@ -163,7 +152,7 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: else: uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized) + self._download_data(tmpdir, uncharacterized) shutil.rmtree(tmpdir, ignore_errors=True) @@ -210,7 +199,6 @@ def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: def _download_data( self, tmpdir: str, - dataset: ASEAtomsData, uncharacterized: Optional[List[int]], ) -> None: @@ -221,9 +209,8 @@ def _download_data( logging.info("Done.") logging.info("Extracting files...") - tar = tarfile.open(tar_path) - tar.extractall(raw_path) - tar.close() + with tarfile.open(tar_path) as tar: + tar.extractall(raw_path) logging.info("Done.") logging.info("Parse xyz files...") @@ -246,7 +233,7 @@ def _download_data( lines = f.readlines() values = lines[1].split()[2:] - for pname, value in zip(dataset.available_properties, values): + for pname, value in zip(self.available_properties, values): properties[pname] = np.array([float(value)]) for line in lines: @@ -263,5 +250,5 @@ def _download_data( property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") From 6ee96995db8319396aee6c32d73387545e602490 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 19 Mar 2026 01:27:24 +0100 Subject: [PATCH 57/68] refactor: streamline metadata _check_db() and dataset creation in ASEAtomsData --- src/schnetpack/data/atoms.py | 41 ++++++++++++++-------------------- src/schnetpack/datasets/qm9.py | 4 ++-- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index f9bf05e10..726eec5a2 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -84,15 +84,6 @@ def __init__( # units from metadata md = self.metadata - if "_distance_unit" not in md: - raise AtomsDataError( - "Dataset does not have a distance unit set. Please add units to the dataset." - ) - if "_property_unit_dict" not in md: - raise AtomsDataError( - "Dataset does not have property units set. Please add units to the dataset." - ) - if distance_unit: self.distance_conversion = spk.units.convert_units( md["_distance_unit"], distance_unit @@ -171,10 +162,21 @@ def _check_db(self): with connect(self.datapath, use_lock_file=False) as conn: n_structures = conn.count() + md = conn.metadata if n_structures == 0: raise AtomsDataError(f"ASE DB at {self.datapath} is empty") + if "_distance_unit" not in md: + raise AtomsDataError( + "Dataset does not have a distance unit set. Please add units to the dataset." + ) + + if "_property_unit_dict" not in md: + raise AtomsDataError( + "Dataset does not have property units set. Please add units to the dataset." + ) + if self.subset_idx is not None: if max(self.subset_idx) >= n_structures: raise AtomsDataError("subset_idx contains out-of-range indices") @@ -312,18 +314,15 @@ def _get_properties( # ---------- creation / writing ---------- - @staticmethod def create( - datapath: str, + self, distance_unit: str, property_unit_dict: Dict[str, str], atomrefs: Optional[Dict[str, List[float]]] = None, - **kwargs, ) -> None: """ Args: - datapath: Path to ASE DB. distance_unit: unit of atom positions and cell property_unit_dict: Defines the available properties of the datasetseta and provides units for ALL properties of the dataset. If a property is @@ -332,28 +331,22 @@ def create( reference values of the property. This is especially useful for extensive properties such as the energy, where the single atom energies contribute a major part to the overall value. - - Returns: - newly created ASEAtomsData - """ - if not datapath.endswith(".db"): + if not self.datapath.endswith(".db"): raise AtomsDataError("Invalid datapath! Add '.db' extension.") - if os.path.exists(datapath): - raise AtomsDataError(f"Dataset already exists: {datapath}") + if os.path.exists(self.datapath): + raise AtomsDataError(f"Dataset already exists: {self.datapath}") - os.makedirs(os.path.dirname(datapath) or ".", exist_ok=True) + os.makedirs(os.path.dirname(self.datapath) or ".", exist_ok=True) atomrefs = atomrefs or {} - with connect(datapath) as conn: + with connect(self.datapath) as conn: conn.metadata = { "_property_unit_dict": property_unit_dict, "_distance_unit": distance_unit, "atomrefs": atomrefs, } - return - def add_system( self, atoms: Optional[Atoms] = None, diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 5a054e071..e21d5bb3f 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -95,7 +95,6 @@ def __init__( distance_unit=distance_unit, **kwargs, ) - self._check_metadata() @staticmethod def _native_property_units() -> Dict[str, str]: @@ -117,7 +116,8 @@ def _native_property_units() -> Dict[str, str]: QM9.Cv: "cal/mol/K", } - def _check_metadata(self) -> None: + def _check_db(self) -> None: + super()._check_db() with connect(self.datapath, use_lock_file=False) as conn: data_count = conn.count() From 86df4d8312baf8418884c898113e0f1f0b966155 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 19 Mar 2026 02:26:11 +0100 Subject: [PATCH 58/68] refactor: enhance ASEAtomsData split transform and db creation --- src/schnetpack/data/atoms.py | 45 ++++++++++++++++++---------------- src/schnetpack/datasets/qm9.py | 21 ++++++++-------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 726eec5a2..064392b83 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -69,6 +69,7 @@ def __init__( self.datapath = datapath self.subset_idx = subset_idx if not os.path.exists(self.datapath): + self.create() self.download() self._check_db() @@ -78,6 +79,7 @@ def __init__( self.train_transforms = list(train_transforms) if train_transforms else None self.val_transforms = list(val_transforms) if val_transforms else None self.test_transforms = list(test_transforms) if test_transforms else None + self.split = None self._load_properties: Optional[List[str]] = None self.load_structure = load_structure @@ -109,7 +111,7 @@ def __init__( # ---------- merged ASEAtomsData bits ---------- - def subset(self, subset_idx: List[int]): + def subset(self, subset_idx: List[int], split: Optional[str] = None): if subset_idx is None: raise ValueError("subset_idx must be provided.") ds = copy.copy(self) @@ -117,6 +119,7 @@ def subset(self, subset_idx: List[int]): ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] else: ds.subset_idx = subset_idx + ds.split = split return ds @property @@ -152,7 +155,16 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: def _apply_transforms( self, props: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: - for tf in self.transforms: + if self.split == "train" and self.train_transforms is not None: + transforms = self.train_transforms + elif self.split == "val" and self.val_transforms is not None: + transforms = self.val_transforms + elif self.split == "test" and self.test_transforms is not None: + transforms = self.test_transforms + else: + transforms = self.transforms + + for tf in transforms: props = tf(props) return props @@ -314,23 +326,10 @@ def _get_properties( # ---------- creation / writing ---------- - def create( - self, - distance_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Optional[Dict[str, List[float]]] = None, - ) -> None: + def create(self) -> None: """ + Create a new ASE database at `self.datapath` and initialize its metadata. - Args: - distance_unit: unit of atom positions and cell - property_unit_dict: Defines the available properties of the datasetseta and - provides units for ALL properties of the dataset. If a property is - unit-less, you can pass "arb. unit" or `None`. - atomrefs: dictionary mapping properies (the keys) to lists of single-atom - reference values of the property. This is especially useful for - extensive properties such as the energy, where the single atom energies - contribute a major part to the overall value. """ if not self.datapath.endswith(".db"): raise AtomsDataError("Invalid datapath! Add '.db' extension.") @@ -339,12 +338,16 @@ def create( os.makedirs(os.path.dirname(self.datapath) or ".", exist_ok=True) - atomrefs = atomrefs or {} + if self.property_units is None: + raise AtomsDataError("property_units is not set in dataset class.") + if self.distance_unit is None: + raise AtomsDataError("distance_unit is not set in dataset class.") + with connect(self.datapath) as conn: conn.metadata = { - "_property_unit_dict": property_unit_dict, - "_distance_unit": distance_unit, - "atomrefs": atomrefs, + "_property_unit_dict": self.property_units, + "_distance_unit": self.distance_unit, + "atomrefs": {}, } def add_system( diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index e21d5bb3f..fdcdf1b1c 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -82,6 +82,8 @@ def __init__( distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ self.remove_uncharacterized = remove_uncharacterized + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -139,13 +141,9 @@ def download(self) -> None: tmpdir = tempfile.mkdtemp("qm9") atomrefs = self._download_atomrefs(tmpdir) - - self.create( - datapath=self.datapath, - distance_unit="Ang", - property_unit_dict=self._native_property_units(), - atomrefs=atomrefs, - ) + md = self.metadata + md["atomrefs"] = atomrefs + self._set_metadata(md) if self.remove_uncharacterized: uncharacterized = self._download_uncharacterized(tmpdir) @@ -158,10 +156,11 @@ def download(self) -> None: def _download_file(self, file_id: str, destination: str) -> None: for base_url in self.base_urls: - url = f"{base_url}{file_id}" - request.urlretrieve(url, destination) - return - + try: + request.urlretrieve(f"{base_url}{file_id}", destination) + return + except Exception: + continue raise AtomsDataError( f"Could not download file with id {file_id} from any source." ) From 1690b19889ff06e4b1d99409bf1b4b531b91938c Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 19 Mar 2026 04:17:48 +0100 Subject: [PATCH 59/68] refactor: streamline ANI1 and QM9 dataset handling and download methods --- src/schnetpack/datasets/ani1.py | 68 ++++++++++++++++----------------- src/schnetpack/datasets/qm9.py | 6 +++ 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 77a3acfd4..8db3583fd 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -69,11 +69,8 @@ def __init__( """ self.num_heavy_atoms = num_heavy_atoms self.high_energies = high_energies - - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -94,39 +91,42 @@ def _native_property_units() -> Dict[str, str]: ANI1.energy: "Hartree", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: + def _check_db(self) -> None: """ Ensure the ANI1 ASE DB exists. """ - if os.path.exists(datapath): - with connect(datapath, use_lock_file=False) as conn: - md = conn.metadata - - if md.get("num_heavy_atoms") != self.num_heavy_atoms: - raise AtomsDataError( - f"Existing ANI1 dataset was created with num_heavy_atoms={md.get('num_heavy_atoms')}, " - f"but requested num_heavy_atoms={self.num_heavy_atoms}." - ) - - if md.get("high_energies") != self.high_energies: - raise AtomsDataError( - f"Existing ANI1 dataset was created with high_energies={md.get('high_energies')}, " - f"but requested high_energies={self.high_energies}." - ) - return - + super()._check_db() + with connect(self.datapath, use_lock_file=False) as conn: + md = conn.metadata + + if md.get("num_heavy_atoms") != self.num_heavy_atoms: + raise AtomsDataError( + f"Existing ANI1 dataset was created with num_heavy_atoms={md.get('num_heavy_atoms')}, " + f"but requested num_heavy_atoms={self.num_heavy_atoms}." + ) + + if md.get("high_energies") != self.high_energies: + raise AtomsDataError( + f"Existing ANI1 dataset was created with high_energies={md.get('high_energies')}, " + f"but requested high_energies={self.high_energies}." + ) + + def download(self) -> None: + """ + Download ANI1 data and populate the ASE DB. + """ tmpdir = tempfile.mkdtemp("ani1") - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self._create_atomrefs(), - ) - self._download_data(tmpdir, dataset) + md = self.metadata + md["atomrefs"] = self._create_atomrefs() + md["num_heavy_atoms"] = self.num_heavy_atoms + md["high_energies"] = self.high_energies + self._set_metadata(md) + + self._download_data(tmpdir) shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: + def _download_data(self, tmpdir: str) -> None: logging.info("Downloading ANI-1 data...") tar_path = os.path.join(tmpdir, "ANI1_release.tar.gz") raw_path = os.path.join(tmpdir, "data") @@ -148,11 +148,11 @@ def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: for i in range(1, self.num_heavy_atoms + 1): file_name = os.path.join(raw_path, "ANI-1_release", f"ani_gdb_s0{i}.h5") logging.info("Start to parse %s", file_name) - self._load_h5_file(file_name, dataset) + self._load_h5_file(file_name) logging.info("Done.") - def _load_h5_file(self, file_name: str, dataset: ASEAtomsData) -> None: + def _load_h5_file(self, file_name: str) -> None: atoms_list = [] properties_list = [] @@ -185,7 +185,7 @@ def _load_h5_file(self, file_name: str, dataset: ASEAtomsData) -> None: atoms_list.append(atm) properties_list.append(properties) - dataset.add_systems(atoms_list=atoms_list, property_list=properties_list) + self.add_systems(atoms_list=atoms_list, property_list=properties_list) def _create_atomrefs(self) -> Dict[str, List[float]]: atref = np.zeros((100,)) diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index fdcdf1b1c..2408f288e 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -119,6 +119,9 @@ def _native_property_units() -> Dict[str, str]: } def _check_db(self) -> None: + """ + Ensure the QM9 ASE DB exists. + """ super()._check_db() with connect(self.datapath, use_lock_file=False) as conn: data_count = conn.count() @@ -138,6 +141,9 @@ def _check_db(self) -> None: ) def download(self) -> None: + """ + Download the QM9 ASE DB. + """ tmpdir = tempfile.mkdtemp("qm9") atomrefs = self._download_atomrefs(tmpdir) From 3962eb8b79e9efb6ec53f1676e1613ebdd1440d6 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 19 Mar 2026 04:31:09 +0100 Subject: [PATCH 60/68] refactor: update ISO17 dataset download method and improve property unit handling --- src/schnetpack/datasets/iso17.py | 33 +++++--------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index 8291f8004..f206447d9 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -70,11 +70,11 @@ def __init__( self.root_path = datapath self.fold = fold + self.distance_unit = "Ang" + self.property_units = self._native_property_units() dbpath = os.path.join(datapath, "iso17", fold + ".db") - self.download(datapath=dbpath, distance_unit=distance_unit or "Ang") - super().__init__( datapath=dbpath, load_properties=load_properties, @@ -95,30 +95,7 @@ def _native_property_units() -> Dict[str, str]: ISO17.forces: "eV/Ang", } - def download(self, datapath: str) -> None: - """ - Ensure the ISO17 DB for the selected fold exists and has proper metadata. - """ - if os.path.exists(datapath): - return - """ - with connect(datapath, use_lock_file=False) as conn: - md = conn.metadata - - if md.get("_property_unit_dict") != self._native_property_units(): - raise AtomsDataError( - f"Existing ISO17 dataset at {datapath} has incompatible property units." - ) - - if md.get("_distance_unit") != "Ang": - raise AtomsDataError( - f"Existing ISO17 dataset at {datapath} has incompatible distance unit." - ) - """ - - self._download_data() - - def _download_data(self) -> None: + def download(self) -> None: logging.info("Downloading ISO17 database...") tmpdir = tempfile.mkdtemp("iso17") tarpath = os.path.join(tmpdir, "iso17.tar.gz") @@ -141,8 +118,8 @@ def _download_data(self) -> None: dbpath = os.path.join(self.root_path, "iso17", fold + ".db") tmp_dbpath = os.path.join(tmpdir, f"{fold}_tmp.db") - with connect(dbpath) as conn: - with connect(tmp_dbpath) as tmp_conn: + with connect(dbpath, use_lock_file=False) as conn: + with connect(tmp_dbpath, use_lock_file=False) as tmp_conn: tmp_conn.metadata = { "_property_unit_dict": self._native_property_units(), "_distance_unit": "Ang", From 66ffe8173008fbcaba6badf0f0dc8ba566942534 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Thu, 19 Mar 2026 04:57:06 +0100 Subject: [PATCH 61/68] refactor: improve MaterialsProject API key validation and simplify download method --- src/schnetpack/datasets/materials_project.py | 44 ++++++-------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 3790b910b..6a8f2015e 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -57,14 +57,20 @@ def __init__( distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) apikey: api key use to get data """ - if apikey is not None and len(apikey) == 16: - raise DeprecationWarning( + if apikey is None: + raise AtomsDataError( + "No API key provided. Visit https://next-gen.materialsproject.org/ " + "to get an API key." + ) + + elif len(apikey) == 16: + raise AtomsDataError( "You are using a legacy API key. This API is deprecated and no longer " "supported by Materials Project. Please use the next-gen API instead. " "Visit https://next-gen.materialsproject.org/ to get a valid API key." ) - if apikey is not None and len(apikey) != 32: + elif len(apikey) != 32: raise AtomsDataError( "Invalid API key. MaterialsProject requires an API key of 32 characters. " f"Your API key contains {len(apikey)} characters. " @@ -72,11 +78,8 @@ def __init__( ) self.apikey = apikey - - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -100,28 +103,7 @@ def _native_property_units() -> Dict[str, str]: MaterialsProject.TotalMagnetization: "None", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: - """ - Ensure the Materials Project ASE DB exists. - """ - if os.path.exists(datapath): - return - - if self.apikey is None: - raise AtomsDataError( - "No API key provided. Visit https://next-gen.materialsproject.org/ " - "to get an API key." - ) - - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - ) - - self._download_data_nextgen(dataset) - - def _download_data_nextgen(self, dataset: ASEAtomsData) -> None: + def download(self) -> None: """ Download Materials Project entries and store them in the ASE DB. """ @@ -182,7 +164,7 @@ def _download_data_nextgen(self, dataset: ASEAtomsData) -> None: ) logging.info("Write atoms to db...") - dataset.add_systems( + self.add_systems( atoms_list=atoms_list, property_list=properties_list, atoms_metadata_list=atoms_metadata_list, From f9615264ea6130ff036df36cfbf92ab5da803d8c Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 23 Mar 2026 22:09:10 +0100 Subject: [PATCH 62/68] refactor: streamline GDMLDataset methods and enhance metadata handling in MD17 --- src/schnetpack/datasets/md17.py | 67 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index f7cc431e5..09bf1a268 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -18,9 +18,8 @@ class GDMLDataset(ASEAtomsData): """ - Base class for GDML-type datasets (e.g. MD17 or MD22). - Requires a dictionary translating between molecule and filenames - and a URL under which the molecular datasets can be found. + Base class for GDML type data (e.g. MD17 or MD22). Requires a dictionary translating between molecule and filenames + and an URL under which the molecular datasets can be found. """ energy = "energy" @@ -63,7 +62,7 @@ def __init__( """ self.datasets_dict = datasets_dict self.download_url = download_url - self._native_atomrefs = atomrefs + self._native_atomrefs = atomrefs or {} self.tmpdir = tmpdir if molecule not in self.datasets_dict: @@ -71,10 +70,8 @@ def __init__( self.molecule = molecule - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -96,35 +93,32 @@ def _native_property_units() -> Dict[str, str]: GDMLDataset.forces: "kcal/mol/Ang", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: - if os.path.exists(datapath): - with connect(datapath, use_lock_file=False) as conn: - md = conn.metadata + def _check_db(self) -> None: + super()._check_db() + md = self.metadata - if "molecule" not in md: - raise AtomsDataError( - "Not a valid GDML dataset. Metadata must contain `molecule`." - ) + if "molecule" not in md: + raise AtomsDataError( + "Not a valid GDML dataset. Metadata must contain `molecule`." + ) - if md["molecule"] != self.molecule: - raise AtomsDataError( - f"The dataset at the given location contains `{md['molecule']}` " - f"instead of `{self.molecule}`." - ) - return + if md["molecule"] != self.molecule: + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." + ) + def download(self) -> None: tmpdir = tempfile.mkdtemp(self.tmpdir) - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self._native_atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir, dataset) + md = self.metadata + md["atomrefs"] = self._native_atomrefs + md["molecule"] = self.molecule + self._set_metadata(md) + + self._download_data(tmpdir) shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data(self, tmpdir, dataset: ASEAtomsData) -> None: + def _download_data(self, tmpdir) -> None: logging.info("Downloading {} data".format(self.molecule)) rawpath = os.path.join(tmpdir, self.datasets_dict[self.molecule]) url = self.download_url + self.datasets_dict[self.molecule] @@ -152,7 +146,7 @@ def _download_data(self, tmpdir, dataset: ASEAtomsData) -> None: property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") @@ -170,7 +164,10 @@ def __init__( datapath: str, molecule: str, load_properties: Optional[List[str]] = None, - transforms=None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -182,10 +179,12 @@ def __init__( molecule: name of the molecule. load_properties: subset of properties to load. transforms: Transform applied to each system separately before batching. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. subset_idx: indices of the subset to load. property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). - **kwargs: additional keyword arguments. """ atomrefs = { self.energy: [ From 716b34eb66e5cc8b760247a91fa266467739e6eb Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 23 Mar 2026 22:09:31 +0100 Subject: [PATCH 63/68] refactor: add type hint to _check_db() method in ASEAtomsData --- src/schnetpack/data/atoms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 064392b83..e0d37a733 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -168,7 +168,7 @@ def _apply_transforms( props = tf(props) return props - def _check_db(self): + def _check_db(self) -> None: if not os.path.exists(self.datapath): raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") From 2b51b1c1f7f6da94f700d4a2d7ce8222c10f206c Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 23 Mar 2026 22:13:00 +0100 Subject: [PATCH 64/68] refactor: simplify download method in omdb --- src/schnetpack/datasets/omdb.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 51b58cd51..41a7a491a 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -57,11 +57,8 @@ def __init__( raw_path: path to raw tar.gz file with the data """ self.raw_path = raw_path - - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -81,27 +78,16 @@ def __init__( def _native_property_units() -> Dict[str, str]: return {OrganicMaterialsDatabase.BandGap: "eV"} - def download(self, datapath: str, distance_unit: str = "Ang") -> None: + def download(self) -> None: """ - Make sure the OMDB database exists. + Convert the OMDB raw archive into an ASE DB. """ - if os.path.exists(datapath): - _ = ASEAtomsData(datapath=datapath, load_structure=False) - return - if self.raw_path is None or not os.path.exists(self.raw_path): raise AtomsDataError( "The path to the raw dataset is not provided or invalid and the db-file does " "not exist!" ) - - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - ) - - self._convert(dataset) + self._convert() def _convert(self, dataset: ASEAtomsData) -> None: """ From ca88623d4d8b5f6ad1ac26a4d6a284d4ab692e6f Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 23 Mar 2026 22:41:49 +0100 Subject: [PATCH 65/68] refactor: simplify QM7X download method and enhance metadata handling --- src/schnetpack/datasets/qm7x.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 41af189b1..35ff5249d 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -182,10 +182,8 @@ def __init__( self.only_equilibrium = only_equilibrium self.only_non_equilibrium = only_non_equilibrium - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() # initialize without subset first, then apply dataset-specific filtering super().__init__( @@ -240,13 +238,10 @@ def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> N self.subset_idx = effective_subset - def download(self, datapath: str, distance_unit: str = "Ang") -> None: + def download(self) -> None: """ Download the QM7-X dataset and create the ASEAtomsData object. """ - if os.path.exists(datapath): - return - tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") atomrefs = { QM7X.energy: [ @@ -255,17 +250,10 @@ def download(self, datapath: str, distance_unit: str = "Ang") -> None: ] } - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=atomrefs, - ) - hd_files = self._download_data(tar_dir) if self.remove_duplicates: self._download_duplicates_ids(tar_dir) - self._parse_data(hd_files, dataset) + self._parse_data(hd_files) if self.raw_data_path is None: shutil.rmtree(tar_dir, ignore_errors=True) @@ -331,7 +319,7 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st return extracted - def _parse_data(self, files: List[str], dataset: ASEAtomsData): + def _parse_data(self, files: List[str]): """ Parse the downloaded data files and add them to the dataset. """ @@ -377,10 +365,10 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): groups_ids[key].append(idx) logging.info(f"Write parsed data from {os.path.basename(file)} to db ...") - dataset.add_systems(property_list=property_list, atoms_list=atoms_list) + self.add_systems(property_list=property_list, atoms_list=atoms_list) # add the hierarchical ids to the metadata - md = dataset.metadata + md = self.metadata if "groups_ids" in md: for key, ids in groups_ids.items(): groups_ids[key] = md["groups_ids"][key] + ids @@ -392,5 +380,5 @@ def _parse_data(self, files: List[str], dataset: ASEAtomsData): else: groups_ids["id"] = list(range(1, len(atoms_list) + 1)) - dataset.update_metadata(groups_ids=groups_ids) + self.update_metadata(groups_ids=groups_ids) logging.info("Done.") From 62114315ccddcea34302cb1080d1da66106eaeb8 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Mon, 23 Mar 2026 22:50:01 +0100 Subject: [PATCH 66/68] refactor: remove unused imports and streamline rMD17 dataset methods --- src/schnetpack/datasets/materials_project.py | 1 - src/schnetpack/datasets/md17.py | 1 - src/schnetpack/datasets/omdb.py | 1 - src/schnetpack/datasets/rmd17.py | 60 +++++++++----------- 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 6a8f2015e..a87bf730a 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -1,5 +1,4 @@ import logging -import os from typing import List, Optional, Dict import numpy as np diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 09bf1a268..6f003d31a 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -7,7 +7,6 @@ import numpy as np from ase import Atoms -from ase.db import connect import schnetpack.properties as structure from schnetpack.transform.base import Transform diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 41a7a491a..34fc918f8 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -5,7 +5,6 @@ import numpy as np from ase.io import read -import torch from schnetpack.data.atoms import ASEAtomsData, AtomsDataError from schnetpack.transform.base import Transform diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index ad74fdc02..0eb954317 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -9,7 +9,6 @@ import numpy as np from ase import Atoms -from ase.db import connect import schnetpack.properties as structure from schnetpack.data.atoms import ASEAtomsData, AtomsDataError @@ -97,10 +96,8 @@ def __init__( self.molecule = molecule - self.download( - datapath=datapath, - distance_unit=distance_unit or "Ang", - ) + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, @@ -122,38 +119,35 @@ def _native_property_units() -> Dict[str, str]: rMD17.forces: "kcal/mol/Ang", } - def download(self, datapath: str, distance_unit: str = "Ang") -> None: - """ - Ensure the ASE DB exists and matches the requested molecule. - """ - if os.path.exists(datapath): - with connect(datapath, use_lock_file=False) as conn: - md = conn.metadata + def _check_db(self) -> None: + super()._check_db() + md = self.metadata - if "molecule" not in md: - raise AtomsDataError( - "Not a valid rMD17 dataset. Metadata must contain `molecule`." - ) + if "molecule" not in md: + raise AtomsDataError( + "Not a valid rMD17 dataset. Metadata must contain `molecule`." + ) - if md["molecule"] != self.molecule: - raise AtomsDataError( - f"The dataset at the given location contains `{md['molecule']}` " - f"instead of `{self.molecule}`." - ) - return + if md["molecule"] != self.molecule: + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." + ) + def download(self) -> None: + """ + Download the requested rMD17 molecule and populate the ASE DB. + """ tmpdir = tempfile.mkdtemp("rmd17") - dataset = self.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=self._native_property_units(), - atomrefs=self.atomrefs, - ) - dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir, dataset) + md = self.metadata + md["atomrefs"] = self.atomrefs + md["molecule"] = self.molecule + self._set_metadata(md) + + self._download_data(tmpdir) shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: + def _download_data(self, tmpdir: str) -> None: logging.info("Downloading %s data...", self.molecule) raw_path = os.path.join(tmpdir, "rmd17") @@ -200,7 +194,7 @@ def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") train_splits = [] @@ -234,7 +228,7 @@ def _download_data(self, tmpdir: str, dataset: ASEAtomsData) -> None: ) test_splits.append(test_split) - dataset.update_metadata(splits={"known": train_splits, "test": test_splits}) + self.update_metadata(splits={"known": train_splits, "test": test_splits}) logging.info("Done.") def _download_archive(self, destination: str) -> None: From cf4c40665da3057268f273c52447621dc572a197 Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Fri, 27 Mar 2026 23:55:02 +0100 Subject: [PATCH 67/68] refactor: remove unused split_id from rMD17 dataset configuration --- src/schnetpack/configs/data/rmd17.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/schnetpack/configs/data/rmd17.yaml b/src/schnetpack/configs/data/rmd17.yaml index aff1c9194..1e6b56f66 100644 --- a/src/schnetpack/configs/data/rmd17.yaml +++ b/src/schnetpack/configs/data/rmd17.yaml @@ -7,7 +7,6 @@ dataset: _target_: schnetpack.datasets.rMD17 datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml molecule: ${data.molecule} - split_id: null batch_size: 10 num_train: 950 From 252627ed49a0bd1f6470aa88b9002cd29c03632e Mon Sep 17 00:00:00 2001 From: sundusaijaz Date: Fri, 27 Mar 2026 23:56:08 +0100 Subject: [PATCH 68/68] refactor: adjust train and test split calculations in rMD17 dataset --- src/schnetpack/datasets/rmd17.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 0eb954317..3ff05232f 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -212,8 +212,8 @@ def _download_data(self, tmpdir: str) -> None: ) .flatten() .astype(int) - .tolist() - ) + - 1 + ).tolist() train_splits.append(train_split) test_split = ( @@ -224,8 +224,8 @@ def _download_data(self, tmpdir: str) -> None: ) .flatten() .astype(int) - .tolist() - ) + - 1 + ).tolist() test_splits.append(test_split) self.update_metadata(splits={"known": train_splits, "test": test_splits})