Skip to content

Adding thinking effort to the deduplication logic#276

Closed
eliott-kalfon wants to merge 3 commits into
mainfrom
eliott/caching_fix
Closed

Adding thinking effort to the deduplication logic#276
eliott-kalfon wants to merge 3 commits into
mainfrom
eliott/caching_fix

Conversation

@eliott-kalfon
Copy link
Copy Markdown

Change Description

Including the thinking effort in the cache deduplication logic. So that users get different results after generating predictions for medium effort, and then trying high effort.

If you used new dependencies: Did you add them to requirements.txt? No dependency added

Breaking changes

No breaking changes.

If you made any breaking changes, please update the version number.
Breaking changes are totally fine, we just need to make sure to keep the users informed and the server in sync.

Does this PR break the API? If so, what is the corresponding server commit?

Does this PR break the user interface? If so, why?


Please do not mark comments/conversations as resolved unless you are the assigned reviewer. This helps maintain clarity during the review process.

@eliott-kalfon eliott-kalfon requested a review from a team as a code owner May 12, 2026 02:49
@eliott-kalfon eliott-kalfon requested review from ggprior and removed request for a team May 12, 2026 02:49
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@eliott-kalfon eliott-kalfon changed the title initial commit Adding thinking effort to the deduplication logic May 12, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a thinking-aware deduplication hash to ensure that changes in thinking configuration (effort, timeout, and metric) correctly partition the server-side fit cache. The fit method was refactored to resolve these configurations earlier and apply the new hashing logic. Feedback suggests improving the stability of the hash by casting the timeout to a float and using consistent key names for the wire format. Additionally, a redundant check in the configuration resolution logic was identified.

Comment thread src/tabpfn_client/client.py
Comment thread src/tabpfn_client/client.py Outdated
@ggprior
Copy link
Copy Markdown
Contributor

ggprior commented May 12, 2026

Thanks, @eliott-kalfon !

Using the following reproducer on your branch:

import numpy as np
import openml
from tabpfn_client import init, TabPFNClassifier, TabPFNRegressor
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    log_loss,
    confusion_matrix,
    classification_report,
)

# OpenML task 168868: APSFailure (binary classification, ~60k rows, 170 numeric features)
task = openml.tasks.get_task(168868)
X, y = task.get_X_and_y(dataset_format="array")
train_idx, test_idx = task.get_train_test_split_indices(fold=0, repeat=0, sample=0)
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

# Shuffle to change the train-set hash (workaround for upload-cache conflict)
rng = np.random.default_rng(0)
perm_train = rng.permutation(len(X_train))
perm_test = rng.permutation(len(X_test))
X_train, y_train = X_train[perm_train], y_train[perm_train]
X_test, y_test = X_test[perm_test], y_test[perm_test]

print(
    f"Dataset: APSFailure (OpenML task 168868, fold 0) "
    f"n={X.shape[0]}, d={X.shape[1]}, "
    f"train/test={X_train.shape[0]}/{X_test.shape[0]}, "
    f"classes={np.bincount(y.astype(int)).tolist()}"
)


def evaluate(name, model):
    model.fit(X_train, y_train)
    predictions = model.predict(X_test)
    probabilities = model.predict_proba(X_test)
    return {
        "name": name,
        "accuracy": accuracy_score(y_test, predictions),
        "f1": f1_score(y_test, predictions),
        "roc_auc": roc_auc_score(y_test, probabilities[:, 1]),
        "log_loss": log_loss(y_test, probabilities),
    }


runs = [
    evaluate("default", TabPFNClassifier()),
    evaluate(
        "thinking_effort=medium",
        TabPFNClassifier(
            thinking_effort="medium",
            thinking_timeout_s=600,
            thinking_metric="roc_auc",
        ),
    ),
    evaluate(
        "thinking_effort=high",
        TabPFNClassifier(
            thinking_effort="high",
            thinking_timeout_s=600,
            thinking_metric="roc_auc",
        ),
    ),
]

baseline_auc = runs[0]["roc_auc"]
print()
print(f"{'run':<25} {'acc':>7} {'f1':>7} {'roc_auc':>9} {'Δauc':>8} {'logloss':>9}")
for r in runs:
    delta = r["roc_auc"] - baseline_auc
    print(
        f"{r['name']:<25} {r['accuracy']:>7.4f} {r['f1']:>7.4f} "
        f"{r['roc_auc']:>9.4f} {delta:>+8.4f} {r['log_loss']:>9.4f}"
    )


