Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
]
license = { file = "LICENSE" }
dependencies = [
"httpx[http2]>=0.25.0,<=0.28.1",
"httpx>=0.25.0,<=0.28.1",
"omegaconf>=2.1.2,<=2.3.0",
"pandas>=2.1.2,<=2.3.3",
"password-strength>=0.0.3.post2,<=0.0.3.post2",
Expand Down
34 changes: 14 additions & 20 deletions src/tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import backoff
import httpx
from httpx._transports.default import HTTPTransport
from omegaconf import OmegaConf
from tabpfn_client.browser_auth import BrowserAuthHandler
from tabpfn_client.constants import (
Expand Down Expand Up @@ -187,22 +186,6 @@ class PredictionResult:
metadata: dict[str, Any] = field(default_factory=dict)


class SelectiveHTTP2Transport(HTTPTransport):
def __init__(self, http2_paths=None, *args, **kwargs):
self.http2_paths = http2_paths or []
self.http1 = HTTPTransport(http2=False, *args, **kwargs)
self.http2 = HTTPTransport(http2=True, *args, **kwargs)

def handle_request(self, request):
if request.url.path in self.http2_paths:
return self.http2.handle_request(request)
return self.http1.handle_request(request)

def close(self) -> None:
self.http1.close()
self.http2.close()


class ServiceClient(Singleton):
"""
Singleton class for handling communication with the server.
Expand All @@ -215,13 +198,24 @@ class ServiceClient(Singleton):
TABPFN_CLIENT_API_URL
or f"{server_config.protocol}://{server_config.host}:{server_config.port}"
)
fit_path = SERVER_CONFIG["endpoints"]["fit"]["path"]
predict_path = SERVER_CONFIG["endpoints"]["predict"]["path"]
# NOTE: HTTP/1.1 only. HTTP/2 used to be selectively enabled for the
# /tabpfn/fit and /tabpfn/predict endpoints, but the long-running
# thinking-mode fit kept the stream open for 5-15 min, which raced
# against intermediate keepalive PINGs from Cloud Run's LB. The
# `h2` state machine treats a PING received while the connection
# is CLOSED as a protocol violation and surfaces it as
# `httpx.LocalProtocolError("Invalid input ConnectionInputs.RECV_PING
# in state ConnectionState.CLOSED")` — which is NOT in the SDK's
# retry tuple, so the request fails hard instead of retrying.
# HTTP/1.1 has no PING frames and no equivalent state machine, so
# the race disappears. Unary POSTs against /fit and /predict don't
# benefit from HTTP/2's multiplexing or HPACK in any measurable way
# (one request per fit, dominated by the multipart body), so the
# tradeoff is one-sided.
httpx_client = httpx.Client(
base_url=base_url,
timeout=TABPFN_CLIENT_TIMEOUT,
headers={"Prior-Client-Version": get_client_version()},
transport=SelectiveHTTP2Transport(http2_paths=[fit_path, predict_path]),
follow_redirects=True,
)
_access_token = None
Expand Down
Loading