@@ -96,9 +96,9 @@ def __init__(
9696 self .prediction_kind = prediction_kind
9797 self .data_limit = data_limit
9898 self .label_filter = label_filter
99- assert (balance_after_filter is not None ) or (
100- self . label_filter is None
101- ), "Filter balancing requires a filter"
99+ assert (balance_after_filter is not None ) or (self . label_filter is None ), (
100+ "Filter balancing requires a filter"
101+ )
102102 self .balance_after_filter = balance_after_filter
103103 self .num_workers = num_workers
104104 self .persistent_workers : bool = bool (persistent_workers )
@@ -108,13 +108,13 @@ def __init__(
108108 self .use_inner_cross_validation = (
109109 inner_k_folds > 1
110110 ) # only use cv if there are at least 2 folds
111- assert (
112- fold_index is None or self . use_inner_cross_validation is not None
113- ), "fold_index can only be set if cross validation is used"
111+ assert fold_index is None or self . use_inner_cross_validation is not None , (
112+ " fold_index can only be set if cross validation is used"
113+ )
114114 if fold_index is not None and self .inner_k_folds is not None :
115- assert (
116- fold_index < self . inner_k_folds
117- ), "fold_index can't be larger than the total number of folds"
115+ assert fold_index < self . inner_k_folds , (
116+ " fold_index can't be larger than the total number of folds"
117+ )
118118 self .fold_index = fold_index
119119 self ._base_dir = base_dir
120120 self .n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137
138138 @property
139139 def feature_vector_size (self ):
140- assert (
141- self . _feature_vector_size is not None
142- ), "size of feature vector must be set"
140+ assert self . _feature_vector_size is not None , (
141+ "size of feature vector must be set"
142+ )
143143 return self ._feature_vector_size
144144
145145 @property
@@ -1173,9 +1173,7 @@ def _retrieve_splits_from_csv(self) -> None:
11731173 splits_df = pd .read_csv (self .splits_file_path )
11741174
11751175 filename = self .processed_file_names_dict ["data" ]
1176- data = self .load_processed_data_from_file (
1177- os .path .join (self .processed_dir , filename )
1178- )
1176+ data = self .load_processed_data_from_file (filename )
11791177 df_data = pd .DataFrame (data )
11801178
11811179 if self .apply_id_filter :
@@ -1255,7 +1253,9 @@ def load_processed_data(
12551253 return self .load_processed_data_from_file (filename )
12561254
12571255 def load_processed_data_from_file (self , filename ):
1258- return torch .load (os .path .join (filename ), weights_only = False )
1256+ return torch .load (
1257+ os .path .join (self .processed_dir , filename ), weights_only = False
1258+ )
12591259
12601260 # ------------------------------ Phase: Raw Properties -----------------------------------
12611261 @property
0 commit comments