99import numpy as np
1010from logging import Logger
1111from sklearn .metrics import *
12+ import scipy
1213from ..args import Metric
1314from .selection_method import BaseSelectionMethod , RandomSelectionMethod
1415from .forgetter import BaseForgetter , RandomForgetter , FirstForgetter
@@ -40,9 +41,15 @@ def eval_metric_func(y, y_pred, metric: str) -> float:
4041 elif metric == 'mse' :
4142 return mean_squared_error (y , y_pred )
4243 elif metric == 'rmse' :
43- return np . sqrt ( eval_metric_func ( y , y_pred , 'mse' ) )
44+ return mean_squared_error ( y , y_pred , squared = False )
4445 elif metric == 'max' :
4546 return np .max (abs (y - y_pred ))
47+ elif metric == 'spearman' :
48+ return scipy .stats .spearmanr (y , y_pred )[0 ]
49+ elif metric == 'kendall' :
50+ return scipy .stats .kendalltau (y , y_pred )[0 ]
51+ elif metric == 'pearson' :
52+ return scipy .stats .pearsonr (y , y_pred )[0 ]
4653 else :
4754 raise RuntimeError (f'Unsupported metrics { metric } ' )
4855
@@ -254,7 +261,7 @@ def evaluate(self, alr: ActiveLearningResult):
254261 pd .DataFrame ({'true' : self .dataset_val_selector .y .ravel (), 'pred' : y_pred }).to_csv (
255262 os .path .join (self .save_dir , f'selector_{ self .current_iter } .csv' ), index = False )
256263 for metric in self .metrics :
257- metric_value = eval_metric_func (self .dataset_val_selector .y , y_pred , metric = metric )
264+ metric_value = eval_metric_func (self .dataset_val_selector .y . ravel () , y_pred , metric = metric )
258265 alr .results [f'{ metric } _selector' ] = metric_value
259266 # evaluate the percentage of top-k data selected in the training set
260267 if self .top_k_id is not None :
@@ -268,7 +275,7 @@ def evaluate(self, alr: ActiveLearningResult):
268275 pd .DataFrame ({'true' : self .dataset_val_selector .y .ravel (), 'pred' : y_pred }).to_csv (
269276 os .path .join (self .save_dir , f'evaluator_{ i } _{ self .current_iter } .csv' ), index = False )
270277 for metric in self .metrics :
271- metric_value = eval_metric_func (self .dataset_val_evaluators [i ].y , y_pred , metric = metric )
278+ metric_value = eval_metric_func (self .dataset_val_evaluators [i ].y . ravel () , y_pred , metric = metric )
272279 alr .results [f'{ metric } _evaluator_{ i } ' ] = metric_value
273280 self .info ('evaluating model performance finished.' )
274281
0 commit comments