diff --git a/src/tabpfn_extensions/scoring/scoring_utils.py b/src/tabpfn_extensions/scoring/scoring_utils.py index 57ed5156..3d6f52e4 100644 --- a/src/tabpfn_extensions/scoring/scoring_utils.py +++ b/src/tabpfn_extensions/scoring/scoring_utils.py @@ -130,10 +130,10 @@ def score_classification( if optimize_metric is None: optimize_metric = "roc" - if (optimize_metric == "roc") and len(np.unique(y_true)) == 2: + if (optimize_metric in ("roc", "auroc")) and len(np.unique(y_true)) == 2: y_pred = y_pred[:, 1] - if (not y_pred_is_labels) and (optimize_metric not in ["roc", "log_loss"]): + if (not y_pred_is_labels) and (optimize_metric not in ["roc", "auroc", "log_loss"]): y_pred = np.argmax(y_pred, axis=1) if optimize_metric in ("roc", "auroc"):