diff --git a/.gitignore b/.gitignore index 367285b4c..660e95452 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,4 @@ interfaces/lammps/examples/*/*.dat interfaces/lammps/examples/*/deployed_model # batchwise optimizer examples -examples/howtos/howto_batchwise_relaxations_outputs/* \ No newline at end of file +examples/howtos/howto_batchwise_relaxations_outputs/* diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 6025f2ed1..9f371c1c8 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -17,7 +17,7 @@ import schnetpack as spk from schnetpack.utils import str2class from schnetpack.utils.script import log_hyperparameters, print_config -from schnetpack.data import BaseAtomsData, AtomsLoader +from schnetpack.data import ASEAtomsData, AtomsLoader from schnetpack.train import PredictionWriter from schnetpack import properties from schnetpack.utils import load_model @@ -178,14 +178,16 @@ def train(config: DictConfig): # Evaluate model on test set after training log.info("Starting testing.") - trainer.test(model=task, datamodule=datamodule, ckpt_path="best") + trainer.test( + model=task, datamodule=datamodule, ckpt_path="best", weights_only=False + ) # Store best model best_path = trainer.checkpoint_callback.best_model_path log.info(f"Best checkpoint path:\n{best_path}") log.info(f"Store best model") - best_task = type(task).load_from_checkpoint(best_path) + best_task = type(task).load_from_checkpoint(best_path, weights_only=False) torch.save(best_task, config.globals.model_path + ".task") best_task.save_model(config.globals.model_path, do_postprocessing=True) @@ -195,7 +197,7 @@ def train(config: DictConfig): @hydra.main(config_path="configs", config_name="predict", version_base="1.2") def predict(config: DictConfig): log.info(f"Load data from `{config.data.datapath}`") - dataset: BaseAtomsData = hydra.utils.instantiate(config.data) + dataset: ASEAtomsData = hydra.utils.instantiate(config.data) loader = AtomsLoader(dataset, batch_size=config.batch_size, num_workers=8) model = load_model("best_model") diff --git a/src/schnetpack/configs/data/ani1.yaml b/src/schnetpack/configs/data/ani1.yaml index a3bdcb48c..2fd1dd23b 100644 --- a/src/schnetpack/configs/data/ani1.yaml +++ b/src/schnetpack/configs/data/ani1.yaml @@ -1,16 +1,18 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.ANI1 +dataset: + _target_: schnetpack.datasets.ANI1 + datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml + num_heavy_atoms: 8 + high_energies: false + distance_unit: Ang + property_units: + energy: eV + transforms: ${data.transforms} + -datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml batch_size: 32 num_train: 10000000 num_val: 100000 -num_heavy_atoms: 8 -high_energies: False - -# convert to typically used units -distance_unit: Ang -property_units: - energy: eV \ No newline at end of file diff --git a/src/schnetpack/configs/data/custom.yaml b/src/schnetpack/configs/data/custom.yaml index 29fc00245..e71360659 100644 --- a/src/schnetpack/configs/data/custom.yaml +++ b/src/schnetpack/configs/data/custom.yaml @@ -1,12 +1,25 @@ -_target_: schnetpack.data.AtomsDataModule +# @package data +_target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2 + +dataset: + _target_: schnetpack.data.ASEAtomsData + datapath: ??? + load_properties: null + distance_unit: Ang + property_units: {} + transforms: ${data.transforms} + train_transforms: null + val_transforms: null + test_transforms: null -datapath: ??? -data_workdir: null batch_size: 10 num_train: ??? num_val: ??? num_test: null +split_file: ${run.data_dir}/split.npz +splitting: null num_workers: 8 -num_val_workers: null -num_test_workers: null -train_sampler_cls: null \ No newline at end of file +train_sampler_cls: null +train_sampler_args: {} +pin_memory: false + diff --git a/src/schnetpack/configs/data/iso17.yaml b/src/schnetpack/configs/data/iso17.yaml index b86dc5474..364a65fe4 100644 --- a/src/schnetpack/configs/data/iso17.yaml +++ b/src/schnetpack/configs/data/iso17.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.ISO17 +dataset: + _target_: schnetpack.datasets.ISO17 + datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml + folder: reference -datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml -folder: reference batch_size: 32 num_train: 0.9 num_val: 0.1 diff --git a/src/schnetpack/configs/data/materials_project.yaml b/src/schnetpack/configs/data/materials_project.yaml index f8516adbe..800cc9532 100644 --- a/src/schnetpack/configs/data/materials_project.yaml +++ b/src/schnetpack/configs/data/materials_project.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MaterialsProject +dataset: + _target_: schnetpack.datasets.MaterialsProject + datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml + apikey: ??? -datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml batch_size: 32 num_train: 60000 -num_val: 2000 -apikey: ??? \ No newline at end of file +num_val: 2000 \ No newline at end of file diff --git a/src/schnetpack/configs/data/md17.yaml b/src/schnetpack/configs/data/md17.yaml index 59fb35a5a..e2b53be47 100644 --- a/src/schnetpack/configs/data/md17.yaml +++ b/src/schnetpack/configs/data/md17.yaml @@ -1,10 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MD17 - -datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml molecule: aspirin + +dataset: + _target_: schnetpack.datasets.MD17 + datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml + molecule: ${data.molecule} + batch_size: 10 num_train: 950 num_val: 50 diff --git a/src/schnetpack/configs/data/md22.yaml b/src/schnetpack/configs/data/md22.yaml index 89ce71e6f..322042c1f 100644 --- a/src/schnetpack/configs/data/md22.yaml +++ b/src/schnetpack/configs/data/md22.yaml @@ -1,10 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.MD22 - -datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml molecule: Ac-Ala3-NHMe + +dataset: + _target_: schnetpack.datasets.MD22 + datapath: ${run.data_dir}/${data.molecule}.db + molecule: ${data.molecule} + batch_size: 10 num_train: 5700 num_val: 300 diff --git a/src/schnetpack/configs/data/omdb.yaml b/src/schnetpack/configs/data/omdb.yaml index 4e885c532..b9bf959c9 100644 --- a/src/schnetpack/configs/data/omdb.yaml +++ b/src/schnetpack/configs/data/omdb.yaml @@ -1,10 +1,12 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.OrganicMaterialsDatabase +dataset: + _target_: schnetpack.datasets.OrganicMaterialsDatabase + datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml + raw_path: null -datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml batch_size: 32 num_train: 0.8 num_val: 0.1 -raw_path: null \ No newline at end of file diff --git a/src/schnetpack/configs/data/qm7x.yaml b/src/schnetpack/configs/data/qm7x.yaml index 09cd4d805..95a21ed26 100644 --- a/src/schnetpack/configs/data/qm7x.yaml +++ b/src/schnetpack/configs/data/qm7x.yaml @@ -1,9 +1,14 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.QM7X +dataset: + _target_: schnetpack.datasets.QM7X + datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml + remove_duplicates: true + only_equilibrium: false + only_non_equilibrium: false -datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml batch_size: 100 num_train: 5550 num_val: 700 \ No newline at end of file diff --git a/src/schnetpack/configs/data/qm9.yaml b/src/schnetpack/configs/data/qm9.yaml index 02c3d1150..f24337472 100644 --- a/src/schnetpack/configs/data/qm9.yaml +++ b/src/schnetpack/configs/data/qm9.yaml @@ -1,22 +1,26 @@ +# @package data defaults: - custom -_target_: schnetpack.datasets.QM9 +dataset: + _target_: schnetpack.datasets.qm9.QM9 + datapath: ${run.data_dir}/qm9.db + remove_uncharacterized: true + load_properties: null + distance_unit: Ang + property_units: + energy_U0: eV + energy_U: eV + enthalpy_H: eV + free_energy: eV + homo: eV + lumo: eV + gap: eV + zpve: eV + transforms: ${data.transforms} -datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml batch_size: 100 num_train: 110000 num_val: 10000 -remove_uncharacterized: True - -# convert to typically used units -distance_unit: Ang -property_units: - energy_U0: eV - energy_U: eV - enthalpy_H: eV - free_energy: eV - homo: eV - lumo: eV - gap: eV - zpve: eV \ No newline at end of file +num_test: 10000 +num_workers: 2 diff --git a/src/schnetpack/configs/data/rmd17.yaml b/src/schnetpack/configs/data/rmd17.yaml index 76614e110..1e6b56f66 100644 --- a/src/schnetpack/configs/data/rmd17.yaml +++ b/src/schnetpack/configs/data/rmd17.yaml @@ -1,11 +1,13 @@ +# @package data defaults: - custom +molecule: aspirin -_target_: schnetpack.datasets.rMD17 +dataset: + _target_: schnetpack.datasets.rMD17 + datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml + molecule: ${data.molecule} -datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml -molecule: aspirin batch_size: 10 num_train: 950 num_val: 50 -split_id: null \ No newline at end of file diff --git a/src/schnetpack/data/__init__.py b/src/schnetpack/data/__init__.py index d3c6b83fc..328def9b1 100644 --- a/src/schnetpack/data/__init__.py +++ b/src/schnetpack/data/__init__.py @@ -4,3 +4,5 @@ from .splitting import * from .datamodule import * from .sampler import * +from .datamodule_v2 import * +from .provider import * diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index f01f866fd..e0d37a733 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -11,185 +11,32 @@ Journal of Physics: Condensed Matter, 9, 27. 2017. """ +import copy import logging import os -from abc import ABC, abstractmethod -from enum import Enum -from typing import Optional, List, Dict, Any, Iterable, Union, Tuple +from typing import Optional, List, Dict, Any, Iterable, Union import torch -import copy from ase import Atoms from ase.db import connect import schnetpack as spk import schnetpack.properties as structure -from schnetpack.transform import Transform +from schnetpack.transform.base import Transform logger = logging.getLogger(__name__) -__all__ = [ - "ASEAtomsData", - "BaseAtomsData", - "AtomsDataFormat", - "resolve_format", - "create_dataset", - "load_dataset", -] - - -class AtomsDataFormat(Enum): - """Enumeration of data formats""" - - ASE = "ase" +__all__ = ["ASEAtomsData", "AtomsDataError"] class AtomsDataError(Exception): pass -extension_map = {AtomsDataFormat.ASE: ".db"} - - -class BaseAtomsData(ABC): - """ - Base mixin class for atomistic data. Use together with PyTorch Dataset or - IterableDataset to implement concrete data formats. - """ - - def __init__( - self, - load_properties: Optional[List[str]] = None, - load_structure: bool = True, - transforms: Optional[List[Transform]] = None, - subset_idx: Optional[List[int]] = None, - ): - """ - Args: - load_properties: Set of properties to be loaded and returned. - If None, all properties in the ASE dB will be returned. - load_structure: If True, load structure properties. - transforms: preprocessing transforms (see schnetpack.data.transforms) - subset: List of data indices. - """ - self._transform_module = None - self.load_properties = load_properties - self.load_structure = load_structure - self.transforms = transforms - self.subset_idx = subset_idx - - def __len__(self) -> int: - raise NotImplementedError - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, value: Optional[List[Transform]]): - self._transforms = [] - self._transform_module = None - - if value is not None: - for tf in value: - self._transforms.append(tf) - self._transform_module = torch.nn.Sequential(*self._transforms) - - def subset(self, subset_idx: List[int]): - assert ( - subset_idx is not None - ), "Indices for creation of the subset need to be provided!" - ds = copy.copy(self) - if ds.subset_idx: - ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] - else: - ds.subset_idx = subset_idx - return ds - - @property - @abstractmethod - def available_properties(self) -> List[str]: - """Available properties in the dataset""" - pass - - @property - @abstractmethod - def units(self) -> Dict[str, str]: - """Property to unit dict""" - pass - - @property - def load_properties(self) -> List[str]: - """Properties to be loaded""" - if self._load_properties is None: - return self.available_properties - else: - return self._load_properties - - @load_properties.setter - def load_properties(self, val: List[str]): - if val is not None: - props = self.available_properties - assert all( - [p in props for p in val] - ), "Not all given properties are available in the dataset!" - self._load_properties = val - - @property - @abstractmethod - def metadata(self) -> Dict[str, Any]: - """Global metadata""" - pass - - @property - @abstractmethod - def atomrefs(self) -> Dict[str, torch.Tensor]: - """Single-atom reference values for properties""" - pass - - @abstractmethod - def update_metadata(self, **kwargs): - pass - - @abstractmethod - def iter_properties( - self, - indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, - load_structure: Optional[bool] = None, - ): - pass - - @staticmethod - @abstractmethod - def create( - datapath: str, - position_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Dict[str, List[float]], - **kwargs, - ) -> "BaseAtomsData": - pass - - @abstractmethod - def add_systems( - self, - property_list: List[Dict[str, Any]], - atoms_list: Optional[List[Atoms]] = None, - atoms_metadata_list: Optional[List[Dict[str, Any]]] = None, - ): - pass - - @abstractmethod - def add_system(self, atoms: Optional[Atoms] = None, **properties): - pass - - -class ASEAtomsData(BaseAtomsData): +class ASEAtomsData(torch.utils.data.Dataset): """ PyTorch dataset for atomistic data. The raw data is stored in the specified ASE database. - """ def __init__( @@ -197,7 +44,10 @@ def __init__( datapath: str, load_properties: Optional[List[str]] = None, load_structure: bool = True, - transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, @@ -209,36 +59,33 @@ def __init__( If None, all properties in the ASE dB will be returned. load_structure: If True, load structure properties. transforms: preprocessing torch.nn.Module (see schnetpack.data.transforms) + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing subset_idx: List of data indices. - units: property-> unit string dictionary that overwrites the native units - of the dataset. Units are converted automatically during loading. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ self.datapath = datapath + self.subset_idx = subset_idx + if not os.path.exists(self.datapath): + self.create() + self.download() + + self._check_db() self.conn = connect(self.datapath, use_lock_file=False) - BaseAtomsData.__init__( - self, - load_properties=load_properties, - load_structure=load_structure, - transforms=transforms, - subset_idx=subset_idx, - ) + self.transforms = list(transforms or []) + self.train_transforms = list(train_transforms) if train_transforms else None + self.val_transforms = list(val_transforms) if val_transforms else None + self.test_transforms = list(test_transforms) if test_transforms else None + self.split = None - self._check_db() + self._load_properties: Optional[List[str]] = None + self.load_structure = load_structure - # initialize units + # units from metadata md = self.metadata - if "_distance_unit" not in md.keys(): - raise AtomsDataError( - "Dataset does not have a distance unit set. Please add units to the " - + "dataset using `spkconvert`!" - ) - if "_property_unit_dict" not in md.keys(): - raise AtomsDataError( - "Dataset does not have a property units set. Please add units to the " - + "dataset using `spkconvert`!" - ) - if distance_unit: self.distance_conversion = spk.units.convert_units( md["_distance_unit"], distance_unit @@ -250,6 +97,8 @@ def __init__( self._units = md["_property_unit_dict"] self.conversions = {prop: 1.0 for prop in self._units} + + # apply unit overrides on load only if property_units is not None: for prop, unit in property_units.items(): self.conversions[prop] = spk.units.convert_units( @@ -257,42 +106,132 @@ def __init__( ) self._units[prop] = unit + # now validate load_properties against available_properties + self.load_properties = load_properties + + # ---------- merged ASEAtomsData bits ---------- + + def subset(self, subset_idx: List[int], split: Optional[str] = None): + if subset_idx is None: + raise ValueError("subset_idx must be provided.") + ds = copy.copy(self) + if ds.subset_idx is not None: + ds.subset_idx = [ds.subset_idx[i] for i in subset_idx] + else: + ds.subset_idx = subset_idx + ds.split = split + return ds + + @property + def load_properties(self) -> List[str]: + if self._load_properties is None: + return self.available_properties + return self._load_properties + + @load_properties.setter + def load_properties(self, val: Optional[List[str]]): + if val is not None: + props = self.available_properties + missing = [p for p in val if p not in props] + if missing: + raise AtomsDataError(f"Properties not available in dataset: {missing}") + self._load_properties = val + + # ---------- core dataset API ---------- + def __len__(self) -> int: if self.subset_idx is not None: return len(self.subset_idx) - return self.conn.count() def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: if self.subset_idx is not None: idx = self.subset_idx[idx] - props = self._get_properties( self.conn, idx, self.load_properties, self.load_structure ) - props = self._apply_transforms(props) - - return props + return self._apply_transforms(props) + + def _apply_transforms( + self, props: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + if self.split == "train" and self.train_transforms is not None: + transforms = self.train_transforms + elif self.split == "val" and self.val_transforms is not None: + transforms = self.val_transforms + elif self.split == "test" and self.test_transforms is not None: + transforms = self.test_transforms + else: + transforms = self.transforms - def _apply_transforms(self, props): - if self._transform_module is not None: - props = self._transform_module(props) + for tf in transforms: + props = tf(props) return props - def _check_db(self): + def _check_db(self) -> None: if not os.path.exists(self.datapath): raise AtomsDataError(f"ASE DB does not exist at {self.datapath}") - if self.subset_idx: - with connect(self.datapath, use_lock_file=False) as conn: - n_structures = conn.count() + with connect(self.datapath, use_lock_file=False) as conn: + n_structures = conn.count() + md = conn.metadata + + if n_structures == 0: + raise AtomsDataError(f"ASE DB at {self.datapath} is empty") - assert max(self.subset_idx) < n_structures + if "_distance_unit" not in md: + raise AtomsDataError( + "Dataset does not have a distance unit set. Please add units to the dataset." + ) + + if "_property_unit_dict" not in md: + raise AtomsDataError( + "Dataset does not have property units set. Please add units to the dataset." + ) + + if self.subset_idx is not None: + if max(self.subset_idx) >= n_structures: + raise AtomsDataError("subset_idx contains out-of-range indices") + + # ---------- metadata / units ----------- + + @property + def metadata(self) -> Dict[str, Any]: + with connect(self.datapath, use_lock_file=False) as conn: + return conn.metadata + + def _set_metadata(self, val: Dict[str, Any]): + with connect(self.datapath, use_lock_file=False) as conn: + conn.metadata = val + + def update_metadata(self, **kwargs): + if not all(k and k[0] != "_" for k in kwargs): + raise AtomsDataError("Metadata keys starting with '_' are protected!") + md = self.metadata + md.update(kwargs) + self._set_metadata(md) + + @property + def available_properties(self) -> List[str]: + md = self.metadata + return list(md["_property_unit_dict"].keys()) + + @property + def units(self) -> Dict[str, str]: + return self._units + + @property + def atomrefs(self) -> Dict[str, torch.Tensor]: + md = self.metadata + arefs = md.get("atomrefs", {}) + return {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} + + # ---------- iteration ---------- def iter_properties( self, indices: Union[int, Iterable[int]] = None, - load_properties: List[str] = None, + load_properties: Optional[List[str]] = None, load_structure: Optional[bool] = None, load_metadata: bool = False, ): @@ -311,22 +250,22 @@ def iter_properties( """ if load_properties is None: load_properties = self.load_properties - load_structure = load_structure or self.load_structure + if load_structure is None: + load_structure = self.load_structure - if self.subset_idx: + if self.subset_idx is not None: if indices is None: indices = self.subset_idx - elif type(indices) is int: + elif isinstance(indices, int): indices = [self.subset_idx[indices]] else: indices = [self.subset_idx[i] for i in indices] else: if indices is None: indices = range(len(self)) - elif type(indices) is int: + elif isinstance(indices, int): indices = [indices] - # read from ase db for i in indices: yield self._get_properties( self.conn, @@ -344,12 +283,24 @@ def _get_properties( load_structure: bool, load_metadata: bool = False, ): - row = conn.get(idx + 1) + """ + Load properties of a single system from the ASE database. + + Args: + conn: ASE database connection. + idx: Zero-based system index. + load_properties: Properties to load. + load_structure: Whether to load structural information. + load_metadata: Whether to load metadata. - # extract properties + Returns: + Dict[str, torch.Tensor]: Dictionary containing the requested properties. + """ + row = conn.get(idx + 1) # TODO: can the copies be avoided? - properties = {} + properties: Dict[str, torch.Tensor] = {} properties[structure.idx] = torch.tensor([idx]) + for pname in load_properties: properties[pname] = ( torch.tensor(row.data[pname].copy()) * self.conversions[pname] @@ -373,91 +324,32 @@ def _get_properties( return properties - # Metadata - @property - def metadata(self): - with connect(self.datapath, use_lock_file=False) as conn: - return conn.metadata - - def _set_metadata(self, val: Dict[str, Any]): - with connect(self.datapath, use_lock_file=False) as conn: - conn.metadata = val - - def update_metadata(self, **kwargs): - assert all( - key[0] != 0 for key in kwargs - ), "Metadata keys starting with '_' are protected!" - - md = self.metadata - md.update(kwargs) - self._set_metadata(md) + # ---------- creation / writing ---------- - @property - def available_properties(self) -> List[str]: - md = self.metadata - return list(md["_property_unit_dict"].keys()) - - @property - def units(self) -> Dict[str, str]: - """Dictionary of properties to units""" - return self._units - - @property - def atomrefs(self) -> Dict[str, torch.Tensor]: - md = self.metadata - arefs = md["atomrefs"] - arefs = {k: self.conversions[k] * torch.tensor(v) for k, v in arefs.items()} - return arefs - - ## Creation - - @staticmethod - def create( - datapath: str, - distance_unit: str, - property_unit_dict: Dict[str, str], - atomrefs: Optional[Dict[str, List[float]]] = None, - **kwargs, - ) -> "ASEAtomsData": + def create(self) -> None: """ - - Args: - datapath: Path to ASE DB. - distance_unit: unit of atom positions and cell - property_unit_dict: Defines the available properties of the datasetseta and - provides units for ALL properties of the dataset. If a property is - unit-less, you can pass "arb. unit" or `None`. - atomrefs: dictionary mapping properies (the keys) to lists of single-atom - reference values of the property. This is especially useful for - extensive properties such as the energy, where the single atom energies - contribute a major part to the overall value. - kwargs: Pass arguments to init. - - Returns: - newly created ASEAtomsData + Create a new ASE database at `self.datapath` and initialize its metadata. """ - if not datapath.endswith(".db"): - raise AtomsDataError( - "Invalid datapath! Please make sure to add the file extension '.db' to " - "your dbpath." - ) + if not self.datapath.endswith(".db"): + raise AtomsDataError("Invalid datapath! Add '.db' extension.") + if os.path.exists(self.datapath): + raise AtomsDataError(f"Dataset already exists: {self.datapath}") - if os.path.exists(datapath): - raise AtomsDataError(f"Dataset already exists: {datapath}") + os.makedirs(os.path.dirname(self.datapath) or ".", exist_ok=True) - atomrefs = atomrefs or {} + if self.property_units is None: + raise AtomsDataError("property_units is not set in dataset class.") + if self.distance_unit is None: + raise AtomsDataError("distance_unit is not set in dataset class.") - with connect(datapath) as conn: + with connect(self.datapath) as conn: conn.metadata = { - "_property_unit_dict": property_unit_dict, - "_distance_unit": distance_unit, - "atomrefs": atomrefs, + "_property_unit_dict": self.property_units, + "_distance_unit": self.distance_unit, + "atomrefs": {}, } - return ASEAtomsData(datapath, **kwargs) - - # add systems def add_system( self, atoms: Optional[Atoms] = None, @@ -490,13 +382,13 @@ def add_systems( Add atoms data to the dataset. Args: - atoms_list: System composition and geometry. If Atoms are None, - the structure needs to be given as part of the property dicts - (using structure.Z, structure.R, structure.cell, structure.pbc) property_list: Properties as list of key-value pairs in the same order as corresponding list of `atoms`. Keys have to match the `available_properties` of the dataset plus additional structure properties, if atoms is None. + atoms_list: System composition and geometry. If Atoms are None, + the structure needs to be given as part of the property dicts + (using structure.Z, structure.R, structure.cell, structure.pbc) atoms_metadata_list: Metadata of the atoms objects as list of key-value pairs in the same order as corresponding list of `atoms`. Metadata can not be used as a training property, but can be used for splitting @@ -504,18 +396,13 @@ def add_systems( """ if atoms_list is None: atoms_list = [None] * len(property_list) - if atoms_metadata_list is None: atoms_metadata_list = [{}] * len(property_list) for atoms, prop, atoms_metadata in zip( atoms_list, property_list, atoms_metadata_list ): - self._add_system( - atoms, - atoms_metadata, - **prop, - ) + self._add_system(atoms, atoms_metadata, **prop) def _add_system( self, @@ -526,7 +413,6 @@ def _add_system( """ Add systems to DB. """ - # create atoms object if not provided if atoms is None: try: Z = properties[structure.Z] @@ -535,9 +421,7 @@ def _add_system( pbc = properties[structure.pbc] atoms = Atoms(numbers=Z, positions=R, cell=cell, pbc=pbc) except KeyError as e: - raise AtomsDataError( - "Property dict does not contain all necessary structure keys" - ) from e + raise AtomsDataError("Missing structure keys in properties") from e if atoms_metadata is None: atoms_metadata = {} @@ -545,104 +429,24 @@ def _add_system( with connect(self.datapath, use_lock_file=False) as conn: prop_keys = conn.metadata["_property_unit_dict"].keys() - valid_props = set().union( - prop_keys, - [structure.Z, structure.R, structure.cell, structure.pbc], + valid_props = set(prop_keys).union( + {structure.Z, structure.R, structure.cell, structure.pbc} ) for pname in properties: if pname not in valid_props: logger.warning( - f"Property `{pname}` is not a defined property for this dataset and " - + f"will be ignored. If it should be included, it has to be " - + f"provided together with its unit when calling " - + f"AseAtomsData.create()." + f"Property `{pname}` is not defined for this dataset and will be ignored." ) data = {} for pname in prop_keys: - if pname in properties: - data[pname] = properties[pname] - else: - raise AtomsDataError("Required property missing:" + pname) + if pname not in properties: + raise AtomsDataError("Required property missing: " + pname) + data[pname] = properties[pname] conn.write(atoms, data=data, key_value_pairs=atoms_metadata) - -def create_dataset( - datapath: str, - format: AtomsDataFormat, - distance_unit: str, - property_unit_dict: Dict[str, str], - **kwargs, -) -> BaseAtomsData: - """ - Create a new atoms dataset. - - Args: - datapath: file path - format: atoms data format - distance_unit: unit of atom positiona etc. as string - property_unit_dict: dictionary that maps properties to units, - e.g. {"energy": "kcal/mol"} - **kwargs: arguments for passed to AtomsData init - - Returns: - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData.create( - datapath=datapath, - distance_unit=distance_unit, - property_unit_dict=property_unit_dict, - **kwargs, - ) - else: - raise AtomsDataError(f"Unknown format: {format}") - return dataset - - -def load_dataset(datapath: str, format: AtomsDataFormat, **kwargs) -> BaseAtomsData: - """ - Load dataset. - - Args: - datapath: file path - format: atoms data format - **kwargs: arguments for passed to AtomsData init - - """ - if format is AtomsDataFormat.ASE: - dataset = ASEAtomsData(datapath=datapath, **kwargs) - else: - raise AtomsDataError(f"Unknown format: {format}") - return dataset - - -def resolve_format( - datapath: str, format: Optional[AtomsDataFormat] = None -) -> Tuple[str, AtomsDataFormat]: - """ - Extract data format from file suffix, check for consistency with (optional) given - format, or append suffix to file path. - - Args: - datapath: path to atoms data - format: atoms data format - - """ - file, suffix = os.path.splitext(datapath) - if suffix == ".db": - if format is None: - format = AtomsDataFormat.ASE - assert ( - format is AtomsDataFormat.ASE - ), f"File extension {suffix} is not compatible with chosen format {format}" - elif len(suffix) == 0 and format: - datapath = datapath + extension_map[format] - elif len(suffix) == 0 and format is None: - raise AtomsDataError( - "If format is not given, `datapath` needs a supported file extension!" + def download(self): + raise NotImplementedError( + f"{self.__class__.__name__} must implement download()." ) - else: - raise AtomsDataError(f"Unsupported file extension: {suffix}") - return datapath, format diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index ea361e0a5..23d9a08c6 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -10,10 +10,7 @@ from torch.utils.data import BatchSampler from schnetpack.data import ( - AtomsDataFormat, - resolve_format, - load_dataset, - BaseAtomsData, + ASEAtomsData, AtomsLoader, calculate_stats, estimate_atomrefs, @@ -43,7 +40,7 @@ def __init__( num_val: Union[int, float] = None, num_test: Optional[Union[int, float]] = None, split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = None, + format=None, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, @@ -116,7 +113,8 @@ def __init__( self.num_test = num_test self.splitting = splitting or RandomSplit() self.split_file = split_file - self.datapath, self.format = resolve_format(datapath, format) + self.datapath = datapath + self.format = format self.load_properties = load_properties self.num_workers = num_workers self.num_val_workers = self.num_workers @@ -181,7 +179,7 @@ def setup(self, stage: Optional[str] = None): # (re)load datasets if self.dataset is None: - self.dataset = load_dataset( + self.dataset = ASEAtomsData( datapath, self.format, property_units=self.property_units, @@ -386,15 +384,15 @@ def get_atomrefs( return {property: atomrefs} @property - def train_dataset(self) -> BaseAtomsData: + def train_dataset(self) -> ASEAtomsData: return self._train_dataset @property - def val_dataset(self) -> BaseAtomsData: + def val_dataset(self) -> ASEAtomsData: return self._val_dataset @property - def test_dataset(self) -> BaseAtomsData: + def test_dataset(self) -> ASEAtomsData: return self._test_dataset def train_dataloader(self) -> AtomsLoader: diff --git a/src/schnetpack/data/datamodule_v2.py b/src/schnetpack/data/datamodule_v2.py new file mode 100644 index 000000000..e00e34b13 --- /dev/null +++ b/src/schnetpack/data/datamodule_v2.py @@ -0,0 +1,270 @@ +from typing import Optional, Union, Dict, Any, Type +import os +import warnings + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import BatchSampler + +from schnetpack.data.atoms import ASEAtomsData +from schnetpack.data.provider import StatsAtomrefProvider +from schnetpack.data.splitting import RandomSplit, SplittingStrategy +from schnetpack.data.loader import AtomsLoader + + +class AtomsDataModuleV2(pl.LightningDataModule): + """ + V2 DataModule: + - accepts a dataset instance + - handles splitting + - builds StatsAtomrefProvider from train split + - initializes transforms + """ + + def __init__( + self, + dataset: ASEAtomsData, + batch_size: int, + num_train: Union[int, float], + num_val: Union[int, float], + num_test: Optional[Union[int, float]] = None, + split_file: Optional[str] = "split.npz", + splitting: Optional[SplittingStrategy] = None, + num_workers: int = 0, + val_batch_size: Optional[int] = None, + test_batch_size: Optional[int] = None, + train_sampler_cls: Optional[Type] = None, + train_sampler_args: Optional[Dict[str, Any]] = None, + pin_memory: bool = False, + **kwargs, + ): + """ + Args: + dataset: prebuilt ASEAtomsData dataset instance + batch_size: (train) batch size + num_train: number of training examples (absolute or relative) + num_val: number of validation examples (absolute or relative) + num_test: number of test examples (absolute or relative) + split_file: path to npz file with data partitions + splitting: Method to generate train/validation/test partitions + (default: RandomSplit) + num_workers: Number of data loader workers + val_batch_size: validation batch size. If None, use test_batch_size, then + batch_size + test_batch_size: test batch size. If None, use val_batch_size, then + batch_size + train_sampler_cls: type of torch training sampler. + This is by default wrapped into a torch.utils.data.BatchSampler. + train_sampler_args: dict of train_sampler keyword arguments. + pin_memory: If true, pin memory of loaded data to GPU. Default: Will be + set to true, when GPUs are used. + """ + legacy_args = { + "datapath", + "format", + "load_properties", + "transforms", + "train_transforms", + "val_transforms", + "test_transforms", + "num_val_workers", + "num_test_workers", + "property_units", + "distance_unit", + "data_workdir", + "cleanup_workdir_stage", + } + used_legacy_args = [k for k in legacy_args if k in kwargs] + + if used_legacy_args: + warnings.warn( + "The following arguments are deprecated in `AtomsDataModuleV2`: " + f"{used_legacy_args}. " + "Use a prebuilt dataset instance and configure these options on the " + "dataset instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__() + + self.dataset = dataset + self.batch_size = batch_size + self.val_batch_size = val_batch_size or test_batch_size or batch_size + self.test_batch_size = test_batch_size or val_batch_size or batch_size + + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + self.split_file = split_file + self.splitting = splitting or RandomSplit() + self.num_workers = num_workers + self._pin_memory = pin_memory + + self.train_idx = None + self.val_idx = None + self.test_idx = None + + self._train_dataset = None + self._val_dataset = None + self._test_dataset = None + + self._train_dataloader = None + self._val_dataloader = None + self._test_dataloader = None + + self.provider: Optional[StatsAtomrefProvider] = None + + self.train_sampler_cls = train_sampler_cls + self.train_sampler_args = train_sampler_args or {} + + @property + def train_dataset(self) -> ASEAtomsData: + if self._train_dataset is None: + raise RuntimeError("Call setup() before accessing train_dataset.") + return self._train_dataset + + @property + def val_dataset(self) -> ASEAtomsData: + if self._val_dataset is None: + raise RuntimeError("Call setup() before accessing val_dataset.") + return self._val_dataset + + @property + def test_dataset(self) -> ASEAtomsData: + if self._test_dataset is None: + raise RuntimeError("Call setup() before accessing test_dataset.") + return self._test_dataset + + def setup(self, stage: Optional[str] = None) -> None: + if self.train_idx is None: + self._load_partitions() + + self._train_dataset = self.dataset.subset(self.train_idx) + self._val_dataset = self.dataset.subset(self.val_idx) + self._test_dataset = self.dataset.subset(self.test_idx) + + transforms = self.dataset.transforms or [] + + train_transforms = self.dataset.train_transforms or transforms + val_transforms = self.dataset.val_transforms or transforms + test_transforms = self.dataset.test_transforms or transforms + + self._train_dataset.transforms = [] + self._val_dataset.transforms = [] + self._test_dataset.transforms = [] + + self.provider = StatsAtomrefProvider(self._train_dataset) + + self._initialize_transforms(train_transforms) + self._initialize_transforms(val_transforms) + self._initialize_transforms(test_transforms) + + self._train_dataset.transforms = train_transforms + self._val_dataset.transforms = val_transforms + self._test_dataset.transforms = test_transforms + + def _initialize_transforms(self, transforms) -> None: + if not transforms: + return + + for t in transforms: + t.initialize(provider=self.provider, atomrefs=self.provider.train_atomrefs) + + def _load_partitions(self) -> None: + total_size = len(self.dataset) + + def _to_abs(x: Optional[Union[int, float]]) -> Optional[int]: + if x is None: + return None + if isinstance(x, float) and x <= 1.0: + return int(x * total_size) + return int(x) + + num_train = _to_abs(self.num_train) + num_val = _to_abs(self.num_val) + num_test = _to_abs(self.num_test) + + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + if self.split_file is not None and os.path.exists(self.split_file): + split_data = np.load(self.split_file) + self.train_idx = split_data["train_idx"].tolist() + self.val_idx = split_data["val_idx"].tolist() + self.test_idx = split_data["test_idx"].tolist() + return + + if num_train is None or num_val is None: + raise ValueError("num_train and num_val must be set if no split file.") + + train_idx, val_idx, test_idx = self.splitting.split( + self.dataset, num_train, num_val, num_test + ) + + self.train_idx = train_idx + self.val_idx = val_idx + self.test_idx = test_idx + + if self.split_file is not None: + np.savez( + self.split_file, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + ) + + def _setup_sampler(self, sampler_cls, sampler_args, dataset): + if sampler_cls is None: + return None + + return BatchSampler( + sampler=sampler_cls( + data_source=dataset, + num_samples=len(dataset), + **sampler_args, + ), + batch_size=self.batch_size, + drop_last=True, + ) + + def train_dataloader(self): + if self._train_dataloader is None: + train_batch_sampler = self._setup_sampler( + sampler_cls=self.train_sampler_cls, + sampler_args=self.train_sampler_args, + dataset=self.train_dataset, + ) + + self._train_dataloader = AtomsLoader( + self.train_dataset, + batch_size=self.batch_size if train_batch_sampler is None else 1, + shuffle=True if train_batch_sampler is None else False, + batch_sampler=train_batch_sampler, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._train_dataloader + + def val_dataloader(self): + if self._val_dataloader is None: + self._val_dataloader = AtomsLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._val_dataloader + + def test_dataloader(self): + if self._test_dataloader is None: + self._test_dataloader = AtomsLoader( + self.test_dataset, + batch_size=self.test_batch_size, + num_workers=self.num_workers, + pin_memory=self._pin_memory, + ) + + return self._test_dataloader diff --git a/src/schnetpack/data/loader.py b/src/schnetpack/data/loader.py index f116021ee..56f8d95f3 100644 --- a/src/schnetpack/data/loader.py +++ b/src/schnetpack/data/loader.py @@ -59,7 +59,7 @@ def _atoms_collate_fn(batch): class AtomsLoader(DataLoader): - """Data loader for subclasses of BaseAtomsData""" + """Data loader for subclasses of ASEAtomsData""" def __init__( self, diff --git a/src/schnetpack/data/provider.py b/src/schnetpack/data/provider.py new file mode 100644 index 000000000..8fa6432b7 --- /dev/null +++ b/src/schnetpack/data/provider.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Dict, Optional, Tuple + +import torch + +from schnetpack.data.atoms import ASEAtomsData +from schnetpack.data.stats import calculate_stats, estimate_atomrefs + + +class StatsAtomrefProvider: + """ + Compute and cache statistics and atom references from the training dataset. + """ + + def __init__(self, train_dataset: ASEAtomsData) -> None: + self.train_dataset = train_dataset + self.train_atomrefs = getattr(train_dataset, "atomrefs", None) + + self._stats_cache: Dict[ + Tuple[str, bool, bool], Tuple[torch.Tensor, torch.Tensor] + ] = {} + self._atomref_cache: Dict[Tuple[str, bool], torch.Tensor] = {} + + def get_stats( + self, property: str, divide_by_atoms: bool, remove_atomref: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + key = (property, divide_by_atoms, remove_atomref) + if key in self._stats_cache: + return self._stats_cache[key] + + atomref = self.train_atomrefs if remove_atomref else None + + stats = calculate_stats( + self.train_dataset, + divide_by_atoms={property: divide_by_atoms}, + atomref=atomref, + )[property] + + self._stats_cache[key] = stats + return stats + + def get_atomrefs( + self, property: str, is_extensive: bool + ) -> Dict[str, torch.Tensor]: + # 1) If dataset already has atomrefs for this property, use them directly + if self.train_atomrefs is not None and property in self.train_atomrefs: + return {property: self.train_atomrefs[property]} + + # 2) Otherwise estimate and cache + key = (property, is_extensive) + if key in self._atomref_cache: + return {property: self._atomref_cache[key]} + + atomref = estimate_atomrefs( + self.train_dataset, + is_extensive={property: is_extensive}, + )[property] + + self._atomref_cache[key] = atomref + return {property: atomref} diff --git a/src/schnetpack/data/sampler.py b/src/schnetpack/data/sampler.py index 0e353ef88..45d8beac9 100644 --- a/src/schnetpack/data/sampler.py +++ b/src/schnetpack/data/sampler.py @@ -4,7 +4,7 @@ from torch.utils.data import Sampler, WeightedRandomSampler from schnetpack import properties -from schnetpack.data import BaseAtomsData +from schnetpack.data import ASEAtomsData __all__ = [ @@ -53,8 +53,8 @@ class StratifiedSampler(WeightedRandomSampler): def __init__( self, - data_source: BaseAtomsData, - partition_criterion: Callable[[BaseAtomsData], List], + data_source: ASEAtomsData, + partition_criterion: Callable[[ASEAtomsData], List], num_samples: int, num_bins: int = 10, replacement: bool = True, diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 7276c8149..9df2365aa 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -1,18 +1,22 @@ -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional, Any import torch from tqdm import tqdm import schnetpack.properties as properties -from schnetpack.data import AtomsLoader +from schnetpack.data.atoms import ASEAtomsData +from schnetpack.data.loader import AtomsLoader __all__ = ["calculate_stats", "estimate_atomrefs"] def calculate_stats( - dataloader: AtomsLoader, + dataset: ASEAtomsData, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] = None, + batch_size: int = 10000, + num_workers: int = 4, + loader_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: """ Use the incremental Welford algorithm described in [h1]_ to accumulate @@ -23,15 +27,26 @@ def calculate_stats( .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Args: - dataloader: data loader - divide_by_atoms: dict from property name to bool: - If True, divide property by number of atoms before calculating statistics. - atomref: reference values for single atoms to be removed before calculating stats + dataset: Dataset used to compute statistics. + divide_by_atoms: Mapping from property name to bool indicating whether the property + should be divided by the number of atoms before computing statistics. + atomref: Optional single-atom reference values to subtract before computing statistics. + batch_size: Batch size used for the temporary data loader. + num_workers: Number of workers used by the data loader. Returns: - Mean and standard deviation over all samples - + Mapping from property name to `(mean, std)` tensors. """ + loader_kwargs = loader_kwargs or {} + + dataloader = AtomsLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + **loader_kwargs, + ) + property_names = list(divide_by_atoms.keys()) norm_mask = torch.tensor( [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64 @@ -45,7 +60,8 @@ def calculate_stats( sample_values = [] for p in property_names: val = props[p][None, :] - if atomref and p in atomref.keys(): + + if atomref and p in atomref: ar = atomref[p] ar = ar[props[properties.Z]] idx_m = props[properties.idx_m] @@ -54,90 +70,105 @@ def calculate_stats( val -= v0 sample_values.append(val) + sample_values = torch.cat(sample_values, dim=0) - batch_size = sample_values.shape[1] - new_count = count + batch_size + batch_n = sample_values.shape[1] + new_count = count + batch_n norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + ( 1 - norm_mask[:, None] ) - sample_values /= norm + sample_values = sample_values / norm sample_mean = torch.mean(sample_values, dim=1) sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1) delta = sample_mean - mean - mean += delta * batch_size / new_count - corr = batch_size * count / new_count + mean += delta * batch_n / new_count + corr = batch_n * count / new_count M2 += sample_m2 + delta**2 * corr count = new_count stddev = torch.sqrt(M2 / count) - stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} - return stats + return {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} -def estimate_atomrefs(dataloader, is_extensive, z_max=100): +def estimate_atomrefs( + dataset: ASEAtomsData, + is_extensive: Dict[str, bool], + z_max: int = 100, + batch_size: int = 10000, + num_workers: int = 4, + loader_kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, torch.Tensor]: """ Uses linear regression to estimate the elementwise biases (atomrefs). Args: - dataloader: data loader - is_extensive: If True, divide atom type counts by number of atoms before - calculating statistics. + dataset: Dataset used to estimate atom reference values. + is_extensive: Mapping from property name to bool indicating whether the property is + extensive. If False, atom type counts are divided by the number of atoms before fitting. + z_max: Maximum atomic number used to size the atomref tensors. + batch_size: Batch size used for the temporary data loader. + num_workers: Number of workers used by the data loader. Returns: - Elementwise bias estimates over all samples - + Mapping from property name to estimated atom reference tensor. """ + loader_kwargs = loader_kwargs or {} + + dataloader = AtomsLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + **loader_kwargs, + ) + property_names = list(is_extensive.keys()) - n_data = len(dataloader.dataset) + n_data = len(dataset) + all_properties = {pname: torch.zeros(n_data) for pname in property_names} all_atom_types = torch.zeros((n_data, z_max)) data_counter = 0 - # loop over all batches for batch in tqdm(dataloader, "estimating atomrefs"): - # load data idx_m = batch[properties.idx_m] atomic_numbers = batch[properties.Z] - # get counts for atomic numbers - unique_ids = torch.unique(idx_m) - for i in unique_ids: + for i in torch.unique(idx_m): atomic_numbers_i = atomic_numbers[idx_m == i] atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True) - # save atom counts and properties + for atom_type, atom_count in zip(atom_types, atom_counts): all_atom_types[data_counter, atom_type] = atom_count + for pname in property_names: property_value = batch[pname][i] if not is_extensive[pname]: property_value *= batch[properties.n_atoms][i] all_properties[pname][data_counter] = property_value + data_counter += 1 - # perform linear regression to get the elementwise energy contributions existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0] X = torch.squeeze(all_atom_types[:, existing_atom_types]) - w = dict() + + weights = {} for pname in property_names: if is_extensive[pname]: - w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] + weights[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] else: - w[pname] = ( + weights[pname] = ( torch.linalg.inv(X.T @ X) @ X.T @ (all_properties[pname] / X.sum(axis=1)) ) - # compute energy estimates - elementwise_contributions = { - pname: torch.zeros((z_max)) for pname in property_names - } + out = {pname: torch.zeros((z_max,)) for pname in property_names} for pname in property_names: - for atom_type, weight in zip(existing_atom_types, w[pname]): - elementwise_contributions[pname][atom_type] = weight + for atom_type, weight in zip(existing_atom_types, weights[pname]): + out[pname][atom_type] = weight - return elementwise_contributions + return out diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 48fb0f062..8db3583fd 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -1,32 +1,32 @@ import logging import os import shutil +import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request +from ase.db import connect +import h5py import numpy as np from ase import Atoms -import torch -import tarfile -import h5py +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform -from schnetpack.data import * +__all__ = ["ANI1"] log = logging.getLogger(__name__) -class ANI1(AtomsDataModule): +class ANI1(ASEAtomsData): """ - ANI1 benchmark database. + ANI1 benchmark dataset. This class adds convenience functions to download ANI1 from figshare and load the data into pytorch. References: - .. [#ani1] https://arxiv.org/abs/1708.04987 - """ energy = "energy" @@ -41,166 +41,156 @@ class ANI1(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, num_heavy_atoms: int = 8, high_energies: bool = False, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, ): """ - Args: datapath: path to dataset - num_heavy_atoms: number of heavy atoms. (See 'Table 1' in Ref. [#ani1]_) - high_energies: add high energy conformations. (See 'Technical Validation' of Ref. [#ani1]_) - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format + num_heavy_atoms: number of heavy atoms + high_energies: whether to include high-energy conformations load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - + transforms: Transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). """ self.num_heavy_atoms = num_heavy_atoms self.high_energies = high_energies + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - ANI1.energy: "Hartree", - } - atomrefs = self._create_atomrefs() - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + ANI1.energy: "Hartree", + } + + def _check_db(self) -> None: + """ + Ensure the ANI1 ASE DB exists. + """ + super()._check_db() + with connect(self.datapath, use_lock_file=False) as conn: + md = conn.metadata + + if md.get("num_heavy_atoms") != self.num_heavy_atoms: + raise AtomsDataError( + f"Existing ANI1 dataset was created with num_heavy_atoms={md.get('num_heavy_atoms')}, " + f"but requested num_heavy_atoms={self.num_heavy_atoms}." + ) + + if md.get("high_energies") != self.high_energies: + raise AtomsDataError( + f"Existing ANI1 dataset was created with high_energies={md.get('high_energies')}, " + f"but requested high_energies={self.high_energies}." ) - tmpdir = tempfile.mkdtemp("ani1") - self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) + def download(self) -> None: + """ + Download ANI1 data and populate the ASE DB. + """ + tmpdir = tempfile.mkdtemp("ani1") + + md = self.metadata + md["atomrefs"] = self._create_atomrefs() + md["num_heavy_atoms"] = self.num_heavy_atoms + md["high_energies"] = self.high_energies + self._set_metadata(md) + + self._download_data(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) - def _download_data(self, tmpdir, dataset: BaseAtomsData): - logging.info("downloading ANI-1 data...") + def _download_data(self, tmpdir: str) -> None: + logging.info("Downloading ANI-1 data...") tar_path = os.path.join(tmpdir, "ANI1_release.tar.gz") raw_path = os.path.join(tmpdir, "data") url = "https://ndownloader.figshare.com/files/9057631" request.urlretrieve(url, tar_path) + if not os.path.exists(tar_path): + raise AtomsDataError(f"Download failed, file not found: {tar_path}") + + if os.path.getsize(tar_path) == 0: + raise AtomsDataError(f"Downloaded file is empty: {tar_path}") + logging.info("Done.") - tar = tarfile.open(tar_path) - tar.extractall(raw_path) - tar.close() + with tarfile.open(tar_path) as tar: + tar.extractall(raw_path) - logging.info("parse files...") + logging.info("Parsing files...") for i in range(1, self.num_heavy_atoms + 1): - file_name = os.path.join(raw_path, "ANI-1_release", "ani_gdb_s0%d.h5" % i) - logging.info("start to parse %s" % file_name) - self._load_h5_file(file_name, dataset) + file_name = os.path.join(raw_path, "ANI-1_release", f"ani_gdb_s0{i}.h5") + logging.info("Start to parse %s", file_name) + self._load_h5_file(file_name) - logging.info("done...") + logging.info("Done.") - def _load_h5_file(self, file_name, dataset): + def _load_h5_file(self, file_name: str) -> None: atoms_list = [] properties_list = [] - store = h5py.File(file_name) - for file_key in store: - for molecule_key in store[file_key]: - molecule_group = store[file_key][molecule_key] - species = "".join([str(s)[-2] for s in molecule_group["species"]]) - positions = molecule_group["coordinates"] - energies = molecule_group["energies"] - - # loop over conformations - for i in range(energies.shape[0]): - atm = Atoms(species, positions[i]) - energy = energies[i] - properties = {self.energy: np.array([energy])} - atoms_list.append(atm) - properties_list.append(properties) - - # high energy conformations as described in 'Technical Validation' - # section of https://arxiv.org/abs/1708.04987 - if self.high_energies: - high_energy_positions = molecule_group["coordinatesHE"] - high_energies = molecule_group["energiesHE"] - - # loop over high energy conformations - for i in range(high_energies.shape[0]): - atm = Atoms(species, high_energy_positions[i]) - high_energy = high_energies[i] - properties = {self.energy: np.array([high_energy])} + with h5py.File(file_name, "r") as store: + for file_key in store: + for molecule_key in store[file_key]: + molecule_group = store[file_key][molecule_key] + species = "".join([str(s)[-2] for s in molecule_group["species"]]) + positions = molecule_group["coordinates"] + energies = molecule_group["energies"] + + # regular conformations + for i in range(energies.shape[0]): + atm = Atoms(species, positions[i]) + energy = energies[i] + properties = {self.energy: np.array([energy])} atoms_list.append(atm) properties_list.append(properties) - # write data to ase db - dataset.add_systems(atoms_list=atoms_list, property_list=properties_list) + # high-energy conformations + # section of https://arxiv.org/abs/1708.04987 + if self.high_energies: + high_energy_positions = molecule_group["coordinatesHE"] + high_energies = molecule_group["energiesHE"] - def _create_atomrefs(self): - atref = np.zeros((100,)) + for i in range(high_energies.shape[0]): + atm = Atoms(species, high_energy_positions[i]) + high_energy = high_energies[i] + properties = {self.energy: np.array([high_energy])} + atoms_list.append(atm) + properties_list.append(properties) - # converts units to eV (which are set to one in ase) + self.add_systems(atoms_list=atoms_list, property_list=properties_list) + + def _create_atomrefs(self) -> Dict[str, List[float]]: + atref = np.zeros((100,)) atref[1] = self.self_energies["H"] atref[6] = self.self_energies["C"] atref[7] = self.self_energies["N"] atref[8] = self.self_energies["O"] - return {ANI1.energy: atref.tolist()} diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index cfc8bb391..f206447d9 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -1,31 +1,30 @@ import logging import os import shutil +import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request -from tqdm import tqdm -import numpy as np -from ase.db import connect from urllib.error import HTTPError, URLError -import tarfile -import torch +import numpy as np +from ase.db import connect +from tqdm import tqdm -from schnetpack.data import * +from schnetpack.transform.base import Transform +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["ISO17"] -class ISO17(AtomsDataModule): +class ISO17(ASEAtomsData): """ - ISO17 benchmark data set for molecular dynamics of C7O2H10 isomers + ISO17 benchmark dataset for molecular dynamics of C7O2H10 isomers containing molecular forces. References: .. [#iso17] http://quantum-machine.org/datasets/ - """ energy = "total_energy" @@ -39,27 +38,16 @@ class ISO17(AtomsDataModule): "test_eq", ] - # properties def __init__( self, datapath: str, fold: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, **kwargs, @@ -68,61 +56,46 @@ def __init__( Args: datapath: path to dataset fold: select a specific dataset of iso17 - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ if fold not in self.existing_folds: - raise ValueError("Fold {:s} does not exist".format(fold)) + raise AtomsDataError(f"Fold {fold} does not exist.") - self.path = datapath + self.root_path = datapath self.fold = fold + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + dbpath = os.path.join(datapath, "iso17", fold + ".db") super().__init__( datapath=dbpath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - def prepare_data(self): - if not os.path.exists(self.datapath): - self._download_data() - else: - dataset = load_dataset(self.datapath, self.format) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + ISO17.energy: "eV", + ISO17.forces: "eV/Ang", + } - def _download_data(self): + def download(self) -> None: logging.info("Downloading ISO17 database...") tmpdir = tempfile.mkdtemp("iso17") tarpath = os.path.join(tmpdir, "iso17.tar.gz") @@ -130,41 +103,40 @@ def _download_data(self): try: request.urlretrieve(url, tarpath) + except HTTPError as e: - logging.error("HTTP Error:", e.code, url) - return False + raise AtomsDataError(f"HTTP Error {e.code} while downloading {url}") from e + except URLError as e: - logging.error("URL Error:", e.reason, url) - return False - - tar = tarfile.open(tarpath) - tar.extractall(self.path) - tar.close() - - # update metadata - for fold in ISO17.existing_folds: - dbpath = os.path.join(self.path, "iso17", fold + ".db") - tmp_dbpath = os.path.join(tmpdir, "tmp.db") - with connect(dbpath) as conn: - with connect(tmp_dbpath) as tmp_conn: + raise AtomsDataError(f"URL Error {e.reason} while downloading {url}") from e + + with tarfile.open(tarpath) as tar: + tar.extractall(self.root_path) + + # update metadata + convert energy into row.data for every fold + for fold in self.existing_folds: + dbpath = os.path.join(self.root_path, "iso17", fold + ".db") + tmp_dbpath = os.path.join(tmpdir, f"{fold}_tmp.db") + + with connect(dbpath, use_lock_file=False) as conn: + with connect(tmp_dbpath, use_lock_file=False) as tmp_conn: tmp_conn.metadata = { - "_property_unit_dict": { - ISO17.energy: "eV", - ISO17.forces: "eV/Ang", - }, + "_property_unit_dict": self._native_property_units(), "_distance_unit": "Ang", "atomrefs": {}, } - # add energy to data dict in db + for idx in tqdm( - range(len(conn)), f"parsing database file {dbpath}" + range(len(conn)), + desc=f"parsing database file {dbpath}", ): atmsrw = conn.get(idx + 1) data = atmsrw.data - data[ISO17.forces] = np.array(data[ISO17.forces]) - data[ISO17.energy] = np.array([atmsrw.total_energy]) + data[self.forces] = np.array(data[self.forces]) + data[self.energy] = np.array([atmsrw.total_energy]) tmp_conn.write(atmsrw.toatoms(), data=data) os.remove(dbpath) os.rename(tmp_dbpath, dbpath) - shutil.rmtree(tmpdir) + + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index 8779e4bff..a87bf730a 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -1,20 +1,16 @@ import logging -import os from typing import List, Optional, Dict -import warnings -from ase import Atoms - -import torch import numpy as np -from schnetpack.data import * -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from ase import Atoms +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform __all__ = ["MaterialsProject"] -class MaterialsProject(AtomsDataModule): +class MaterialsProject(ASEAtomsData): """ Materials Project (MP) database of bulk crystals. This class adds convenient functions to download Materials Project data into @@ -23,7 +19,6 @@ class MaterialsProject(AtomsDataModule): References: .. [#matproj] https://materialsproject.org/ - """ # properties @@ -31,144 +26,101 @@ class MaterialsProject(AtomsDataModule): EPerAtom = "energy_per_atom" BandGap = "band_gap" TotalMagnetization = "total_magnetization" - MaterialId = ("material_id",) + MaterialId = "material_id" CreatedAt = "created_at" def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, apikey: Optional[str] = None, **kwargs, ): """ - Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - apikey: Materials project key needed to download the data. + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) + apikey: api key use to get data """ - if apikey is not None and len(apikey) == 16: - raise DeprecationWarning( - "You are using a legacy API key. This API is deprecated and no longer supported by MaterialsProject. " - "Please use the nextgen API instead. " - "Visit https://next-gen.materialsproject.org/ to get a valid API-key. " + if apikey is None: + raise AtomsDataError( + "No API key provided. Visit https://next-gen.materialsproject.org/ " + "to get an API key." + ) + + elif len(apikey) == 16: + raise AtomsDataError( + "You are using a legacy API key. This API is deprecated and no longer " + "supported by Materials Project. Please use the next-gen API instead. " + "Visit https://next-gen.materialsproject.org/ to get a valid API key." ) - if apikey is not None and len(apikey) != 32: - raise AtomsDataModuleError( - "Invalid API-key. MaterialsProject requires an API-key of 32 characters. " - f"Your API-key contains {len(apikey)} characters. " - "Visit https://next-gen.materialsproject.org/ to get a valid API-key. " + + elif len(apikey) != 32: + raise AtomsDataError( + "Invalid API key. MaterialsProject requires an API key of 32 characters. " + f"Your API key contains {len(apikey)} characters. " + "Visit https://next-gen.materialsproject.org/ to get a valid API key." ) + self.apikey = apikey + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - self.apikey = apikey - - def prepare_data(self): - if not os.path.exists(self.datapath): - # check if apikey is provided - if self.apikey is None: - raise AtomsDataModuleError( - "No API-key provided, visit https://next-gen.materialsproject.org/ to get an API-key." - ) - - # initialize dataset - property_unit_dict = { - MaterialsProject.EformationPerAtom: "eV", - MaterialsProject.EPerAtom: "eV", - MaterialsProject.BandGap: "eV", - MaterialsProject.TotalMagnetization: "None", - } - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - ) - self._download_data_nextgen(dataset) - else: - dataset = load_dataset(self.datapath, self.format) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + MaterialsProject.EformationPerAtom: "eV", + MaterialsProject.EPerAtom: "eV", + MaterialsProject.BandGap: "eV", + MaterialsProject.TotalMagnetization: "None", + } - def _download_data_nextgen(self, dataset: BaseAtomsData): + def download(self) -> None: """ - Downloads dataset provided it does not exist in self.path - Returns: - works (bool): true if download succeeded or file already exists + Download Materials Project entries and store them in the ASE DB. """ - # collect data - atms_list = [] + atoms_list = [] properties_list = [] atoms_metadata_list = [] + try: from pymatgen.core import Structure - import pymatgen as pmg from mp_api.client import MPRester - - except: + except Exception as e: raise ImportError( - "In order to download Materials Project data, you have to install " - "mp-api and pymatgen packages" - ) + "To download Materials Project data, install `mp-api` and `pymatgen`." + ) from e with MPRester(self.apikey) as m: query = m.materials.summary.search( - num_sites=(0, 300, 30), + num_sites=(0, 300), num_elements=(1, 9), fields=[ "structure", @@ -183,8 +135,8 @@ def _download_data_nextgen(self, dataset: BaseAtomsData): for q in query: s = q.structure - if type(s) is Structure: - atms_list.append( + if isinstance(s, Structure): + atoms_list.append( Atoms( numbers=s.atomic_numbers, positions=s.cart_coords, @@ -206,14 +158,13 @@ def _download_data_nextgen(self, dataset: BaseAtomsData): ) atoms_metadata_list.append( { - "material_id": q.material_id, + MaterialsProject.MaterialId: str(q.material_id), } ) - # write systems to database logging.info("Write atoms to db...") - dataset.add_systems( - atoms_list=atms_list, + self.add_systems( + atoms_list=atoms_list, property_list=properties_list, atoms_metadata_list=atoms_metadata_list, ) diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 0e23e7ecb..6f003d31a 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -8,15 +8,14 @@ import numpy as np from ase import Atoms -import torch import schnetpack.properties as structure - -from schnetpack.data import * +from schnetpack.transform.base import Transform +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["MD17"] -class GDMLDataModule(AtomsDataModule): +class GDMLDataset(ASEAtomsData): """ Base class for GDML type data (e.g. MD17 or MD22). Requires a dictionary translating between molecule and filenames and an URL under which the molecular datasets can be found. @@ -25,134 +24,100 @@ class GDMLDataModule(AtomsDataModule): energy = "energy" forces = "forces" - # properties def __init__( self, datasets_dict: Dict[str, str], download_url: str, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, + tmpdir: str = "gdml_tmp", + atomrefs: Optional[Dict[str, List[float]]] = None, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - tmpdir: str = "gdml_tmp", - atomrefs: Optional[Dict[str, List[float]]] = None, **kwargs, ): """ Args: - datasets_dict: dictionary mapping molecule names to dataset names. - download_url: URL where individual molecule datasets can me found. + datasets_dict: dictionary mapping molecule names to dataset names + download_url: URL where individual molecule datasets can me found datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. - tmpdir: name of temporary directory used for parsing. + molecule: name of the molecule + tmpdir: name of temporary directory used for parsing atomrefs: properties of free atoms + load_properties: subset of properties to load + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ + self.datasets_dict = datasets_dict + self.download_url = download_url + self._native_atomrefs = atomrefs or {} + self.tmpdir = tmpdir + + if molecule not in self.datasets_dict: + raise AtomsDataError(f"Molecule {molecule} is not supported!") + + self.molecule = molecule + + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, **kwargs, ) - self.datasets_dict = datasets_dict - self.download_url = download_url - self.atomrefs = atomrefs - self.tmpdir = tmpdir + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + GDMLDataset.energy: "kcal/mol", + GDMLDataset.forces: "kcal/mol/Ang", + } - if molecule not in self.datasets_dict.keys(): - raise AtomsDataModuleError("Molecule {} is not supported!".format(molecule)) + def _check_db(self) -> None: + super()._check_db() + md = self.metadata - self.molecule = molecule - - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - self.energy: "kcal/mol", - self.forces: "kcal/mol/Ang", - } - - tmpdir = tempfile.mkdtemp(self.tmpdir) + if "molecule" not in md: + raise AtomsDataError( + "Not a valid GDML dataset. Metadata must contain `molecule`." + ) - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=self.atomrefs, + if md["molecule"] != self.molecule: + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." ) - dataset.update_metadata(molecule=self.molecule) - self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) - md = dataset.metadata - if "molecule" not in md: - raise AtomsDataModuleError( - "Not a valid GDML dataset! The molecule needs to be specified in the metadata." - ) - if md["molecule"] != self.molecule: - raise AtomsDataModuleError( - f"The dataset at the given location does not contain the specified molecule: " - + f"`{md['molecule']}` instead of `{self.molecule}`" - ) + def download(self) -> None: + tmpdir = tempfile.mkdtemp(self.tmpdir) + md = self.metadata + md["atomrefs"] = self._native_atomrefs + md["molecule"] = self.molecule + self._set_metadata(md) - def _download_data( - self, - tmpdir, - dataset: BaseAtomsData, - ): + self._download_data(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) + + def _download_data(self, tmpdir) -> None: logging.info("Downloading {} data".format(self.molecule)) rawpath = os.path.join(tmpdir, self.datasets_dict[self.molecule]) url = self.download_url + self.datasets_dict[self.molecule] @@ -180,66 +145,45 @@ def _download_data( property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") -class MD17(GDMLDataModule): +class MD17(GDMLDataset): """ MD17 benchmark data set for molecular dynamics of small molecules containing molecular forces. References: .. [#md17_1] http://quantum-machine.org/gdml/#datasets - """ def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, **kwargs, ): """ Args: - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. + datapath: path to dataset. + molecule: name of the molecule. + load_properties: subset of properties to load. transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. + subset_idx: indices of the subset to load. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). """ atomrefs = { self.energy: [ @@ -257,46 +201,28 @@ def __init__( datasets_dict = dict( aspirin="md17_aspirin.npz", - # aspirin_ccsd='aspirin_ccsd.zip', azobenzene="azobenzene_dft.npz", benzene="md17_benzene2017.npz", ethanol="md17_ethanol.npz", - # ethanol_ccsdt='ethanol_ccsd_t.zip', malonaldehyde="md17_malonaldehyde.npz", - # malonaldehyde_ccsdt='malonaldehyde_ccsd_t.zip', naphthalene="md17_naphthalene.npz", paracetamol="paracetamol_dft.npz", salicylic_acid="md17_salicylic.npz", toluene="md17_toluene.npz", - # toluene_ccsdt='toluene_ccsd_t.zip', uracil="md17_uracil.npz", ) - super(MD17, self).__init__( + super().__init__( datasets_dict=datasets_dict, download_url="http://www.quantum-machine.org/gdml/data/npz/", tmpdir="md17", molecule=molecule, datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, - train_transforms=train_transforms, - val_transforms=val_transforms, - test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, atomrefs=atomrefs, **kwargs, ) diff --git a/src/schnetpack/datasets/md22.py b/src/schnetpack/datasets/md22.py index 9d51008f7..4cc0f8997 100644 --- a/src/schnetpack/datasets/md22.py +++ b/src/schnetpack/datasets/md22.py @@ -1,68 +1,44 @@ -import torch from typing import Optional, Dict, List +from schnetpack.datasets.md17 import GDMLDataset +from schnetpack.transform.base import Transform -from schnetpack.data import * -from schnetpack.datasets.md17 import GDMLDataModule +__all__ = ["MD22"] -all = ["MD22"] - - -class MD22(GDMLDataModule): +class MD22(GDMLDataset): """ MD22 benchmark data set for extended molecules containing molecular forces. References: .. [#md22_1] http://quantum-machine.org/gdml/#datasets - """ def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, **kwargs, ): """ Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format + molecule: name of the molecule load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ atomrefs = { self.energy: [ @@ -77,6 +53,7 @@ def __init__( -47069.30768969713, ] } + datasets_dict = { "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz", "DHA": "md22_DHA.npz", @@ -87,31 +64,20 @@ def __init__( "double-walled_nanotube": "md22_double-walled_nanotube.npz", } - super(MD22, self).__init__( + super().__init__( datasets_dict=datasets_dict, download_url="http://www.quantum-machine.org/gdml/repo/datasets/", tmpdir="md22", molecule=molecule, datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, atomrefs=atomrefs, **kwargs, ) diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index 413ef5ba0..34fc918f8 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -2,19 +2,18 @@ import os import tarfile from typing import List, Optional, Dict -from ase.io import read import numpy as np +from ase.io import read -import torch -from schnetpack.data import * -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform __all__ = ["OrganicMaterialsDatabase"] -class OrganicMaterialsDatabase(AtomsDataModule): +class OrganicMaterialsDatabase(ASEAtomsData): """ Organic Materials Database (OMDB) of bulk organic crystals. Registration to the OMDB is free for academic users. This database contains DFT @@ -32,22 +31,12 @@ class OrganicMaterialsDatabase(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, raw_path: Optional[str] = None, @@ -56,88 +45,77 @@ def __init__( """ Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). + transforms: transform applied to each system separately before batching. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. + subset_idx: indices of the subset to load. + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). raw_path: path to raw tar.gz file with the data """ + self.raw_path = raw_path + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, + load_structure=True, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, **kwargs, ) - self.raw_path = raw_path - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = {OrganicMaterialsDatabase.BandGap: "eV"} + @staticmethod + def _native_property_units() -> Dict[str, str]: + return {OrganicMaterialsDatabase.BandGap: "eV"} - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - ) - - self._convert(dataset) - else: - dataset = load_dataset(self.datapath, self.format) - - def _convert(self, dataset): + def download(self) -> None: """ - Converts .tar.gz to a .db file + Convert the OMDB raw archive into an ASE DB. """ if self.raw_path is None or not os.path.exists(self.raw_path): - # TODO: can we download here automatically like QM9? - raise AtomsDataModuleError( + raise AtomsDataError( "The path to the raw dataset is not provided or invalid and the db-file does " "not exist!" ) - logging.info("Converting %s to a .db file.." % self.raw_path) - tar = tarfile.open(self.raw_path, "r:gz") - names = tar.getnames() - tar.extractall() - tar.close() + self._convert() - structures = read("structures.xyz", index=":") - Y = np.loadtxt("bandgaps.csv") - [os.remove(name) for name in names] + def _convert(self, dataset: ASEAtomsData) -> None: + """ + Converts .tar.gz to a .db file + """ + logging.info("Converting %s to a .db file..", self.raw_path) + + extract_dir = os.path.dirname(self.raw_path) or "." + with tarfile.open(self.raw_path, "r:gz") as tar: + names = tar.getnames() + tar.extractall(path=extract_dir) + + structures_path = os.path.join(extract_dir, "structures.xyz") + bandgaps_path = os.path.join(extract_dir, "bandgaps.csv") + + structures = read(structures_path, index=":") + y = np.loadtxt(bandgaps_path) atoms_list = [] property_list = [] for i, at in enumerate(structures): atoms_list.append(at) - property_list.append({OrganicMaterialsDatabase.BandGap: np.array([Y[i]])}) + property_list.append( + {OrganicMaterialsDatabase.BandGap: np.array([y[i]], dtype=np.float64)} + ) + dataset.add_systems(atoms_list=atoms_list, property_list=property_list) + + for name in names: + path = os.path.join(extract_dir, name) + if os.path.exists(path): + os.remove(path) diff --git a/src/schnetpack/datasets/qm7x.py b/src/schnetpack/datasets/qm7x.py index 03b4d4087..35ff5249d 100644 --- a/src/schnetpack/datasets/qm7x.py +++ b/src/schnetpack/datasets/qm7x.py @@ -11,17 +11,13 @@ import h5py import numpy as np import progressbar -import torch from ase import Atoms -from tqdm import tqdm -from schnetpack.data import * -from schnetpack.data import AtomsDataModule -from schnetpack.data.splitting import GroupSplit +from schnetpack.transform.base import Transform +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError __all__ = ["QM7X"] -# Helper functions pbar = None @@ -42,36 +38,32 @@ def show_progress(block_num: int, block_size: int, total_size: int): pbar = None -def download_and_check(url: str, tar_path: str, checksum: str): +def download_and_check(url: str, target_path: str, checksum: str): """ Download file from url to tar_path and check md5 checksum. """ + file_name = url.split("/")[-1] - file = url.split("/")[-1] - - # check if file already exists and has correct checksum - if os.path.exists(tar_path): - md5_sum = hashlib.md5(open(tar_path, "rb").read()).hexdigest() + if os.path.exists(target_path): + md5_sum = hashlib.md5(open(target_path, "rb").read()).hexdigest() if md5_sum == checksum: logging.info( - f"File {file} already exists and has correct checksum. Skipping download." + f"File {file_name} already exists and has correct checksum. Skipping download." ) return - else: - logging.info( - f"File {file} already exists but has wrong checksum. Redownloading." - ) - os.remove(tar_path) + logging.info( + f"File {file_name} already exists but has wrong checksum. Redownloading." + ) + os.remove(target_path) logging.info(f"Downloading {url} ...") - request.urlretrieve(url, tar_path, show_progress) + request.urlretrieve(url, target_path, show_progress) - if hashlib.md5(open(tar_path, "rb").read()).hexdigest() == checksum: - logging.info("Done.") - else: + if hashlib.md5(open(target_path, "rb").read()).hexdigest() != checksum: raise RuntimeError( - f"Checksum of downloaded file {file} does not match. Please try again." + f"Checksum of downloaded file {file_name} does not match. Please try again." ) + logging.info("Done.") def extract_xz(source: str, target: str): @@ -86,19 +78,18 @@ def extract_xz(source: str, target: str): return logging.info(f"Extracting {s_file} ...") - try: with lzma.open(source) as fin, open(target, mode="wb") as fout: shutil.copyfileobj(fin, fout) - except: + except Exception as e: if os.path.exists(target): os.remove(target) - raise RuntimeError(f"Could not extract file {s_file}. Please try again.") + raise RuntimeError(f"Could not extract file {s_file}. Please try again.") from e logging.info("Done.") -class QM7X(AtomsDataModule): +class QM7X(ASEAtomsData): """ QM7-X a comprehensive dataset of > 40 physicochemical properties for ~4.2 M equilibrium and non-equilibrium structure of small organic molecules with up to seven non-hydrogen (C, N, O, S, Cl) atoms. @@ -124,17 +115,6 @@ class QM7X(AtomsDataModule): FMBD = "FMBD" # FMBD: total eMBD forces RMSD = "rmsd" # root mean square deviation of the atomic positions from the equilibrium structure - property_unit_dict = { - forces: "eV/Ang", - energy: "eV", - Eat: "eV", - EPBE0: "eV", - EMBD: "eV", - FPBE0: "eV/Ang", - FMBD: "eV/Ang", - RMSD: "Ang", - } - # the original keys in the raw dataset to query the properties property_dataset_keys = { forces: "totFOR", @@ -160,105 +140,140 @@ class QM7X(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - raw_data_path: str = None, + raw_data_path: Optional[str] = None, remove_duplicates: bool = True, only_equilibrium: bool = False, only_non_equilibrium: bool = False, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - splitting: Optional[SplittingStrategy] = None, **kwargs, ): """ Args: datapath: path to dataset - batch_size: (train) batch size - raw_data_path: path to raw data. If None use tmp dir otherwise persist data and not remove it. - remove_duplicates: remove duplicated equilibrium structures with different non-equilibrium structures - only_equilibrium: only use equilibrium structures - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format + raw_data_path: path to raw data + remove_duplicates: do not include duplicate molecules + only_equilibrium: only include equilibrium molecules + only_non_equilibrium: only include non-equilibrium molecules load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. - splitting: Method to generate train/validation/test partitions - (default: GroupSplit(splitting_key="smiles_id")) + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ + + if only_equilibrium and only_non_equilibrium: + raise AtomsDataError( + "only_equilibrium and only_non_equilibrium cannot both be True." + ) + + self.raw_data_path = raw_data_path + self.remove_duplicates = remove_duplicates + self.duplicates_ids = None + self.only_equilibrium = only_equilibrium + self.only_non_equilibrium = only_non_equilibrium + + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + + # initialize without subset first, then apply dataset-specific filtering super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, - splitting=splitting or GroupSplit(splitting_key="smiles_id"), **kwargs, ) - self.raw_data_path = raw_data_path - self.remove_duplicates = remove_duplicates - self.duplicates_ids = None - self.only_equilibrium = only_equilibrium - self.only_non_equilibrium = only_non_equilibrium + self._apply_structure_filter(original_subset_idx=subset_idx) + + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + QM7X.forces: "eV/Ang", + QM7X.energy: "eV", + QM7X.Eat: "eV", + QM7X.EPBE0: "eV", + QM7X.EMBD: "eV", + QM7X.FPBE0: "eV/Ang", + QM7X.FMBD: "eV/Ang", + QM7X.RMSD: "Ang", + } + + def _apply_structure_filter(self, original_subset_idx: Optional[List[int]]) -> None: + effective_subset = original_subset_idx + + if self.only_equilibrium or self.only_non_equilibrium: + step_ids = self.metadata["groups_ids"]["step_id"] + + if len(step_ids) != self.conn.count(): + raise AtomsDataError( + "Dataset size does not match size of step_id metadata." + ) + + if self.only_equilibrium: + filtered = [i for i, s in enumerate(step_ids) if s == 0] + else: + filtered = [i for i, s in enumerate(step_ids) if s != 0] + + if effective_subset is None: + effective_subset = filtered + else: + filtered_set = set(filtered) + effective_subset = [i for i in effective_subset if i in filtered_set] + + self.subset_idx = effective_subset + + def download(self) -> None: + """ + Download the QM7-X dataset and create the ASEAtomsData object. + """ + tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") + atomrefs = { + QM7X.energy: [ + QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 + for i in range(0, 18) + ] + } + + hd_files = self._download_data(tar_dir) + if self.remove_duplicates: + self._download_duplicates_ids(tar_dir) + self._parse_data(hd_files) + + if self.raw_data_path is None: + shutil.rmtree(tar_dir, ignore_errors=True) def _download_duplicates_ids(self, tar_dir: str): """ download duplicates ids for QM7-X """ - url = f"https://zenodo.org/record/4288677/files/DupMols.dat" - tar_path = os.path.join(tar_dir, "DupMols.dat") + url = "https://zenodo.org/record/4288677/files/DupMols.dat" + target_path = os.path.join(tar_dir, "DupMols.dat") checksum = "5d886ccac38877c8cb26c07704dd1034" - download_and_check(url, tar_path, checksum) + download_and_check(url, target_path, checksum) # fetch duplicates ids dup_mols = [] - for line in open(tar_path, "r"): - dup_mols.append(line.rstrip("\n")[:-4]) + with open(target_path, "r") as f: + for line in f: + dup_mols.append(line.rstrip("\n")[:-4]) + self.duplicates_ids = dup_mols def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[str]: @@ -281,42 +296,35 @@ def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[st logging.info("Downloading QM7-X data files ...") - # download files for i, file_id in enumerate(file_ids): if ignore_extracted and os.path.exists( os.path.join(tar_dir, f"{file_id}.hdf5") ): logging.info( - f"File {file_id}.xz exists in extracted version {file_id}.hdf5 already, skipping download." + f"File {file_id}.hdf5 already exists. Skipping download of {file_id}.xz." ) continue url = f"https://zenodo.org/record/4288677/files/{file_id}.xz" - - tar_path = os.path.join(tar_dir, f"{file_id}.xz") - download_and_check(url, tar_path, checksums[i]) + xz_path = os.path.join(tar_dir, f"{file_id}.xz") + download_and_check(url, xz_path, checksums[i]) # extract the compressed files extracted = [] - for i, file_id in enumerate(file_ids): + for file_id in file_ids: xz_path = os.path.join(tar_dir, f"{file_id}.xz") hd_path = os.path.join(tar_dir, f"{file_id}.hdf5") - extract_xz(xz_path, hd_path) - extracted.append(hd_path) return extracted - def _parse_data(self, files: List[str], dataset: BaseAtomsData): + def _parse_data(self, files: List[str]): """ Parse the downloaded data files and add them to the dataset. """ - - # parse the data files - for file in files: - logging.info(f"Parsing {file.split('/')[-1]} ...") + logging.info(f"Parsing {os.path.basename(file)} ...") atoms_list = [] property_list = [] @@ -328,7 +336,7 @@ def _parse_data(self, files: List[str], dataset: BaseAtomsData): } with h5py.File(file, "r") as mol_dict: - for mol_id, mol in tqdm(mol_dict.items()): + for _mol_id, mol in mol_dict.items(): for conf_id, conf in mol.items(): # exclude equilibrium duplicates trunc_id = conf_id[::-1].split("-", 1)[-1][::-1] @@ -340,31 +348,28 @@ def _parse_data(self, files: List[str], dataset: BaseAtomsData): key: np.array( conf[QM7X.property_dataset_keys[key]], dtype=np.float64 ) - for key in QM7X.property_unit_dict.keys() + for key in QM7X._native_property_units().keys() } # get the hierarchical ids for each system if "opt" in conf_id: - conf_id = ( - conf_id[:-3] + "d0" - ) # repalce the 'opt' key with id 'd0' + conf_id = conf_id[:-3] + "d0" + ids = map(lambda x: int(x), re.findall(r"\d+", conf_id)) atoms_list.append(ats) property_list.append(properties) # save the hierarchical ids for each system in same order as the systems - for i, j in zip(groups_ids.keys(), ids): - groups_ids[i].append(j) + for key, idx in zip(groups_ids.keys(), ids): + groups_ids[key].append(idx) - # add the data to the dataset - logging.info(f"Write parsed data from {file.split('/')[-1]} to db ...") - - dataset.add_systems(property_list=property_list, atoms_list=atoms_list) + logging.info(f"Write parsed data from {os.path.basename(file)} to db ...") + self.add_systems(property_list=property_list, atoms_list=atoms_list) # add the hierarchical ids to the metadata - md = dataset.metadata - if "groups_ids" in md.keys(): + md = self.metadata + if "groups_ids" in md: for key, ids in groups_ids.items(): groups_ids[key] = md["groups_ids"][key] + ids @@ -375,80 +380,5 @@ def _parse_data(self, files: List[str], dataset: BaseAtomsData): else: groups_ids["id"] = list(range(1, len(atoms_list) + 1)) - dataset.update_metadata(groups_ids=groups_ids) - + self.update_metadata(groups_ids=groups_ids) logging.info("Done.") - - def prepare_data(self): - """ - prepare data for pytorch lightning data module - """ - if not os.path.exists(self.datapath): - tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") - - atomrefs = { - QM7X.energy: [ - QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 - for i in range(0, 18) - ] - } - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=QM7X.property_unit_dict, - atomrefs=atomrefs, - ) - - hd_files = self._download_data(tar_dir) - if self.remove_duplicates: - self._download_duplicates_ids(tar_dir) - self._parse_data(hd_files, dataset) - - if self.raw_data_path is None: - shutil.rmtree(tar_dir) - - def setup(self, stage=None): - if self.data_workdir is None: - datapath = self.datapath - else: - datapath = self._copy_to_workdir() - - # (re)load datasets - if self.dataset is None: - self.dataset = load_dataset( - datapath, - self.format, - property_units=self.property_units, - distance_unit=self.distance_unit, - load_properties=self.load_properties, - ) - - # use subset of equilibrium structures - - if self.only_equilibrium or self.only_non_equilibrium: - step_ids = self.dataset.metadata["groups_ids"]["step_id"] - - if len(step_ids) != len(self.dataset): - raise ValueError( - "The dataset size does not match the size of step ids arrays in meta data." - ) - - if self.only_equilibrium: - eq_indices = [i for i, s in enumerate(step_ids) if s == 0] - else: - eq_indices = [i for i, s in enumerate(step_ids) if s != 0] - - self.dataset = self.dataset.subset(eq_indices) - - # load and generate partitions if needed - if self.train_idx is None: - self._load_partitions() - - # partition dataset - self._train_dataset = self.dataset.subset(self.train_idx) - self._val_dataset = self.dataset.subset(self.val_idx) - self._test_dataset = self.dataset.subset(self.test_idx) - - self._setup_transforms() diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 3085a21a7..2408f288e 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -5,32 +5,27 @@ import shutil import tarfile import tempfile -from typing import List, Optional, Dict +from typing import Dict, List, Optional from urllib import request as request import numpy as np from ase import Atoms from ase.io.extxyz import read_xyz +from ase.db import connect + from tqdm import tqdm -import torch -from schnetpack.data import * import schnetpack.properties as structure -from schnetpack.data import AtomsDataModuleError, AtomsDataModule - -__all__ = ["QM9"] +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform -class QM9(AtomsDataModule): - """QM9 benchmark database for organic molecules. - - The QM9 database contains small organic molecules with up to nine non-hydrogen atoms - from including C, O, N, F. This class adds convenient functions to download QM9 from - figshare and load the data into pytorch. +__all__ = ["QM9"] - References: - .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 +class QM9(ASEAtomsData): + """ + QM9 benchmark database for organic molecules. """ base_urls = [ @@ -43,7 +38,6 @@ class QM9(AtomsDataModule): "uncharacterized": "3195404", } - # properties A = "rotational_constant_A" B = "rotational_constant_B" C = "rotational_constant_C" @@ -63,142 +57,121 @@ class QM9(AtomsDataModule): def __init__( self, datapath: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, - load_properties: Optional[List[str]] = None, remove_uncharacterized: bool = False, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + load_properties: Optional[List[str]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, **kwargs, ): """ - Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load remove_uncharacterized: do not include uncharacterized molecules. - val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + load_properties: subset of properties to load + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ + self.remove_uncharacterized = remove_uncharacterized + self.distance_unit = "Ang" + self.property_units = self._native_property_units() + super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, **kwargs, ) - self.remove_uncharacterized = remove_uncharacterized + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + QM9.A: "GHz", + QM9.B: "GHz", + QM9.C: "GHz", + QM9.mu: "Debye", + QM9.alpha: "a0 a0 a0", + QM9.homo: "Ha", + QM9.lumo: "Ha", + QM9.gap: "Ha", + QM9.r2: "a0 a0", + QM9.zpve: "Ha", + QM9.U0: "Ha", + QM9.U: "Ha", + QM9.H: "Ha", + QM9.G: "Ha", + QM9.Cv: "cal/mol/K", + } + + def _check_db(self) -> None: + """ + Ensure the QM9 ASE DB exists. + """ + super()._check_db() + with connect(self.datapath, use_lock_file=False) as conn: + data_count = conn.count() + + if self.remove_uncharacterized and data_count == 133885: + raise AtomsDataError( + "The dataset at the chosen location contains the uncharacterized 3054 molecules. " + "Choose a different location to reload the data or set " + "`remove_uncharacterized=False`." + ) + + if (not self.remove_uncharacterized) and data_count < 133885: + raise AtomsDataError( + "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " + "Choose a different location to reload the data or set " + "`remove_uncharacterized=True`." + ) - def _download_file(self, file_id: str, destination: str): + def download(self) -> None: + """ + Download the QM9 ASE DB. + """ + tmpdir = tempfile.mkdtemp("qm9") + + atomrefs = self._download_atomrefs(tmpdir) + md = self.metadata + md["atomrefs"] = atomrefs + self._set_metadata(md) + + if self.remove_uncharacterized: + uncharacterized = self._download_uncharacterized(tmpdir) + else: + uncharacterized = None + + self._download_data(tmpdir, uncharacterized) + + shutil.rmtree(tmpdir, ignore_errors=True) + + def _download_file(self, file_id: str, destination: str) -> None: for base_url in self.base_urls: - url = f"{base_url}{file_id}" try: - request.urlretrieve(url, destination) + request.urlretrieve(f"{base_url}{file_id}", destination) return except Exception: - logging.warning(f"Could not download from {url}, trying next source...") - raise AtomsDataModuleError( + continue + raise AtomsDataError( f"Could not download file with id {file_id} from any source." ) - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - QM9.A: "GHz", - QM9.B: "GHz", - QM9.C: "GHz", - QM9.mu: "Debye", - QM9.alpha: "a0 a0 a0", - QM9.homo: "Ha", - QM9.lumo: "Ha", - QM9.gap: "Ha", - QM9.r2: "a0 a0", - QM9.zpve: "Ha", - QM9.U0: "Ha", - QM9.U: "Ha", - QM9.H: "Ha", - QM9.G: "Ha", - QM9.Cv: "cal/mol/K", - } - - tmpdir = tempfile.mkdtemp("qm9") - atomrefs = self._download_atomrefs(tmpdir) - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=atomrefs, - ) - - if self.remove_uncharacterized: - uncharacterized = self._download_uncharacterized(tmpdir) - else: - uncharacterized = None - self._download_data(tmpdir, dataset, uncharacterized=uncharacterized) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) - if self.remove_uncharacterized and len(dataset) == 133885: - raise AtomsDataModuleError( - "The dataset at the chosen location contains the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=False`!" - ) - elif not self.remove_uncharacterized and len(dataset) < 133885: - raise AtomsDataModuleError( - "The dataset at the chosen location does NOT contain the uncharacterized 3054 molecules. " - + "Choose a different location to reload the data or set `remove_uncharacterized=True`!" - ) - - def _download_uncharacterized(self, tmpdir): + def _download_uncharacterized(self, tmpdir: str) -> List[int]: logging.info("Downloading list of uncharacterized molecules...") tmp_path = os.path.join(tmpdir, "uncharacterized.txt") self._download_file(self.file_ids["uncharacterized"], tmp_path) @@ -211,7 +184,7 @@ def _download_uncharacterized(self, tmpdir): uncharacterized.append(int(line.split()[0])) return uncharacterized - def _download_atomrefs(self, tmpdir): + def _download_atomrefs(self, tmpdir: str) -> Dict[str, List[float]]: logging.info("Downloading GDB-9 atom references...") tmp_path = os.path.join(tmpdir, "atomrefs.txt") self._download_file(self.file_ids["atomrefs"], tmp_path) @@ -219,17 +192,21 @@ def _download_atomrefs(self, tmpdir): props = [QM9.zpve, QM9.U0, QM9.U, QM9.H, QM9.G, QM9.Cv] atref = {p: np.zeros((100,)) for p in props} + with open(tmp_path) as f: lines = f.readlines() for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): for i, p in enumerate(props): atref[p][z] = float(l.split()[i + 1]) - atref = {k: v.tolist() for k, v in atref.items()} - return atref + + return {k: v.tolist() for k, v in atref.items()} def _download_data( - self, tmpdir, dataset: BaseAtomsData, uncharacterized: List[int] - ): + self, + tmpdir: str, + uncharacterized: Optional[List[int]], + ) -> None: + logging.info("Downloading GDB-9 data...") tar_path = os.path.join(tmpdir, "gdb9.tar.gz") raw_path = os.path.join(tmpdir, "gdb9_xyz") @@ -237,9 +214,8 @@ def _download_data( logging.info("Done.") logging.info("Extracting files...") - tar = tarfile.open(tar_path) - tar.extractall(raw_path) - tar.close() + with tarfile.open(tar_path) as tar: + tar.extractall(raw_path) logging.info("Done.") logging.info("Parse xyz files...") @@ -248,32 +224,36 @@ def _download_data( ) property_list = [] + indices = np.arange(len(ordered_files), dtype=int) - irange = np.arange(len(ordered_files), dtype=int) if uncharacterized is not None: - irange = np.setdiff1d(irange, np.array(uncharacterized, dtype=int) - 1) + indices = np.setdiff1d(indices, np.array(uncharacterized, dtype=int) - 1) - for i in tqdm(irange): + for i in tqdm(indices): xyzfile = os.path.join(raw_path, ordered_files[i]) properties = {} tmp = io.StringIO() with open(xyzfile, "r") as f: lines = f.readlines() - l = lines[1].split()[2:] - for pn, p in zip(dataset.available_properties, l): - properties[pn] = np.array([float(p)]) + values = lines[1].split()[2:] + + for pname, value in zip(self.available_properties, values): + properties[pname] = np.array([float(value)]) + for line in lines: tmp.write(line.replace("*^", "e")) tmp.seek(0) - ats: Atoms = list(read_xyz(tmp, 0))[0] - properties[structure.Z] = ats.numbers - properties[structure.R] = ats.positions - properties[structure.cell] = ats.cell - properties[structure.pbc] = ats.pbc + atoms: Atoms = list(read_xyz(tmp, 0))[0] + + properties[structure.Z] = atoms.numbers + properties[structure.R] = atoms.positions + properties[structure.cell] = atoms.cell + properties[structure.pbc] = atoms.pbc + property_list.append(properties) logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self.add_systems(property_list=property_list) logging.info("Done.") diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 50cf59385..3ff05232f 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -1,32 +1,31 @@ import logging import os import shutil -import tempfile import tarfile -from typing import List, Optional, Dict -from urllib import request as request +import tempfile +from typing import Dict, List, Optional +from urllib.request import Request, urlopen +from urllib.error import HTTPError, URLError import numpy as np from ase import Atoms -import torch import schnetpack.properties as structure - -from schnetpack.data import * +from schnetpack.data.atoms import ASEAtomsData, AtomsDataError +from schnetpack.transform.base import Transform __all__ = ["rMD17"] -class rMD17(AtomsDataModule): +class rMD17(ASEAtomsData): """ - Revised MD17 benchmark data set for molecular dynamics of small molecules + Revised MD17 benchmark dataset for molecular dynamics of small molecules containing molecular forces. References: .. [#md17_1] https://figshare.com/articles/dataset/ Revised_MD17_dataset_rMD17_/12672038?file=24013628 .. [#md17_2] http://quantum-machine.org/gdml/#datasets - """ energy = "energy" @@ -46,219 +45,222 @@ class rMD17(AtomsDataModule): ] } - datasets_dict = dict( - aspirin="rmd17_aspirin.npz", - azobenzene="rmd17_azobenzene.npz", - benzene="rmd17_benzene.npz", - ethanol="rmd17_ethanol.npz", - malonaldehyde="rmd17_malonaldehyde.npz", - naphthalene="rmd17_naphthalene.npz", - paracetamol="rmd17_paracetamol.npz", - salicylic_acid="rmd17_salicylic.npz", - toluene="rmd17_toluene.npz", - uracil="rmd17_uracil.npz", - ) - - # properties + datasets_dict = { + "aspirin": "rmd17_aspirin.npz", + "azobenzene": "rmd17_azobenzene.npz", + "benzene": "rmd17_benzene.npz", + "ethanol": "rmd17_ethanol.npz", + "malonaldehyde": "rmd17_malonaldehyde.npz", + "naphthalene": "rmd17_naphthalene.npz", + "paracetamol": "rmd17_paracetamol.npz", + "salicylic_acid": "rmd17_salicylic.npz", + "toluene": "rmd17_toluene.npz", + "uracil": "rmd17_uracil.npz", + } + + download_urls = [ + "https://figshare.com/ndownloader/files/23950376", + "https://archive.materialscloud.org/records/pfffs-fff86/files/rmd17.tar.bz2?download=1", + ] + def __init__( self, datapath: str, molecule: str, - batch_size: int, - num_train: Optional[int] = None, - num_val: Optional[int] = None, - num_test: Optional[int] = None, - split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, - val_batch_size: Optional[int] = None, - test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, - num_workers: int = 2, - num_val_workers: Optional[int] = None, - num_test_workers: Optional[int] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, + subset_idx: Optional[List[int]] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - data_workdir: Optional[str] = None, - split_id: Optional[int] = None, **kwargs, ): """ Args: datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples - num_val: number of validation examples - num_test: number of test examples - split_file: path to npz file with data partitions - format: dataset format + molecule: name of the molecule load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then - batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then - batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers - (overrides num_workers). - num_test_workers: Number of test data loader workers - (overrides num_workers). - distance_unit: Unit of the atom positions and cell as a string - (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for - faster performance. - split_id: The id of the predefined rMD17 train/test splits (0-4). + transforms: transform applied to each system separately before batching + train_transforms: overrides transform_fn for training + val_transforms: overrides transform_fn for validation + test_transforms: overrides transform_fn for testing + subset_idx: indices of the subset to load + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...) + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...) """ - if split_id is not None: - splitting = SubsamplePartitions( - split_partition_sources=["known", "known", "test"], split_id=split_id - ) - else: - splitting = RandomSplit() + if molecule not in self.datasets_dict.keys(): + raise AtomsDataError(f"Molecule {molecule} is not supported!") + + self.molecule = molecule + + self.distance_unit = "Ang" + self.property_units = self._native_property_units() super().__init__( datapath=datapath, - batch_size=batch_size, - num_train=num_train, - num_val=num_val, - num_test=num_test, - split_file=split_file, - format=format, load_properties=load_properties, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, - num_workers=num_workers, - num_val_workers=num_val_workers, - num_test_workers=num_test_workers, + subset_idx=subset_idx, property_units=property_units, distance_unit=distance_unit, - data_workdir=data_workdir, - splitting=splitting, **kwargs, ) - if molecule not in rMD17.datasets_dict.keys(): - raise AtomsDataModuleError("Molecule {} is not supported!".format(molecule)) + @staticmethod + def _native_property_units() -> Dict[str, str]: + return { + rMD17.energy: "kcal/mol", + rMD17.forces: "kcal/mol/Ang", + } - self.molecule = molecule + def _check_db(self) -> None: + super()._check_db() + md = self.metadata - def prepare_data(self): - if not os.path.exists(self.datapath): - property_unit_dict = { - rMD17.energy: "kcal/mol", - rMD17.forces: "kcal/mol/Ang", - } - - tmpdir = tempfile.mkdtemp("md17") - - dataset = create_dataset( - datapath=self.datapath, - format=self.format, - distance_unit="Ang", - property_unit_dict=property_unit_dict, - atomrefs=rMD17.atomrefs, + if "molecule" not in md: + raise AtomsDataError( + "Not a valid rMD17 dataset. Metadata must contain `molecule`." ) - dataset.update_metadata(molecule=self.molecule) - - self._download_data(tmpdir, dataset) - shutil.rmtree(tmpdir) - else: - dataset = load_dataset(self.datapath, self.format) - md = dataset.metadata - if "molecule" not in md: - raise AtomsDataModuleError( - "Not a valid rMD17 dataset! The molecule needs to be specified in " - + "the metadata." - ) - if md["molecule"] != self.molecule: - raise AtomsDataModuleError( - f"The dataset at the given location does not contain the specified " - + f"molecule: `{md['molecule']}` instead of `{self.molecule}`" - ) - def _download_data( - self, - tmpdir, - dataset: BaseAtomsData, - ): - logging.info("Downloading {} data".format(self.molecule)) - raw_path = os.path.join(tmpdir, "rmd17") - tar_path = os.path.join(tmpdir, "rmd17.tar.gz") - url = "https://figshare.com/ndownloader/files/23950376" - request.urlretrieve(url, tar_path) - logging.info("Done.") + if md["molecule"] != self.molecule: + raise AtomsDataError( + f"The dataset at the given location contains `{md['molecule']}` " + f"instead of `{self.molecule}`." + ) - logging.info("Extracting data...") - tar = tarfile.open(tar_path) - tar.extract( - path=raw_path, member=f"rmd17/npz_data/{self.datasets_dict[self.molecule]}" - ) + def download(self) -> None: + """ + Download the requested rMD17 molecule and populate the ASE DB. + """ + tmpdir = tempfile.mkdtemp("rmd17") + md = self.metadata + md["atomrefs"] = self.atomrefs + md["molecule"] = self.molecule + self._set_metadata(md) - logging.info("Parsing molecule {:s}".format(self.molecule)) + self._download_data(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) - data = np.load( - os.path.join( - raw_path, "rmd17", "npz_data", self.datasets_dict[self.molecule] - ) - ) + def _download_data(self, tmpdir: str) -> None: + logging.info("Downloading %s data...", self.molecule) + + raw_path = os.path.join(tmpdir, "rmd17") + tar_path = os.path.join(tmpdir, "rmd17.tar") - numbers = data["nuclear_charges"] - property_list = [] - for positions, energies, forces in zip( - data["coords"], data["energies"], data["forces"] - ): - ats = Atoms(positions=positions, numbers=numbers) - properties = { - rMD17.energy: np.array([energies]), - rMD17.forces: forces, - structure.Z: ats.numbers, - structure.R: ats.positions, - structure.cell: ats.cell, - structure.pbc: ats.pbc, - } - property_list.append(properties) - - logging.info("Write atoms to db...") - dataset.add_systems(property_list=property_list) + self._download_archive(tar_path) logging.info("Done.") - train_splits = [] - test_splits = [] - for i in range(1, 6): - tar.extract(path=raw_path, member=f"rmd17/splits/index_train_0{i}.csv") - tar.extract(path=raw_path, member=f"rmd17/splits/index_test_0{i}.csv") + logging.info("Extracting data...") + os.makedirs(raw_path, exist_ok=True) - train_split = ( - np.loadtxt( - os.path.join(raw_path, "rmd17", "splits", f"index_train_0{i}.csv") - ) - .flatten() - .astype(int) - .tolist() + with tarfile.open(tar_path, mode="r:*") as tar: + tar.extract( + path=raw_path, + member=f"rmd17/npz_data/{self.datasets_dict[self.molecule]}", ) - train_splits.append(train_split) - test_split = ( - np.loadtxt( - os.path.join(raw_path, "rmd17", "splits", f"index_test_0{i}.csv") + + logging.info("Parsing molecule %s", self.molecule) + + data = np.load( + os.path.join( + raw_path, + "rmd17", + "npz_data", + self.datasets_dict[self.molecule], ) - .flatten() - .astype(int) - .tolist() ) - test_splits.append(test_split) - dataset.update_metadata(splits={"known": train_splits, "test": test_splits}) + numbers = data["nuclear_charges"] + property_list = [] + + for positions, energies, forces in zip( + data["coords"], data["energies"], data["forces"] + ): + ats = Atoms(positions=positions, numbers=numbers) + properties = { + rMD17.energy: np.array([energies]), + rMD17.forces: forces, + structure.Z: ats.numbers, + structure.R: ats.positions, + structure.cell: ats.cell, + structure.pbc: ats.pbc, + } + property_list.append(properties) + + logging.info("Write atoms to db...") + self.add_systems(property_list=property_list) + logging.info("Done.") + + train_splits = [] + test_splits = [] + + for i in range(1, 6): + tar.extract(path=raw_path, member=f"rmd17/splits/index_train_0{i}.csv") + tar.extract(path=raw_path, member=f"rmd17/splits/index_test_0{i}.csv") + + train_split = ( + np.loadtxt( + os.path.join( + raw_path, "rmd17", "splits", f"index_train_0{i}.csv" + ) + ) + .flatten() + .astype(int) + - 1 + ).tolist() + train_splits.append(train_split) - tar.close() + test_split = ( + np.loadtxt( + os.path.join( + raw_path, "rmd17", "splits", f"index_test_0{i}.csv" + ) + ) + .flatten() + .astype(int) + - 1 + ).tolist() + test_splits.append(test_split) + + self.update_metadata(splits={"known": train_splits, "test": test_splits}) logging.info("Done.") + + def _download_archive(self, destination: str) -> None: + last_error = None + + for url in self.download_urls: + try: + logging.info("Downloading from: %s", url) + req = Request(url) + with urlopen(req, timeout=600) as resp, open(destination, "wb") as f: + shutil.copyfileobj(resp, f) + + if not os.path.exists(destination): + raise RuntimeError("Download did not create a file.") + + size = os.path.getsize(destination) + ctype = (resp.headers.get("Content-Type") or "").lower() + + if size == 0: + raise RuntimeError("Downloaded file is empty.") + + if "text/html" in ctype: + raise RuntimeError( + f"Got HTML instead of archive (Content-Type={ctype})." + ) + + return + + except (HTTPError, URLError, RuntimeError) as e: + last_error = e + logging.warning("Download failed from %s: %s", url, e) + + raise AtomsDataError( + f"rMD17 download failed from all sources. Last error: {last_error}" + ) diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index 17c856ec7..cbf75dfd5 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -1,24 +1,17 @@ -import io -import logging import os -import re import shutil -import tarfile import tempfile from typing import List, Optional, Dict from urllib import request as request import gzip import numpy as np -from ase import Atoms -from ase.io.extxyz import read_xyz from ase.io import read -from tqdm import tqdm -import torch from schnetpack.data import * -import schnetpack.properties as structure -from schnetpack.data import AtomsDataModuleError, AtomsDataModule +from schnetpack.data import AtomsDataModule +from schnetpack.transform.base import Transform + __all__ = ["TMQM"] @@ -59,14 +52,13 @@ def __init__( num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", - format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, - transforms: Optional[List[torch.nn.Module]] = None, - train_transforms: Optional[List[torch.nn.Module]] = None, - val_transforms: Optional[List[torch.nn.Module]] = None, - test_transforms: Optional[List[torch.nn.Module]] = None, + transforms: Optional[List[Transform]] = None, + train_transforms: Optional[List[Transform]] = None, + val_transforms: Optional[List[Transform]] = None, + test_transforms: Optional[List[Transform]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, @@ -84,21 +76,20 @@ def __init__( num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions - format: dataset format load_properties: subset of properties to load remove_uncharacterized: do not include uncharacterized molecules. val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. - transforms: Transform applied to each system separately before batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers (overrides num_workers). - num_test_workers: Number of test data loader workers (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. + transforms: transform applied to each system separately before batching. + train_transforms: overrides transform_fn for training. + val_transforms: overrides transform_fn for validation. + test_transforms: overrides transform_fn for testing. + num_workers: number of data loader workers. + num_val_workers: number of validation data loader workers (overrides num_workers). + num_test_workers: number of test data loader workers (overrides num_workers). + property_units: dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). + distance_unit: unit of the atom positions and cell as a string (Ang, Bohr, ...). + data_workdir: copy data here as part of setup, e.g. cluster scratch for faster performance. """ super().__init__( datapath=datapath, @@ -107,7 +98,6 @@ def __init__( num_val=num_val, num_test=num_test, split_file=split_file, - format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, @@ -139,9 +129,8 @@ def prepare_data(self): tmpdir = tempfile.mkdtemp("tmQM") - dataset = create_dataset( + dataset = ASEAtomsData( datapath=self.datapath, - format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) @@ -149,9 +138,9 @@ def prepare_data(self): self._download_data(tmpdir, dataset) shutil.rmtree(tmpdir) else: - dataset = load_dataset(self.datapath, self.format) + dataset = ASEAtomsData(self.datapath) - def _download_data(self, tmpdir, dataset: BaseAtomsData): + def _download_data(self, tmpdir, dataset: ASEAtomsData): tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz") url = [ "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz", diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index d7a87dfa1..b6fde9c25 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional +from typing import Dict +import warnings import torch from ase.data import atomic_masses @@ -6,6 +7,8 @@ import schnetpack.properties as structure from .base import Transform from schnetpack.nn import scatter_add +from schnetpack.data.provider import StatsAtomrefProvider + __all__ = [ "SubtractCenterOfMass", @@ -19,6 +22,7 @@ class SubtractCenterOfMass(Transform): """ Subtract center of mass from positions. + """ is_preprocessor: bool = True @@ -41,6 +45,7 @@ def forward( class SubtractCenterOfGeometry(Transform): """ Subtract center of geometry from positions. + """ is_preprocessor: bool = True @@ -59,8 +64,6 @@ class RemoveOffsets(Transform): Remove offsets from property based on the mean of the training data and/or the single atom reference calculations. - The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, - when it is used. Otherwise, they have to be provided in the init manually. """ is_preprocessor: bool = True @@ -87,6 +90,7 @@ def __init__( tensor. atomrefs: Provide single-atom references directly. property_mean: Provide mean property value / n_atoms. + estimate_atomref: If true, add estimated atomrefs. """ super().__init__() self._property = property @@ -117,24 +121,40 @@ def __init__( property_mean = property_mean or torch.zeros((1,)) self.register_buffer("mean", property_mean) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None, **kwargs) -> None: """ - Sets mean and atomref automatically when using PyTorchLightning integration. + Initialize mean and/or atomref using a StatsAtomrefProvider. """ if self.remove_atomrefs and not self._atomrefs_initialized: if self.estimate_atomref: - atrefs = _datamodule.get_atomrefs( - property=self._property, is_extensive=self.is_extensive - ) + atrefs = provider.get_atomrefs(self._property, self.is_extensive) else: - atrefs = _datamodule.train_dataset.atomrefs + if atomrefs is None: + raise RuntimeError( + "RemoveOffsets requires dataset atomrefs when estimate_atomref=False." + ) + atrefs = atomrefs self.atomref = atrefs[self._property].detach() if self.remove_mean and not self._mean_initialized: - stats = _datamodule.get_stats( + mean, _std = provider.get_stats( self._property, self.is_extensive, self.remove_atomrefs ) - self.mean = stats[0].detach() + self.mean = mean.detach() + + # legacy hook for old AtomsDataModule + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + warnings.warn( + "RemoveOffsets.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + provider = StatsAtomrefProvider(_datamodule.train_dataset) + return self.initialize(provider, atomrefs=provider.train_atomrefs) def forward( self, @@ -147,6 +167,7 @@ def forward( else self.mean ) inputs[self._property] -= mean + if self.remove_atomrefs: atomref_bias = torch.sum(self.atomref[inputs[structure.Z]]) if not self.is_extensive: @@ -200,12 +221,28 @@ def __init__( scale = scale or torch.ones((1,)) self.register_buffer("scale", scale) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None) -> None: + """ + Initialize scaling using training statistics. + """ if not self._initialized: - stats = _datamodule.get_stats(self._target_key, True, False) - scale = stats[0] if self._scale_by_mean else stats[1] + mean, std = provider.get_stats(self._target_key, True, False) + scale = mean if self._scale_by_mean else std self.scale = torch.abs(scale).detach() + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + warnings.warn( + "ScaleProperty.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + provider = StatsAtomrefProvider(_datamodule.train_dataset) + return self.initialize(provider, atomrefs=None) + def forward( self, inputs: Dict[str, torch.Tensor], @@ -219,15 +256,9 @@ class AddOffsets(Transform): Add offsets to property based on the mean of the training data and/or the single atom reference calculations. - The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, - when it is used. Otherwise, they have to be provided in the init manually. - - Hint: - Place this postprocessor after casting to float64 for higher numerical - precision. """ - is_preprocessor: bool = True + is_preprocessor: bool = False is_postprocessor: bool = True atomref: torch.Tensor @@ -252,6 +283,7 @@ def __init__( tensor. atomrefs: Provide single-atom references directly. property_mean: Provide mean property value / n_atoms. + estimate_atomref: If true, add estimated atomrefs. """ super().__init__() self._property = property @@ -281,21 +313,39 @@ def __init__( self.register_buffer("atomref", atomrefs) self.register_buffer("mean", property_mean) - def datamodule(self, _datamodule): + def initialize(self, provider, atomrefs=None) -> None: + """ + Initialize mean and/or atomref using a StatsAtomrefProvider. + """ if self.add_atomrefs and not self._atomrefs_initialized: if self.estimate_atomref: - atrefs = _datamodule.get_atomrefs( - property=self._property, is_extensive=self.is_extensive - ) + atrefs = provider.get_atomrefs(self._property, self.is_extensive) else: - atrefs = _datamodule.train_dataset.atomrefs + if atomrefs is None: + raise RuntimeError( + "AddOffsets requires dataset atomrefs when estimate_atomref=False." + ) + atrefs = atomrefs self.atomref = atrefs[self._property].detach() if self.add_mean and not self._mean_initialized: - stats = _datamodule.get_stats( + mean, _std = provider.get_stats( self._property, self.is_extensive, self.add_atomrefs ) - self.mean = stats[0].detach() + self.mean = mean.detach() + + def datamodule(self, _datamodule): + """ + Legacy hook for old AtomsDataModule. Safe to remove once legacy DM is removed. + """ + warnings.warn( + "AddOffsets.datamodule(...) is deprecated and will be removed in a future " + "release. Use initialize(provider=..., atomrefs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + provider = StatsAtomrefProvider(_datamodule.train_dataset) + return self.initialize(provider, atomrefs=provider.train_atomrefs) def forward( self, diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index 77535c200..4dbff1f61 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -1,9 +1,9 @@ -from typing import Optional, Dict +from typing import Dict import torch import torch.nn as nn -import schnetpack as spk +# import schnetpack as spk __all__ = [ "Transform", @@ -32,6 +32,7 @@ class Transform(nn.Module): def datamodule(self, value): """ + Legacy hook for transforms initialized from an old AtomsDataModule. Extract all required information from data module automatically when using PyTorch Lightning integration. The transform should also implement a way to set these things manually, to make it usable independent of PL. @@ -48,3 +49,9 @@ def forward( def teardown(self): pass + + def initialize(self, **kwargs) -> None: + """ + Initialization hook for transforms that require training + """ + return diff --git a/tests/data/test_data.py b/tests/data/test_data.py index eda02e7e5..f710f23ee 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -108,9 +108,11 @@ def test_stats(): atomref = {"property1": torch.ones((100,)) / 3.0} for bs in range(1, 7): stats = calculate_stats( - AtomsLoader(data, batch_size=bs), + data, {"property1": True, "property2": False}, atomref=atomref, + batch_size=bs, + num_workers=0, ) assert np.allclose(stats["property1"][0].numpy(), np.array([0.0])) assert np.allclose(stats["property1"][1].numpy(), np.array([1.0]))