Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions src/tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,41 @@ def _serialize_to_parquet(df: pd.DataFrame) -> tuple[bytes, str]:
return parquet_bytes, crc32c_b64


def _thinking_aware_dedup_hash(
content_hash: str,
*,
thinking_effort: str | None,
thinking_timeout_s: float | None,
thinking_metric: str | None,
) -> str:
"""Return a dedup hash that partitions by thinking config.

When thinking is disabled the content hash is returned
unchanged so non-thinking calls keep the existing dedup semantics. When
thinking is enabled, the thinking config is folded into the hash so that
`(dataset, thinking_config)` becomes the cache unit — same config hits
the existing fit, different config misses it.
"""
if (
thinking_effort is None
and thinking_timeout_s is None
and thinking_metric is None
):
return content_hash
discriminator = json.dumps(
{
"thinking_effort": thinking_effort,
"thinking_timeout_s": (
# normalise to float for cache stability (0 vs 0.0)
float(thinking_timeout_s) if thinking_timeout_s is not None else None
),
"thinking_effort_metric": thinking_metric,
},
sort_keys=True,
)
Comment thread
eliott-kalfon marked this conversation as resolved.
return _get_crc32c_hash(f"{content_hash}|{discriminator}".encode("utf-8"))


class NeedsRefittingError(Exception):
"""
Exception raised when the server is not able to predict given the current state.
Expand Down Expand Up @@ -315,12 +350,35 @@ def fit(
f"the server limit of {limits.dataset_max_size_bytes} bytes."
)

thinking_enabled = bool(tabpfn_config) and (
bool(tabpfn_config.get("thinking_mode"))
or tabpfn_config.get("thinking_effort") is not None
)
if thinking_enabled:
thinking_effort = tabpfn_config.get("thinking_effort") or "medium"
thinking_timeout_s = tabpfn_config.get("thinking_timeout_s")
thinking_metric = tabpfn_config.get("thinking_metric")
else:
thinking_effort = None
thinking_timeout_s = None
thinking_metric = None

x_bytes, x_crc32c_hash = _serialize_to_parquet(df_X)
y_bytes, y_crc32c_hash = _serialize_to_parquet(df_y)

if dedup_datasets_enabled():
x_dedup_hash = x_crc32c_hash
y_dedup_hash = y_crc32c_hash
x_dedup_hash = _thinking_aware_dedup_hash(
x_crc32c_hash,
thinking_effort=thinking_effort,
thinking_timeout_s=thinking_timeout_s,
thinking_metric=thinking_metric,
)
y_dedup_hash = _thinking_aware_dedup_hash(
y_crc32c_hash,
thinking_effort=thinking_effort,
thinking_timeout_s=thinking_timeout_s,
thinking_metric=thinking_metric,
)
else:
x_dedup_hash = None
y_dedup_hash = None
Expand Down Expand Up @@ -382,14 +440,6 @@ def fit(
raise

tabpfn_systems = ["preprocessing", "text"]
# Thinking is enabled when either flag is set: explicit `thinking_mode=True`,
# or any non-None `thinking_effort`. Setting `thinking_effort` alone is
# enough — the server-side validator on FitRequest also normalises this,
# but doing it here means the request body itself is consistent.
thinking_enabled = bool(tabpfn_config) and (
bool(tabpfn_config.get("thinking_mode"))
or tabpfn_config.get("thinking_effort") is not None
)
if tabpfn_config:
if tabpfn_config.get("paper_version") is True:
tabpfn_systems = []
Expand All @@ -399,18 +449,10 @@ def fit(
tabpfn_systems = ["preprocessing", "text", "thinking"]

# The client-side `thinking_*` knobs forward 1:1 to the server's
# top-level FitRequest fields. When the user enabled thinking via
# `thinking_mode=True` without picking a level, default to "medium".
# The user-facing kwarg is `thinking_metric`; on the wire it is sent
# as `thinking_effort_metric` (matching the server's FitRequest schema).
if thinking_enabled and tabpfn_config:
thinking_effort = tabpfn_config.get("thinking_effort") or "medium"
thinking_timeout_s = tabpfn_config.get("thinking_timeout_s")
thinking_metric = tabpfn_config.get("thinking_metric")
else:
thinking_effort = None
thinking_timeout_s = None
thinking_metric = None
# top-level FitRequest fields. The user-facing kwarg is
# `thinking_metric`; on the wire it is sent as `thinking_effort_metric`
# (matching the server's FitRequest schema). The effective values were
# resolved above so they could feed into the dedup hash.

# Strip client-only keys that the server does not expect (mirrors
# the predict path's filter below).
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
GetModelLimitsResponse,
NeedsRefittingError,
ServiceClient,
_thinking_aware_dedup_hash,
)
from tests.mock_tabpfn_server import with_mock_server

Expand Down Expand Up @@ -530,3 +531,111 @@ def test_predict_converts_none_in_dict_prediction_to_nan(self, mock_server):
np.array([[1.0, np.nan], [np.nan, 4.0]]),
equal_nan=True,
)


class TestThinkingAwareDedupHash(unittest.TestCase):
"""Pins the cache-partitioning rules for thinking-mode fits.

Thinking mode is deterministic: same dataset + same thinking config must
collide on the server's dedup/fit cache (cache hit -> identical result),
but a *different* thinking config (e.g. effort medium -> high) must miss
so the fit actually runs at the requested effort.
"""

CONTENT = "content-hash-abc"

def test_no_thinking_returns_content_hash_unchanged(self):
# Preserves existing dedup behavior for non-thinking fits.
self.assertEqual(
_thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort=None,
thinking_timeout_s=None,
thinking_metric=None,
),
self.CONTENT,
)

def test_same_thinking_config_is_stable(self):
# Two calls with identical (dataset, thinking config) must hash to the
# same value so the server's cache hits.
h1 = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=60.0,
thinking_metric="rmse",
)
h2 = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=60.0,
thinking_metric="rmse",
)
self.assertEqual(h1, h2)

def test_effort_change_partitions_cache(self):
# The bug fix: medium -> high on the same dataset must NOT collide.
medium = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=None,
thinking_metric=None,
)
high = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="high",
thinking_timeout_s=None,
thinking_metric=None,
)
self.assertNotEqual(medium, high)

def test_timeout_and_metric_also_partition(self):
base = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=None,
thinking_metric=None,
)
different_timeout = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=120.0,
thinking_metric=None,
)
different_metric = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=None,
thinking_metric="rmse",
)
self.assertNotEqual(base, different_timeout)
self.assertNotEqual(base, different_metric)
self.assertNotEqual(different_timeout, different_metric)

def test_int_and_float_timeout_hash_identically(self):
# `json.dumps(60)` and `json.dumps(60.0)` differ; the helper normalizes
# so callers don't suffer spurious cache misses on equivalent values.
h_int = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=60,
thinking_metric=None,
)
h_float = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=60.0,
thinking_metric=None,
)
self.assertEqual(h_int, h_float)

def test_thinking_hash_differs_from_content_hash(self):
# Enabling thinking must change the hash, otherwise a prior
# non-thinking fit on the same dataset would be served.
with_thinking = _thinking_aware_dedup_hash(
self.CONTENT,
thinking_effort="medium",
thinking_timeout_s=None,
thinking_metric=None,
)
self.assertNotEqual(with_thinking, self.CONTENT)
Loading