I'm getting

georggrab@Georgs-MacBook-Pro-2 tmp-client3 % uv run main.py 
/Users/georggrab/code/tmp-client3/main.py:17: FutureWarning: Support for `dataset_format='array'` will be removed in 0.15,start using `dataset_format='dataframe' to ensure your code will continue to work. You can use the dataframe's `to_numpy` function to continue using numpy arrays.
  X, y = task.get_X_and_y(dataset_format="array")
/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/openml/tasks/task.py:334: FutureWarning: Support for `dataset_format='array'` will be removed in 0.15,start using `dataset_format='dataframe' to ensure your code will continue to work. You can use the dataframe's `to_numpy` function to continue using numpy arrays.
  X, y, _, _ = dataset.get_data(
Dataset: APSFailure (OpenML task 168868, fold 0) n=76000, d=170, train/test=68400/7600, classes=[1375, 74625]
Found existing access token, reusing it for authentication.
00:00 Fitting... \The provided train set hashes match previously uploaded train sets.
00:00 Fitting... Done!
00:01 Predicting... \The provided test set hash matches a previously uploaded test set.
00:25 Predicting... Done!
00:00 Predicting... -The provided test set hash matches a previously uploaded test set.
00:23 Predicting... Done!
/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:3001: UserWarning: The y_pred values do not sum to one. Make sure to pass probabilities.
  warnings.warn(
00:05 Fitting... \Traceback (most recent call last):
  File "/Users/georggrab/code/tmp-client3/main.py", line 52, in <module>
    evaluate(
    ~~~~~~~~^
        "thinking_effort=medium",
        ^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<4 lines>...
        ),
        ^^
    ),
    ^
  File "/Users/georggrab/code/tmp-client3/main.py", line 38, in evaluate
    model.fit(X_train, y_train)
    ~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/estimator.py", line 360, in fit
    self.last_fitted_train_set_id = run_task(fit_task, "Fitting")
                                    ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/estimator.py", line 960, in run_task
    result = future.result()
  File "/Users/georggrab/.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ~~~~~~~~~~~~~~~~~^^
  File "/Users/georggrab/.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/Users/georggrab/.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/estimator.py", line 350, in fit_task
    return InferenceClient.fit(
           ~~~~~~~~~~~~~~~~~~~^
        X,
        ^^
    ...<5 lines>...
        client_options=self.client_options,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/service_wrapper.py", line 270, in fit
    return ServiceClient.fit(
           ~~~~~~~~~~~~~~~~~^
        X,
        ^^
    ...<5 lines>...
        client_options=client_options,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/client.py", line 492, in fit
    cls._validate_response(
    ~~~~~~~~~~~~~~~~~~~~~~^
        res,
        ^^^^
        "fit",
        ^^^^^^
        response_models={200: FitResponse},
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ),
    ^
  File "/Users/georggrab/code/tmp-client3/.venv/lib/python3.13/site-packages/tabpfn_client/client.py", line 882, in _validate_response
    raise RuntimeError(
        f"Fail to call {method_name} with error: {status_code}, {message}"
    )
RuntimeError: Fail to call fit with error: 409, There was a conflict while composing the upload, multiple parallel requests might be targeting the same set. Wait for all conflicting requests to complete and try again.

Did you see the same during your debugging?
I'm guessing busting the client-side cache is not enough here, the cache needs to be invalidated in the server as well.

@ggprior
Copy link
Copy Markdown
Contributor

ggprior commented May 12, 2026

Thanks again for the great catch! It turns out this issue is better solved in the server-side change detection logic, which I've just shipped out. So closing the client PR. Happy if you can retest though if you have your scripts handy :D

@ggprior ggprior closed this May 12, 2026
@eliott-kalfon
Copy link
Copy Markdown
Author

Thanks again for the great catch! It turns out this issue is better solved in the server-side change detection logic, which I've just shipped out. So closing the client PR. Happy if you can retest though if you have your scripts handy :D

Hi @ggprior, tested with my scripts again, works for me! Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants