|
1 | | -from typing import Optional, List, Callable |
| 1 | +import importlib |
| 2 | +import os |
| 3 | +from typing import Callable, List, Optional |
2 | 4 |
|
| 5 | +import pandas as pd |
| 6 | +import torch |
| 7 | +import tqdm |
| 8 | +from chebai.preprocessing.datasets.base import XYBaseDataModule |
3 | 9 | from chebai.preprocessing.datasets.chebi import ( |
4 | 10 | ChEBIOver50, |
5 | 11 | ChEBIOver100, |
6 | 12 | ChEBIOverXPartial, |
7 | 13 | ) |
8 | | -from chebai.preprocessing.datasets.base import XYBaseDataModule |
9 | 14 | from lightning_utilities.core.rank_zero import rank_zero_info |
| 15 | +from torch_geometric.data.data import Data as GeomData |
10 | 16 |
|
11 | | -from chebai_graph.preprocessing.reader import GraphReader, GraphPropertyReader |
| 17 | +import chebai_graph.preprocessing.properties as graph_properties |
12 | 18 | from chebai_graph.preprocessing.properties import ( |
13 | 19 | AtomProperty, |
14 | 20 | BondProperty, |
15 | 21 | MolecularProperty, |
16 | 22 | ) |
17 | | -import pandas as pd |
18 | | -from torch_geometric.data.data import Data as GeomData |
19 | | -import torch |
20 | | -import chebai_graph.preprocessing.properties as graph_properties |
21 | | -import importlib |
22 | | -import os |
23 | | -import tqdm |
| 23 | +from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader |
24 | 24 |
|
25 | 25 |
|
26 | 26 | class ChEBI50GraphData(ChEBIOver50): |
@@ -84,18 +84,20 @@ def _setup_properties(self): |
84 | 84 | for file in file_names: |
85 | 85 | # processed_dir_main only exists for ChEBI datasets |
86 | 86 | path = os.path.join( |
87 | | - self.processed_dir_main |
88 | | - if hasattr(self, "processed_dir_main") |
89 | | - else self.raw_dir, |
| 87 | + ( |
| 88 | + self.processed_dir_main |
| 89 | + if hasattr(self, "processed_dir_main") |
| 90 | + else self.raw_dir |
| 91 | + ), |
90 | 92 | file, |
91 | 93 | ) |
92 | 94 | raw_data += list(self._load_dict(path)) |
93 | 95 | idents = [row["ident"] for row in raw_data] |
94 | 96 | features = [row["features"] for row in raw_data] |
95 | 97 |
|
96 | 98 | # use vectorized version of encode function, apply only if value is present |
97 | | - enc_if_not_none = ( |
98 | | - lambda encode, value: [encode(atom_v) for atom_v in value] |
| 99 | + enc_if_not_none = lambda encode, value: ( |
| 100 | + [encode(atom_v) for atom_v in value] |
99 | 101 | if value is not None and len(value) > 0 |
100 | 102 | else None |
101 | 103 | ) |
@@ -134,11 +136,14 @@ def get_property_path(self, property: MolecularProperty): |
134 | 136 | f"{property.name}_{property.encoder.name}.pt", |
135 | 137 | ) |
136 | 138 |
|
137 | | - def setup(self, **kwargs): |
138 | | - super().setup(keep_reader=True, **kwargs) |
139 | | - self._setup_properties() |
| 139 | + def _after_setup(self, **kwargs): |
| 140 | + """ |
| 141 | + Finalize the setup process after ensuring the processed data is available. |
140 | 142 |
|
141 | | - self.reader.on_finish() |
| 143 | + This method performs post-setup tasks like finalizing the reader and setting internal properties. |
| 144 | + """ |
| 145 | + self._setup_properties() |
| 146 | + super()._after_setup(**kwargs) |
142 | 147 |
|
143 | 148 | def _merge_props_into_base(self, row): |
144 | 149 | geom_data = row["features"] |
|
0 commit comments