From d1c7c65feeb84ed828114538cc08e9a966ac8582 Mon Sep 17 00:00:00 2001 From: Viraat Chandra Date: Mon, 22 Jun 2026 14:52:54 -0700 Subject: [PATCH] feat(client): round-robin outbound connections across source IPs Streaming LLM benchmarks hold one TCP connection per in-flight request for the request's whole duration, so concurrency on a single client host is capped by the OS ephemeral-port range (default 32768-60999, ~28k ports). Past that ceiling new connections fail with EADDRNOTAVAIL, and the wait for a port to free is charged to TTFT rather than surfaced as an error. TCP 4-tuple uniqueness (src_ip, src_port, dst_ip, dst_port) is per source IP, so binding outbound connections across N local source IPs gives N independent ephemeral-port spaces to the same destination -- multiplying the usable connection budget by ~N on a single host (multiple NICs, or 127.0.0.0/8 aliases on Linux). - config: add `source_ips` (default empty = unchanged OS default source selection); scale the ephemeral-port budget clamp by the source-IP count. - http: ConnectionPool round-robins `local_addr` across `source_ips` per new connection; empty list is normalized to None (no hot-path overhead). - worker: thread `source_ips` from HTTPClientConfig into each per-worker pool. Verified: 8 unit tests (round-robin binding + budget scaling), plus an OS-level check showing 5 source IPs yield exactly 5x established connections to a single destination, the same client port coexisting across source IPs, and EADDRNOTAVAIL raised only when one IP's port window is exhausted. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../endpoint_client/config.py | 43 ++++++++- .../endpoint_client/http.py | 17 ++++ .../endpoint_client/worker.py | 1 + tests/unit/endpoint_client/test_http.py | 90 +++++++++++++++++++ .../test_http_client_config.py | 60 +++++++++++++ 5 files changed, 207 insertions(+), 4 deletions(-) diff --git a/src/inference_endpoint/endpoint_client/config.py b/src/inference_endpoint/endpoint_client/config.py index 6f93aba9d..b7a0a6202 100644 --- a/src/inference_endpoint/endpoint_client/config.py +++ b/src/inference_endpoint/endpoint_client/config.py @@ -107,6 +107,27 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel): ), ] = Field(-1, ge=-1) + # Client source IPs to bind outbound connections to, round-robined per new + # connection (see ConnectionPool._create_connection). + # + # Each distinct source IP has an independent ephemeral-port space to a given + # destination — TCP 4-tuple (src_ip, src_port, dst_ip, dst_port) uniqueness + # is per src_ip — so binding across N source IPs multiplies the usable + # connection budget by ~N on a single host, lifting the ~28k single-IP + # ephemeral-port ceiling that otherwise caps concurrency for streaming runs. + # + # IPs must be assigned to local interfaces (multiple NICs, or loopback + # aliases in 127.0.0.0/8 on Linux). Empty = OS default source selection + # (unchanged single-IP behavior). + source_ips: Annotated[ + list[str], + cyclopts.Parameter( + alias="--source-ips", + help="Client source IPs to round-robin connections across " + "(multiplies ephemeral-port budget by IP count). Empty=OS default.", + ), + ] = Field(default_factory=list) + # Transport configuration transport: AnyTransportConfig = Field( default_factory=TransportConfig.create_default @@ -222,6 +243,12 @@ def _resolve_zeros(cls, v: int, info: Any) -> int: raise ValueError(f"{info.field_name} must be -1 (auto) or >= 1, got 0") return v + @field_validator("source_ips") + @classmethod + def _clean_source_ips(cls, v: list[str]) -> list[str]: + # Drop blank entries so a stray "" can't bind to INADDR_ANY (all interfaces). + return [ip.strip() for ip in v if ip and ip.strip()] + @model_validator(mode="after") def _resolve_defaults(self) -> HTTPClientConfig: """Resolve auto-detect values and lazy defaults.""" @@ -253,13 +280,21 @@ def _resolve_defaults(self) -> HTTPClientConfig: system_maximum_ports = high - low + 1 available_ports = get_ephemeral_port_limit() + # Each distinct source IP has its own ephemeral-port space to a given + # destination (4-tuple uniqueness is per src_ip), so binding across N + # source IPs scales the connection budget by ~N. See `source_ips`. + source_ip_multiplier = max(1, len(self.source_ips)) + port_budget = available_ports * source_ip_multiplier + if self.max_connections == -1: - object.__setattr__(self, "max_connections", available_ports) + object.__setattr__(self, "max_connections", port_budget) elif self.max_connections > 0: - if self.max_connections > available_ports: + if self.max_connections > port_budget: raise RuntimeError( - f"--max-connections ({self.max_connections}) exceeds ephemeral port limit ({available_ports}). " - f"Either reduce --max-connections or increase system port limit." + f"--max-connections ({self.max_connections}) exceeds ephemeral port " + f"budget ({port_budget} = {available_ports} ports x {source_ip_multiplier} " + f"source IP(s)). Either reduce --max-connections, add --source-ips, or " + f"increase the system port limit." ) if self.min_required_connections == -1: diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index 210d3e192..30dee945f 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -471,6 +471,8 @@ def __init__( max_connections: int | None = None, # None means no limit max_idle_time: float = 4.0, # Discard connections idle longer than this ssl_context: ssl.SSLContext | None = None, + source_ips: list[str] + | None = None, # round-robin bind src IPs; None=OS default ): self._host = host self._port = port @@ -478,6 +480,11 @@ def __init__( self._max_connections = max_connections self._max_idle_time = max_idle_time self._ssl_context = ssl_context + # Round-robin outbound binds across these source IPs to multiply the + # ephemeral-port budget (one independent port space per src_ip). Empty + # list normalized to None so the hot path stays a single is-None check. + self._source_ips: list[str] | None = source_ips or None + self._source_ip_idx: int = 0 # Connection tracking self._idle_stack: list[PooledConnection] = [] self._all_connections: set[PooledConnection] = set() @@ -560,12 +567,22 @@ async def _create_connection(self) -> PooledConnection: def protocol_factory() -> HttpResponseProtocol: return HttpResponseProtocol(self._loop) + # Round-robin the outbound source IP so each new connection draws + # from a different src_ip's ephemeral-port space. port 0 => kernel + # picks the ephemeral port; None => OS default source selection. + local_addr: tuple[str, int] | None = None + if self._source_ips: + src_ip = self._source_ips[self._source_ip_idx % len(self._source_ips)] + self._source_ip_idx += 1 + local_addr = (src_ip, 0) + # Create connection without timeout transport, protocol = await self._loop.create_connection( protocol_factory, host=self._host, port=self._port, ssl=self._ssl_context, + local_addr=local_addr, ) # Apply/Override socket defaults diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index a87b6889f..fa066cdb4 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -223,6 +223,7 @@ async def run(self) -> None: max_connections=connections_per_worker, max_idle_time=self.http_config.max_idle_time, ssl_context=self._ssl_context, + source_ips=self.http_config.source_ips, ) # Signal handlers for graceful shutdown diff --git a/tests/unit/endpoint_client/test_http.py b/tests/unit/endpoint_client/test_http.py index d9be6ddab..d3b20c740 100644 --- a/tests/unit/endpoint_client/test_http.py +++ b/tests/unit/endpoint_client/test_http.py @@ -811,3 +811,93 @@ def is_closing(self) -> bool: def get_extra_info(self, name: str, default=None): return default + + +class _RecordingLoop: + """Wraps a real event loop, records the ``local_addr`` passed to + create_connection, then delegates WITHOUT it so the real connection to the + echo server still succeeds even when the recorded source IPs are fake.""" + + def __init__(self, real): + self._real = real + self.local_addrs: list = [] + + def __getattr__(self, name): + return getattr(self._real, name) + + async def create_connection( + self, + protocol_factory, + host=None, + port=None, + ssl=None, + local_addr=None, + **kwargs, + ): + self.local_addrs.append(local_addr) + return await self._real.create_connection( + protocol_factory, host=host, port=port, ssl=ssl + ) + + +class TestConnectionPoolSourceIps: + """source_ips round-robin binding (multiplies the ephemeral-port budget).""" + + @pytest.mark.asyncio + async def test_round_robins_local_addr_across_source_ips(self, echo_server): + loop = _RecordingLoop(asyncio.get_running_loop()) + parsed = urlparse(echo_server.url) + pool = ConnectionPool( + host=parsed.hostname, + port=parsed.port, + loop=loop, + max_connections=8, + source_ips=["10.0.0.1", "10.0.0.2", "10.0.0.3"], + ) + conns = [] + try: + conns = [await pool.acquire() for _ in range(7)] + # Cycles through the IPs in order, wrapping around. + assert loop.local_addrs == [ + ("10.0.0.1", 0), + ("10.0.0.2", 0), + ("10.0.0.3", 0), + ("10.0.0.1", 0), + ("10.0.0.2", 0), + ("10.0.0.3", 0), + ("10.0.0.1", 0), + ] + finally: + for c in conns: + pool.release(c) + await pool.close() + + @pytest.mark.asyncio + async def test_no_source_ips_uses_default_local_addr_none(self, echo_server): + loop = _RecordingLoop(asyncio.get_running_loop()) + parsed = urlparse(echo_server.url) + pool = ConnectionPool( + host=parsed.hostname, port=parsed.port, loop=loop, max_connections=4 + ) + conns = [] + try: + conns = [await pool.acquire() for _ in range(2)] + assert loop.local_addrs == [None, None] + finally: + for c in conns: + pool.release(c) + await pool.close() + + @pytest.mark.asyncio + async def test_empty_source_ips_normalized_to_none(self, echo_server): + loop = asyncio.get_running_loop() + parsed = urlparse(echo_server.url) + pool = ConnectionPool( + host=parsed.hostname, + port=parsed.port, + loop=loop, + max_connections=2, + source_ips=[], + ) + assert pool._source_ips is None + await pool.close() diff --git a/tests/unit/endpoint_client/test_http_client_config.py b/tests/unit/endpoint_client/test_http_client_config.py index 22e251f36..152a56a6a 100644 --- a/tests/unit/endpoint_client/test_http_client_config.py +++ b/tests/unit/endpoint_client/test_http_client_config.py @@ -9,6 +9,7 @@ from unittest.mock import patch +import pytest from inference_endpoint.endpoint_client import config as cfg from inference_endpoint.endpoint_client.cpu_affinity import UnsupportedPlatformError @@ -43,3 +44,62 @@ def test_http_client_config_constructs_when_numa_unsupported(self): ): c = cfg.HTTPClientConfig() assert c.num_workers == 10 + + +class TestSourceIpsBudgetScaling: + """source_ips multiplies the ephemeral-port budget by the IP count.""" + + # num_workers is pinned (>=1) so config resolution skips NUMA auto-probe. + + def test_blank_source_ips_are_dropped(self): + c = cfg.HTTPClientConfig(source_ips=["127.0.0.1", " ", "", "127.0.0.2"]) + assert c.source_ips == ["127.0.0.1", "127.0.0.2"] + + def test_auto_max_connections_scales_by_source_ip_count(self): + with ( + patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)), + patch.object(cfg, "get_ephemeral_port_limit", return_value=10000), + ): + c = cfg.HTTPClientConfig( + endpoint_urls=["http://localhost:8000"], + num_workers=10, + source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"], + ) + assert c.max_connections == 30000 # 10000 available x 3 source IPs + + def test_auto_max_connections_unchanged_without_source_ips(self): + with ( + patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)), + patch.object(cfg, "get_ephemeral_port_limit", return_value=10000), + ): + c = cfg.HTTPClientConfig( + endpoint_urls=["http://localhost:8000"], num_workers=10 + ) + assert c.max_connections == 10000 # single-IP budget, unchanged + + def test_explicit_max_connections_within_scaled_budget_ok(self): + # 25000 exceeds the single-IP budget (10000) but fits 3 IPs (30000). + with ( + patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)), + patch.object(cfg, "get_ephemeral_port_limit", return_value=10000), + ): + c = cfg.HTTPClientConfig( + endpoint_urls=["http://localhost:8000"], + num_workers=10, + max_connections=25000, + source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"], + ) + assert c.max_connections == 25000 + + def test_explicit_max_connections_exceeding_scaled_budget_raises(self): + with ( + patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)), + patch.object(cfg, "get_ephemeral_port_limit", return_value=10000), + ): + with pytest.raises(RuntimeError, match="exceeds ephemeral port"): + cfg.HTTPClientConfig( + endpoint_urls=["http://localhost:8000"], + num_workers=10, + max_connections=40000, + source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"], + )