diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index 2d1df51..4e0fc32 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -152,6 +152,41 @@ def _serialize_to_parquet(df: pd.DataFrame) -> tuple[bytes, str]: return parquet_bytes, crc32c_b64 +def _thinking_aware_dedup_hash( + content_hash: str, + *, + thinking_effort: str | None, + thinking_timeout_s: float | None, + thinking_metric: str | None, +) -> str: + """Return a dedup hash that partitions by thinking config. + + When thinking is disabled the content hash is returned + unchanged so non-thinking calls keep the existing dedup semantics. When + thinking is enabled, the thinking config is folded into the hash so that + `(dataset, thinking_config)` becomes the cache unit — same config hits + the existing fit, different config misses it. + """ + if ( + thinking_effort is None + and thinking_timeout_s is None + and thinking_metric is None + ): + return content_hash + discriminator = json.dumps( + { + "thinking_effort": thinking_effort, + "thinking_timeout_s": ( + # normalise to float for cache stability (0 vs 0.0) + float(thinking_timeout_s) if thinking_timeout_s is not None else None + ), + "thinking_effort_metric": thinking_metric, + }, + sort_keys=True, + ) + return _get_crc32c_hash(f"{content_hash}|{discriminator}".encode("utf-8")) + + class NeedsRefittingError(Exception): """ Exception raised when the server is not able to predict given the current state. @@ -315,12 +350,35 @@ def fit( f"the server limit of {limits.dataset_max_size_bytes} bytes." ) + thinking_enabled = bool(tabpfn_config) and ( + bool(tabpfn_config.get("thinking_mode")) + or tabpfn_config.get("thinking_effort") is not None + ) + if thinking_enabled: + thinking_effort = tabpfn_config.get("thinking_effort") or "medium" + thinking_timeout_s = tabpfn_config.get("thinking_timeout_s") + thinking_metric = tabpfn_config.get("thinking_metric") + else: + thinking_effort = None + thinking_timeout_s = None + thinking_metric = None + x_bytes, x_crc32c_hash = _serialize_to_parquet(df_X) y_bytes, y_crc32c_hash = _serialize_to_parquet(df_y) if dedup_datasets_enabled(): - x_dedup_hash = x_crc32c_hash - y_dedup_hash = y_crc32c_hash + x_dedup_hash = _thinking_aware_dedup_hash( + x_crc32c_hash, + thinking_effort=thinking_effort, + thinking_timeout_s=thinking_timeout_s, + thinking_metric=thinking_metric, + ) + y_dedup_hash = _thinking_aware_dedup_hash( + y_crc32c_hash, + thinking_effort=thinking_effort, + thinking_timeout_s=thinking_timeout_s, + thinking_metric=thinking_metric, + ) else: x_dedup_hash = None y_dedup_hash = None @@ -382,14 +440,6 @@ def fit( raise tabpfn_systems = ["preprocessing", "text"] - # Thinking is enabled when either flag is set: explicit `thinking_mode=True`, - # or any non-None `thinking_effort`. Setting `thinking_effort` alone is - # enough — the server-side validator on FitRequest also normalises this, - # but doing it here means the request body itself is consistent. - thinking_enabled = bool(tabpfn_config) and ( - bool(tabpfn_config.get("thinking_mode")) - or tabpfn_config.get("thinking_effort") is not None - ) if tabpfn_config: if tabpfn_config.get("paper_version") is True: tabpfn_systems = [] @@ -399,18 +449,10 @@ def fit( tabpfn_systems = ["preprocessing", "text", "thinking"] # The client-side `thinking_*` knobs forward 1:1 to the server's - # top-level FitRequest fields. When the user enabled thinking via - # `thinking_mode=True` without picking a level, default to "medium". - # The user-facing kwarg is `thinking_metric`; on the wire it is sent - # as `thinking_effort_metric` (matching the server's FitRequest schema). - if thinking_enabled and tabpfn_config: - thinking_effort = tabpfn_config.get("thinking_effort") or "medium" - thinking_timeout_s = tabpfn_config.get("thinking_timeout_s") - thinking_metric = tabpfn_config.get("thinking_metric") - else: - thinking_effort = None - thinking_timeout_s = None - thinking_metric = None + # top-level FitRequest fields. The user-facing kwarg is + # `thinking_metric`; on the wire it is sent as `thinking_effort_metric` + # (matching the server's FitRequest schema). The effective values were + # resolved above so they could feed into the dedup hash. # Strip client-only keys that the server does not expect (mirrors # the predict path's filter below). diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2c9ee98..1dc0709 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -11,6 +11,7 @@ GetModelLimitsResponse, NeedsRefittingError, ServiceClient, + _thinking_aware_dedup_hash, ) from tests.mock_tabpfn_server import with_mock_server @@ -530,3 +531,111 @@ def test_predict_converts_none_in_dict_prediction_to_nan(self, mock_server): np.array([[1.0, np.nan], [np.nan, 4.0]]), equal_nan=True, ) + + +class TestThinkingAwareDedupHash(unittest.TestCase): + """Pins the cache-partitioning rules for thinking-mode fits. + + Thinking mode is deterministic: same dataset + same thinking config must + collide on the server's dedup/fit cache (cache hit -> identical result), + but a *different* thinking config (e.g. effort medium -> high) must miss + so the fit actually runs at the requested effort. + """ + + CONTENT = "content-hash-abc" + + def test_no_thinking_returns_content_hash_unchanged(self): + # Preserves existing dedup behavior for non-thinking fits. + self.assertEqual( + _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort=None, + thinking_timeout_s=None, + thinking_metric=None, + ), + self.CONTENT, + ) + + def test_same_thinking_config_is_stable(self): + # Two calls with identical (dataset, thinking config) must hash to the + # same value so the server's cache hits. + h1 = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=60.0, + thinking_metric="rmse", + ) + h2 = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=60.0, + thinking_metric="rmse", + ) + self.assertEqual(h1, h2) + + def test_effort_change_partitions_cache(self): + # The bug fix: medium -> high on the same dataset must NOT collide. + medium = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=None, + thinking_metric=None, + ) + high = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="high", + thinking_timeout_s=None, + thinking_metric=None, + ) + self.assertNotEqual(medium, high) + + def test_timeout_and_metric_also_partition(self): + base = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=None, + thinking_metric=None, + ) + different_timeout = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=120.0, + thinking_metric=None, + ) + different_metric = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=None, + thinking_metric="rmse", + ) + self.assertNotEqual(base, different_timeout) + self.assertNotEqual(base, different_metric) + self.assertNotEqual(different_timeout, different_metric) + + def test_int_and_float_timeout_hash_identically(self): + # `json.dumps(60)` and `json.dumps(60.0)` differ; the helper normalizes + # so callers don't suffer spurious cache misses on equivalent values. + h_int = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=60, + thinking_metric=None, + ) + h_float = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=60.0, + thinking_metric=None, + ) + self.assertEqual(h_int, h_float) + + def test_thinking_hash_differs_from_content_hash(self): + # Enabling thinking must change the hash, otherwise a prior + # non-thinking fit on the same dataset would be served. + with_thinking = _thinking_aware_dedup_hash( + self.CONTENT, + thinking_effort="medium", + thinking_timeout_s=None, + thinking_metric=None, + ) + self.assertNotEqual(with_thinking, self.CONTENT)