-
Notifications
You must be signed in to change notification settings - Fork 251
AtomsDataset and DataModule Refactor #781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sundusaijaz
wants to merge
68
commits into
dev
Choose a base branch
from
sa/dataset_refactor
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
08980dc
chore: clean up .gitignore by removing unnecessary entries
sundusaijaz f8879c0
feat: add new AtomsDataModuleV2 and StatsAtomrefProvider for enhanced…
sundusaijaz b043e7e
style: improve code formatting and readability in multiple files
sundusaijaz 6e67e01
refactor: simplify AtomsDataModuleV2 by removing unused parameters an…
sundusaijaz f7dd667
docs: update AtomsDataModuleV2 docstring for clarity by removing redu…
sundusaijaz 8b5a014
test: add pytests for AtomsDataset and AtomsDataModuleV2 functionality
sundusaijaz 215587e
refactor: update QM9 and StatsAtomrefProvider docstrings for clarity …
sundusaijaz db21004
feat: refactor AtomsDataModuleV2
sundusaijaz 01fb1dd
refactor: update StatsAtomrefProvider to use BaseAtomsData and simpli…
sundusaijaz 6d897a9
refactor: update calculate_stats and estimate_atomrefs to use BaseAto…
sundusaijaz 7af99a8
refactor: simplify initialization in StatsAtomrefProvider
sundusaijaz e592743
refactor: update Transform class
sundusaijaz ba8d46f
refactor: QM9 class by removing unused parameters and simplifying doc…
sundusaijaz ee24a7e
refactor: enhance AtomsDataModuleV2 and QM9 class by simplifying init…
sundusaijaz 8a90055
fix: black format
sundusaijaz ee32d85
refactor: update custom and qm9 config files
sundusaijaz ecaca86
refactor: improve model testing and checkpoint handling in cli
sundusaijaz 1339fb6
refactor: merged ASEAtomsData class and BaseAtomsData
sundusaijaz cc3cb3c
refactor: update data handling
sundusaijaz b48ca02
refactor: simplify ASEAtomsData by removing unused methods and proper…
sundusaijaz bfc2c97
refactor: update dataset method signatures to use ASEAtomsData
sundusaijaz 860fbee
refactor: update references from BaseAtomsData to ASEAtomsData in dat…
sundusaijaz c83501a
refactor: update checkpoint loading in training process and adjust da…
sundusaijaz 3fa44d3
refactor: clean up code formatting and remove legacy QM9 dataset file
sundusaijaz 7406cef
refactor: update rMD17 dataset class
sundusaijaz 091a0e5
refactor: remove irrelevant refactor pytest
sundusaijaz 2d92e05
refactor: update md17, md22, qm7x, rmd17
sundusaijaz 61740d9
refactor: update dataset classes mp, ani1, iso17
sundusaijaz c242648
refactor: fix format error in MaterialsProject
sundusaijaz fb2bd3d
refactor: remove format parameter in QM7X dataset loading
sundusaijaz 1fb9d97
refactor: remove legacy atoms_legacy.py file and streamline dataset l…
sundusaijaz bc8cb02
refactor: change ASEAtomsData class with additional transform options…
sundusaijaz 4efe59b
refactor: remove format parameter from all dataset classes
sundusaijaz de23e5d
refactor: simplify format handling in AtomsDataModule (old)
sundusaijaz dadbf1b
refactor: removed dict in ASEAtomsData and simplify download method i…
sundusaijaz f788a91
refactor: add deprecation warnings for legacy datamodule methods in a…
sundusaijaz 70cb3e1
refactor: add docstring and deprecation warnings for legacy argument…
sundusaijaz 3bbc31d
refactor: replace property_unit_dict with _native_property_units meth…
sundusaijaz ed73284
refactor: enhance QM9 dataset with train/val/test transform options
sundusaijaz e627a69
refactor: add train/val/test transform options and docstring across m…
sundusaijaz 607eab1
refactor: update docstrings in atomistic transforms
sundusaijaz 2a6ce76
refactor: add docstrings in ASEAtomsData
sundusaijaz f0dd4c5
refactor: add docstrings for calculate_stats() and estimate_atomrefs()
sundusaijaz 4ca3648
refactor: simplify transform initialization in ASEAtomsData and updat…
sundusaijaz 93d9ce8
refactor: restructure configs for all datasets
sundusaijaz 99ba778
refactor: update omdb to support datamodulev2
sundusaijaz 5d44d74
refactor: streamline transform assignment and initialization in Atoms…
sundusaijaz 43e605d
refactor: update pytest test_stats to accept data and batch parameter…
sundusaijaz db0a96e
refactor: improve ANI1 dataset loading and validation
sundusaijaz 074bee9
refactor: enhance QM9 dataset loading
sundusaijaz d4ca4b0
refactor: simplify download method in ISO17 dataset
sundusaijaz 985032e
refactor: enhance MaterialsProject and update docstring
sundusaijaz 6bf81d3
refactor: optimize GDMLDataset download method in md17
sundusaijaz 6f30249
refactor: enhance QM7X dataset docstring and improve download methods
sundusaijaz b5b8969
refactor: improve rMD17 dataset loading and metadata handling
sundusaijaz 799d5a8
refactor: enhance ASEAtomsData and QM9 dataset handling
sundusaijaz 6ee9699
refactor: streamline metadata _check_db() and dataset creation in ASE…
sundusaijaz 86df4d8
refactor: enhance ASEAtomsData split transform and db creation
sundusaijaz 1690b19
refactor: streamline ANI1 and QM9 dataset handling and download methods
sundusaijaz 3962eb8
refactor: update ISO17 dataset download method and improve property u…
sundusaijaz 66ffe81
refactor: improve MaterialsProject API key validation and simplify do…
sundusaijaz f961526
refactor: streamline GDMLDataset methods and enhance metadata handlin…
sundusaijaz 716b34e
refactor: add type hint to _check_db() method in ASEAtomsData
sundusaijaz 2b51b1c
refactor: simplify download method in omdb
sundusaijaz ca88623
refactor: simplify QM7X download method and enhance metadata handling
sundusaijaz 6211431
refactor: remove unused imports and streamline rMD17 dataset methods
sundusaijaz cf4c406
refactor: remove unused split_id from rMD17 dataset configuration
sundusaijaz 252627e
refactor: adjust train and test split calculations in rMD17 dataset
sundusaijaz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,18 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.ANI1 | ||
| dataset: | ||
| _target_: schnetpack.datasets.ANI1 | ||
| datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml | ||
| num_heavy_atoms: 8 | ||
| high_energies: false | ||
| distance_unit: Ang | ||
| property_units: | ||
| energy: eV | ||
| transforms: ${data.transforms} | ||
|
|
||
|
|
||
| datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml | ||
| batch_size: 32 | ||
| num_train: 10000000 | ||
| num_val: 100000 | ||
| num_heavy_atoms: 8 | ||
| high_energies: False | ||
|
|
||
| # convert to typically used units | ||
| distance_unit: Ang | ||
| property_units: | ||
| energy: eV |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,25 @@ | ||
| _target_: schnetpack.data.AtomsDataModule | ||
| # @package data | ||
| _target_: schnetpack.data.datamodule_v2.AtomsDataModuleV2 | ||
|
|
||
| dataset: | ||
| _target_: schnetpack.data.ASEAtomsData | ||
| datapath: ??? | ||
| load_properties: null | ||
| distance_unit: Ang | ||
| property_units: {} | ||
| transforms: ${data.transforms} | ||
| train_transforms: null | ||
| val_transforms: null | ||
| test_transforms: null | ||
|
|
||
| datapath: ??? | ||
| data_workdir: null | ||
| batch_size: 10 | ||
| num_train: ??? | ||
| num_val: ??? | ||
| num_test: null | ||
| split_file: ${run.data_dir}/split.npz | ||
| splitting: null | ||
| num_workers: 8 | ||
| num_val_workers: null | ||
| num_test_workers: null | ||
| train_sampler_cls: null | ||
| train_sampler_cls: null | ||
| train_sampler_args: {} | ||
| pin_memory: false | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.ISO17 | ||
| dataset: | ||
| _target_: schnetpack.datasets.ISO17 | ||
| datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml | ||
| folder: reference | ||
|
|
||
| datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml | ||
| folder: reference | ||
| batch_size: 32 | ||
| num_train: 0.9 | ||
| num_val: 0.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.MaterialsProject | ||
| dataset: | ||
| _target_: schnetpack.datasets.MaterialsProject | ||
| datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml | ||
| apikey: ??? | ||
|
|
||
| datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml | ||
| batch_size: 32 | ||
| num_train: 60000 | ||
| num_val: 2000 | ||
| apikey: ??? | ||
| num_val: 2000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,14 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.MD17 | ||
|
|
||
| datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml | ||
| molecule: aspirin | ||
|
|
||
| dataset: | ||
| _target_: schnetpack.datasets.MD17 | ||
| datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml | ||
| molecule: ${data.molecule} | ||
|
|
||
| batch_size: 10 | ||
| num_train: 950 | ||
| num_val: 50 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,14 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.MD22 | ||
|
|
||
| datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml | ||
| molecule: Ac-Ala3-NHMe | ||
|
|
||
| dataset: | ||
| _target_: schnetpack.datasets.MD22 | ||
| datapath: ${run.data_dir}/${data.molecule}.db | ||
| molecule: ${data.molecule} | ||
|
|
||
| batch_size: 10 | ||
| num_train: 5700 | ||
| num_val: 300 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.OrganicMaterialsDatabase | ||
| dataset: | ||
| _target_: schnetpack.datasets.OrganicMaterialsDatabase | ||
| datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml | ||
| raw_path: null | ||
|
|
||
| datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml | ||
| batch_size: 32 | ||
| num_train: 0.8 | ||
| num_val: 0.1 | ||
| raw_path: null |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,14 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.QM7X | ||
| dataset: | ||
| _target_: schnetpack.datasets.QM7X | ||
| datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml | ||
| remove_duplicates: true | ||
| only_equilibrium: false | ||
| only_non_equilibrium: false | ||
|
|
||
| datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml | ||
| batch_size: 100 | ||
| num_train: 5550 | ||
| num_val: 700 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,26 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
|
|
||
| _target_: schnetpack.datasets.QM9 | ||
| dataset: | ||
| _target_: schnetpack.datasets.qm9.QM9 | ||
| datapath: ${run.data_dir}/qm9.db | ||
| remove_uncharacterized: true | ||
| load_properties: null | ||
| distance_unit: Ang | ||
| property_units: | ||
| energy_U0: eV | ||
| energy_U: eV | ||
| enthalpy_H: eV | ||
| free_energy: eV | ||
| homo: eV | ||
| lumo: eV | ||
| gap: eV | ||
| zpve: eV | ||
| transforms: ${data.transforms} | ||
|
|
||
| datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml | ||
| batch_size: 100 | ||
| num_train: 110000 | ||
| num_val: 10000 | ||
| remove_uncharacterized: True | ||
|
|
||
| # convert to typically used units | ||
| distance_unit: Ang | ||
| property_units: | ||
| energy_U0: eV | ||
| energy_U: eV | ||
| enthalpy_H: eV | ||
| free_energy: eV | ||
| homo: eV | ||
| lumo: eV | ||
| gap: eV | ||
| zpve: eV | ||
| num_test: 10000 | ||
| num_workers: 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,13 @@ | ||
| # @package data | ||
| defaults: | ||
| - custom | ||
| molecule: aspirin | ||
|
|
||
| _target_: schnetpack.datasets.rMD17 | ||
| dataset: | ||
| _target_: schnetpack.datasets.rMD17 | ||
| datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml | ||
| molecule: ${data.molecule} | ||
|
|
||
| datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml | ||
| molecule: aspirin | ||
| batch_size: 10 | ||
| num_train: 950 | ||
| num_val: 50 | ||
| split_id: null |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.