Skip to content

Commit 1303fd5

Browse files
committed
v0.6.0:
Add 3 metrics for regression tasks: Spearman's r, Pearson's r, and Kendall's tau.
1 parent 94abec6 commit 1303fd5

4 files changed

Lines changed: 13 additions & 8 deletions

File tree

molalkit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# -*- coding: utf-8 -*-
33

44

5-
__version__ = '0.5.2'
5+
__version__ = '0.6.0'

molalkit/al/learner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from logging import Logger
1111
from sklearn.metrics import *
12+
import scipy
1213
from ..args import Metric
1314
from .selection_method import BaseSelectionMethod, RandomSelectionMethod
1415
from .forgetter import BaseForgetter, RandomForgetter, FirstForgetter
@@ -40,9 +41,15 @@ def eval_metric_func(y, y_pred, metric: str) -> float:
4041
elif metric == 'mse':
4142
return mean_squared_error(y, y_pred)
4243
elif metric == 'rmse':
43-
return np.sqrt(eval_metric_func(y, y_pred, 'mse'))
44+
return mean_squared_error(y, y_pred, squared=False)
4445
elif metric == 'max':
4546
return np.max(abs(y - y_pred))
47+
elif metric == 'spearman':
48+
return scipy.stats.spearmanr(y, y_pred)[0]
49+
elif metric == 'kendall':
50+
return scipy.stats.kendalltau(y, y_pred)[0]
51+
elif metric == 'pearson':
52+
return scipy.stats.pearsonr(y, y_pred)[0]
4653
else:
4754
raise RuntimeError(f'Unsupported metrics {metric}')
4855

@@ -254,7 +261,7 @@ def evaluate(self, alr: ActiveLearningResult):
254261
pd.DataFrame({'true': self.dataset_val_selector.y.ravel(), 'pred': y_pred}).to_csv(
255262
os.path.join(self.save_dir, f'selector_{self.current_iter}.csv'), index=False)
256263
for metric in self.metrics:
257-
metric_value = eval_metric_func(self.dataset_val_selector.y, y_pred, metric=metric)
264+
metric_value = eval_metric_func(self.dataset_val_selector.y.ravel(), y_pred, metric=metric)
258265
alr.results[f'{metric}_selector'] = metric_value
259266
# evaluate the percentage of top-k data selected in the training set
260267
if self.top_k_id is not None:
@@ -268,7 +275,7 @@ def evaluate(self, alr: ActiveLearningResult):
268275
pd.DataFrame({'true': self.dataset_val_selector.y.ravel(), 'pred': y_pred}).to_csv(
269276
os.path.join(self.save_dir, f'evaluator_{i}_{self.current_iter}.csv'), index=False)
270277
for metric in self.metrics:
271-
metric_value = eval_metric_func(self.dataset_val_evaluators[i].y, y_pred, metric=metric)
278+
metric_value = eval_metric_func(self.dataset_val_evaluators[i].y.ravel(), y_pred, metric=metric)
272279
alr.results[f'{metric}_evaluator_{i}'] = metric_value
273280
self.info('evaluating model performance finished.')
274281

molalkit/args.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
import numpy as np
1212
from mgktools.features_mol import FeaturesGenerator
1313
from mgktools.data.split import data_split_index
14+
from mgktools.evaluators.metric import Metric
1415
from molalkit.logging import create_logger
1516
from molalkit.utils import get_data, get_model, get_kernel
1617
from molalkit.data.datasets import DATA_DIR
1718
from molalkit.al.selection_method import *
1819
from molalkit.al.forgetter import *
19-
2020
CWD = os.path.dirname(os.path.abspath(__file__))
21-
Metric = Literal['roc-auc', 'accuracy', 'precision', 'recall', 'f1_score', 'mcc',
22-
'rmse', 'mae', 'mse', 'r2', 'max']
2321

2422

2523
class CommonArgs(Tap):

test/test_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_classification(metric):
3131
al_results_check(save_dir)
3232

3333

34-
@pytest.mark.parametrize('metric', ['rmse', 'mae', 'mse', 'r2', 'max'])
34+
@pytest.mark.parametrize('metric', ['rmse', 'mae', 'mse', 'r2', 'max', 'spearman', 'kendall', 'pearson'])
3535
def test_regression(metric):
3636
save_dir = os.path.join(CWD, 'test')
3737
arguments = [

0 commit comments

Comments
 (0)