Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
08980dc
chore: clean up .gitignore by removing unnecessary entries
sundusaijaz Feb 21, 2026
f8879c0
feat: add new AtomsDataModuleV2 and StatsAtomrefProvider for enhanced…
sundusaijaz Feb 22, 2026
b043e7e
style: improve code formatting and readability in multiple files
sundusaijaz Feb 22, 2026
6e67e01
refactor: simplify AtomsDataModuleV2 by removing unused parameters an…
sundusaijaz Feb 22, 2026
f7dd667
docs: update AtomsDataModuleV2 docstring for clarity by removing redu…
sundusaijaz Feb 22, 2026
8b5a014
test: add pytests for AtomsDataset and AtomsDataModuleV2 functionality
sundusaijaz Feb 22, 2026
215587e
refactor: update QM9 and StatsAtomrefProvider docstrings for clarity …
sundusaijaz Feb 22, 2026
db21004
feat: refactor AtomsDataModuleV2
sundusaijaz Mar 2, 2026
01fb1dd
refactor: update StatsAtomrefProvider to use BaseAtomsData and simpli…
sundusaijaz Mar 2, 2026
6d897a9
refactor: update calculate_stats and estimate_atomrefs to use BaseAto…
sundusaijaz Mar 2, 2026
7af99a8
refactor: simplify initialization in StatsAtomrefProvider
sundusaijaz Mar 2, 2026
e592743
refactor: update Transform class
sundusaijaz Mar 2, 2026
ba8d46f
refactor: QM9 class by removing unused parameters and simplifying doc…
sundusaijaz Mar 2, 2026
ee24a7e
refactor: enhance AtomsDataModuleV2 and QM9 class by simplifying init…
sundusaijaz Mar 3, 2026
8a90055
fix: black format
sundusaijaz Mar 3, 2026
ee32d85
refactor: update custom and qm9 config files
sundusaijaz Mar 4, 2026
ecaca86
refactor: improve model testing and checkpoint handling in cli
sundusaijaz Mar 4, 2026
1339fb6
refactor: merged ASEAtomsData class and BaseAtomsData
sundusaijaz Mar 4, 2026
cc3cb3c
refactor: update data handling
sundusaijaz Mar 7, 2026
b48ca02
refactor: simplify ASEAtomsData by removing unused methods and proper…
sundusaijaz Mar 7, 2026
bfc2c97
refactor: update dataset method signatures to use ASEAtomsData
sundusaijaz Mar 7, 2026
860fbee
refactor: update references from BaseAtomsData to ASEAtomsData in dat…
sundusaijaz Mar 7, 2026
c83501a
refactor: update checkpoint loading in training process and adjust da…
sundusaijaz Mar 7, 2026
3fa44d3
refactor: clean up code formatting and remove legacy QM9 dataset file
sundusaijaz Mar 8, 2026
7406cef
refactor: update rMD17 dataset class
sundusaijaz Mar 8, 2026
091a0e5
refactor: remove irrelevant refactor pytest
sundusaijaz Mar 8, 2026
2d92e05
refactor: update md17, md22, qm7x, rmd17
sundusaijaz Mar 8, 2026
61740d9
refactor: update dataset classes mp, ani1, iso17
sundusaijaz Mar 8, 2026
c242648
refactor: fix format error in MaterialsProject
sundusaijaz Mar 8, 2026
fb2bd3d
refactor: remove format parameter in QM7X dataset loading
sundusaijaz Mar 8, 2026
1fb9d97
refactor: remove legacy atoms_legacy.py file and streamline dataset l…
sundusaijaz Mar 12, 2026
bc8cb02
refactor: change ASEAtomsData class with additional transform options…
sundusaijaz Mar 12, 2026
4efe59b
refactor: remove format parameter from all dataset classes
sundusaijaz Mar 12, 2026
de23e5d
refactor: simplify format handling in AtomsDataModule (old)
sundusaijaz Mar 12, 2026
dadbf1b
refactor: removed dict in ASEAtomsData and simplify download method i…
sundusaijaz Mar 12, 2026
f788a91
refactor: add deprecation warnings for legacy datamodule methods in a…
sundusaijaz Mar 12, 2026
70cb3e1
refactor: add docstring and deprecation warnings for legacy argument…
sundusaijaz Mar 12, 2026
3bbc31d
refactor: replace property_unit_dict with _native_property_units meth…
sundusaijaz Mar 12, 2026
ed73284
refactor: enhance QM9 dataset with train/val/test transform options
sundusaijaz Mar 12, 2026
e627a69
refactor: add train/val/test transform options and docstring across m…
sundusaijaz Mar 12, 2026
607eab1
refactor: update docstrings in atomistic transforms
sundusaijaz Mar 12, 2026
2a6ce76
refactor: add docstrings in ASEAtomsData
sundusaijaz Mar 12, 2026
f0dd4c5
refactor: add docstrings for calculate_stats() and estimate_atomrefs()
sundusaijaz Mar 12, 2026
4ca3648
refactor: simplify transform initialization in ASEAtomsData and updat…
sundusaijaz Mar 12, 2026
93d9ce8
refactor: restructure configs for all datasets
sundusaijaz Mar 12, 2026
99ba778
refactor: update omdb to support datamodulev2
sundusaijaz Mar 12, 2026
5d44d74
refactor: streamline transform assignment and initialization in Atoms…
sundusaijaz Mar 14, 2026
43e605d
refactor: update pytest test_stats to accept data and batch parameter…
sundusaijaz Mar 14, 2026
db0a96e
refactor: improve ANI1 dataset loading and validation
sundusaijaz Mar 15, 2026
074bee9
refactor: enhance QM9 dataset loading
sundusaijaz Mar 15, 2026
d4ca4b0
refactor: simplify download method in ISO17 dataset
sundusaijaz Mar 15, 2026
985032e
refactor: enhance MaterialsProject and update docstring
sundusaijaz Mar 15, 2026
6bf81d3
refactor: optimize GDMLDataset download method in md17
sundusaijaz Mar 15, 2026
6f30249
refactor: enhance QM7X dataset docstring and improve download methods
sundusaijaz Mar 15, 2026
b5b8969
refactor: improve rMD17 dataset loading and metadata handling
sundusaijaz Mar 15, 2026
799d5a8
refactor: enhance ASEAtomsData and QM9 dataset handling
sundusaijaz Mar 17, 2026
6ee9699
refactor: streamline metadata _check_db() and dataset creation in ASE…
sundusaijaz Mar 19, 2026
86df4d8
refactor: enhance ASEAtomsData split transform and db creation
sundusaijaz Mar 19, 2026
1690b19
refactor: streamline ANI1 and QM9 dataset handling and download methods
sundusaijaz Mar 19, 2026
3962eb8
refactor: update ISO17 dataset download method and improve property u…
sundusaijaz Mar 19, 2026
66ffe81
refactor: improve MaterialsProject API key validation and simplify do…
sundusaijaz Mar 19, 2026
f961526
refactor: streamline GDMLDataset methods and enhance metadata handlin…
sundusaijaz Mar 23, 2026
716b34e
refactor: add type hint to _check_db() method in ASEAtomsData
sundusaijaz Mar 23, 2026
2b51b1c
refactor: simplify download method in omdb
sundusaijaz Mar 23, 2026
ca88623
refactor: simplify QM7X download method and enhance metadata handling
sundusaijaz Mar 23, 2026
6211431
refactor: remove unused imports and streamline rMD17 dataset methods
sundusaijaz Mar 23, 2026
cf4c406
refactor: remove unused split_id from rMD17 dataset configuration
sundusaijaz Mar 27, 2026
252627e
refactor: adjust train and test split calculations in rMD17 dataset
sundusaijaz Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ interfaces/lammps/examples/*/*.dat
interfaces/lammps/examples/*/deployed_model

