22from abc import ABC
33from collections .abc import Callable
44from pprint import pformat
5+ from typing import Optional
56
67import pandas as pd
78import torch
@@ -281,7 +282,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
281282 molecule_attr = molecule_attr ,
282283 )
283284
284- def load_processed_data_from_file (self , filename : str ) -> list [dict ]:
285+ def load_processed_data (
286+ self , kind : Optional [str ] = None , filename : Optional [str ] = None
287+ ) -> list [dict ]:
285288 """
286289 Load dataset and merge cached properties into base features.
287290
@@ -291,7 +294,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
291294 Returns:
292295 List of data entries, each a dictionary.
293296 """
294- base_data = super ().load_processed_data_from_file ( filename )
297+ base_data = super ().load_processed_data ( kind , filename )
295298 base_df = pd .DataFrame (base_data )
296299
297300 for property in self .properties :
@@ -379,7 +382,9 @@ def __init__(self, properties=None, transform=None, **kwargs):
379382 f"Data module uses these properties (ordered): { ', ' .join ([str (p ) for p in self .properties ])} " ,
380383 )
381384
382- def load_processed_data_from_file (self , filename : str ) -> list [dict ]:
385+ def load_processed_data (
386+ self , kind : Optional [str ] = None , filename : Optional [str ] = None
387+ ) -> list [dict ]:
383388 """
384389 Load dataset and merge cached properties into base features.
385390
@@ -389,9 +394,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
389394 Returns:
390395 List of data entries, each a dictionary.
391396 """
392- base_data = super ().load_processed_data_from_file ( filename )
397+ base_data = super ().load_processed_data ( kind , filename )
393398 base_df = pd .DataFrame (base_data )
394-
395399 props_categories = {
396400 "AllNodeTypeProperties" : [],
397401 "FGNodeTypeProperties" : [],
@@ -442,6 +446,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
442446 )
443447
444448 for property in self .properties :
449+ rank_zero_info (f"Loading property { property .name } ..." )
445450 property_data = torch .load (
446451 self .get_property_path (property ), weights_only = False
447452 )
0 commit comments