Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/tabpfn_extensions/scoring/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def score_classification(
if optimize_metric is None:
optimize_metric = "roc"

if (optimize_metric == "roc") and len(np.unique(y_true)) == 2:
if (optimize_metric in ("roc", "auroc")) and len(np.unique(y_true)) == 2:
y_pred = y_pred[:, 1]
Comment on lines +133 to 134
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block at lines 133-134 is redundant and potentially problematic for several reasons:

  1. Redundancy: safe_roc_auc_score already handles binary classification by checking the shape of y_pred (lines 60-61).
  2. Efficiency: Calling np.unique(y_true) is computationally expensive for large arrays and it is already called inside safe_roc_auc_score (line 48).
  3. Correctness: Slicing y_pred[:, 1] based solely on the number of unique classes in y_true can lead to incorrect results in multiclass problems where a specific subset of data happens to contain only two classes. In such cases, index 1 might not correspond to the correct 'positive' class. safe_roc_auc_score is designed to handle this correctly by identifying and adjusting for missing classes.

Normalizing auroc to roc here simplifies the logic and ensures consistency across all subsequent checks.

Suggested change
if (optimize_metric in ("roc", "auroc")) and len(np.unique(y_true)) == 2:
y_pred = y_pred[:, 1]
if optimize_metric == "auroc":
optimize_metric = "roc"


if (not y_pred_is_labels) and (optimize_metric not in ["roc", "log_loss"]):
if (not y_pred_is_labels) and (optimize_metric not in ["roc", "auroc", "log_loss"]):
y_pred = np.argmax(y_pred, axis=1)

Comment on lines +136 to 138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check can be simplified now that auroc is normalized to roc.

    if (not y_pred_is_labels) and (optimize_metric not in ["roc", "log_loss"]):

if optimize_metric in ("roc", "auroc"):
Expand Down
Loading