# batchwise optimizer examples
examples/howtos/howto_batchwise_relaxations_outputs/*
examples/howtos/howto_batchwise_relaxations_outputs/*
10 changes: 6 additions & 4 deletions src/schnetpack/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import schnetpack as spk
from schnetpack.utils import str2class
from schnetpack.utils.script import log_hyperparameters, print_config
from schnetpack.data import BaseAtomsData, AtomsLoader
from schnetpack.data import ASEAtomsData, AtomsLoader
from schnetpack.train import PredictionWriter
from schnetpack import properties
from schnetpack.utils import load_model
Expand Down Expand Up @@ -178,14 +178,16 @@ def train(config: DictConfig):

# Evaluate model on test set after training
log.info("Starting testing.")
trainer.test(model=task, datamodule=datamodule, ckpt_path="best")
trainer.test(
model=task, datamodule=datamodule, ckpt_path="best", weights_only=False
)

# Store best model
best_path = trainer.checkpoint_callback.best_model_path
log.info(f"Best checkpoint path:\n{best_path}")

log.info(f"Store best model")
best_task = type(task).load_from_checkpoint(best_path)
best_task = type(task).load_from_checkpoint(best_path, weights_only=False)
torch.save(best_task, config.globals.model_path + ".task")

best_task.save_model(config.globals.model_path, do_postprocessing=True)
Expand All @@ -195,7 +197,7 @@ def train(config: DictConfig):
@hydra.main(config_path="configs", config_name="predict", version_base="1.2")
def predict(config: DictConfig):
log.info(f"Load data from `{config.data.datapath}`")
dataset: BaseAtomsData = hydra.utils.instantiate(config.data)
dataset: ASEAtomsData = hydra.utils.instantiate(config.data)
loader = AtomsLoader(dataset, batch_size=config.batch_size, num_workers=8)

model = load_model("best_model")
Expand Down
20 changes: 11 additions & 9 deletions src/schnetpack/configs/data/ani1.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.ANI1
dataset:
_target_: schnetpack.datasets.ANI1
datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml
num_heavy_atoms: 8
high_energies: false
distance_unit: Ang
property_units:
energy: eV
transforms: ${data.transforms}


datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 10000000
num_val: 100000
num_heavy_atoms: 8
high_energies: False

# convert to typically used units
distance_unit: Ang
property_units:
energy: eV
25 changes: 19 additions & 6 deletions src/schnetpack/configs/data/custom.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
_target_: schnetpack.data.AtomsDataModule
# @package data
_target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2

dataset:
_target_: schnetpack.data.ASEAtomsData
datapath: ???
load_properties: null
distance_unit: Ang
property_units: {}
transforms: ${data.transforms}
train_transforms: null
val_transforms: null
test_transforms: null

datapath: ???
data_workdir: null
batch_size: 10
num_train: ???
num_val: ???
num_test: null
split_file: ${run.data_dir}/split.npz
splitting: null
num_workers: 8
num_val_workers: null
num_test_workers: null
train_sampler_cls: null
train_sampler_cls: null
train_sampler_args: {}
pin_memory: false

8 changes: 5 additions & 3 deletions src/schnetpack/configs/data/iso17.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.ISO17
dataset:
_target_: schnetpack.datasets.ISO17
datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml
folder: reference

datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml
folder: reference
batch_size: 32
num_train: 0.9
num_val: 0.1
10 changes: 6 additions & 4 deletions src/schnetpack/configs/data/materials_project.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.MaterialsProject
dataset:
_target_: schnetpack.datasets.MaterialsProject
datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml
apikey: ???

datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 60000
num_val: 2000
apikey: ???
num_val: 2000
10 changes: 7 additions & 3 deletions src/schnetpack/configs/data/md17.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.MD17

datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml
molecule: aspirin

dataset:
_target_: schnetpack.datasets.MD17
datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml
molecule: ${data.molecule}

batch_size: 10
num_train: 950
num_val: 50
10 changes: 7 additions & 3 deletions src/schnetpack/configs/data/md22.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.MD22

datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml
molecule: Ac-Ala3-NHMe

dataset:
_target_: schnetpack.datasets.MD22
datapath: ${run.data_dir}/${data.molecule}.db
molecule: ${data.molecule}

batch_size: 10
num_train: 5700
num_val: 300
8 changes: 5 additions & 3 deletions src/schnetpack/configs/data/omdb.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.OrganicMaterialsDatabase
dataset:
_target_: schnetpack.datasets.OrganicMaterialsDatabase
datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml
raw_path: null

datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 0.8
num_val: 0.1
raw_path: null
9 changes: 7 additions & 2 deletions src/schnetpack/configs/data/qm7x.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.QM7X
dataset:
_target_: schnetpack.datasets.QM7X
datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml
remove_duplicates: true
only_equilibrium: false
only_non_equilibrium: false

datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml
batch_size: 100
num_train: 5550
num_val: 700
34 changes: 19 additions & 15 deletions src/schnetpack/configs/data/qm9.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
# @package data
defaults:
- custom

_target_: schnetpack.datasets.QM9
dataset:
_target_: schnetpack.datasets.qm9.QM9
datapath: ${run.data_dir}/qm9.db
remove_uncharacterized: true
load_properties: null
distance_unit: Ang
property_units:
energy_U0: eV
energy_U: eV
enthalpy_H: eV
free_energy: eV
homo: eV
lumo: eV
gap: eV
zpve: eV
transforms: ${data.transforms}

datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml
batch_size: 100
num_train: 110000
num_val: 10000
remove_uncharacterized: True

# convert to typically used units
distance_unit: Ang
property_units:
energy_U0: eV
energy_U: eV
enthalpy_H: eV
free_energy: eV
homo: eV
lumo: eV
gap: eV
zpve: eV
num_test: 10000
num_workers: 2
10 changes: 6 additions & 4 deletions src/schnetpack/configs/data/rmd17.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# @package data
defaults:
- custom
molecule: aspirin

_target_: schnetpack.datasets.rMD17
dataset:
_target_: schnetpack.datasets.rMD17
datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml
molecule: ${data.molecule}

datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml
molecule: aspirin
batch_size: 10
num_train: 950
num_val: 50
split_id: null
2 changes: 2 additions & 0 deletions src/schnetpack/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from .splitting import *
from .datamodule import *
from .sampler import *
from .datamodule_v2 import *
from .provider import *
Loading
Loading