Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/inference_endpoint/endpoint_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,27 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel):
),
] = Field(-1, ge=-1)

# Client source IPs to bind outbound connections to, round-robined per new
# connection (see ConnectionPool._create_connection).
#
# Each distinct source IP has an independent ephemeral-port space to a given
# destination — TCP 4-tuple (src_ip, src_port, dst_ip, dst_port) uniqueness
# is per src_ip — so binding across N source IPs multiplies the usable
# connection budget by ~N on a single host, lifting the ~28k single-IP
# ephemeral-port ceiling that otherwise caps concurrency for streaming runs.
#
# IPs must be assigned to local interfaces (multiple NICs, or loopback
# aliases in 127.0.0.0/8 on Linux). Empty = OS default source selection
# (unchanged single-IP behavior).
source_ips: Annotated[
list[str],
cyclopts.Parameter(
alias="--source-ips",
help="Client source IPs to round-robin connections across "
"(multiplies ephemeral-port budget by IP count). Empty=OS default.",
),
] = Field(default_factory=list)

# Transport configuration
transport: AnyTransportConfig = Field(
default_factory=TransportConfig.create_default
Expand Down Expand Up @@ -222,6 +243,12 @@ def _resolve_zeros(cls, v: int, info: Any) -> int:
raise ValueError(f"{info.field_name} must be -1 (auto) or >= 1, got 0")
return v

@field_validator("source_ips")
@classmethod
def _clean_source_ips(cls, v: list[str]) -> list[str]:
# Drop blank entries so a stray "" can't bind to INADDR_ANY (all interfaces).
return [ip.strip() for ip in v if ip and ip.strip()]
Comment on lines +246 to +250

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If duplicate IP addresses are provided in --source-ips (e.g., ["127.0.0.1", "127.0.0.1"]), they are not filtered out. This will cause len(self.source_ips) to count duplicates, leading to an incorrect (overestimated) ephemeral port budget calculation in _resolve_defaults.

Deduplicating the list while preserving order using dict.fromkeys ensures the budget is scaled accurately.

Suggested change
@field_validator("source_ips")
@classmethod
def _clean_source_ips(cls, v: list[str]) -> list[str]:
# Drop blank entries so a stray "" can't bind to INADDR_ANY (all interfaces).
return [ip.strip() for ip in v if ip and ip.strip()]
@field_validator("source_ips")
@classmethod
def _clean_source_ips(cls, v: list[str]) -> list[str]:
# Drop blank entries so a stray "" can't bind to INADDR_ANY (all interfaces).
# Also deduplicate to avoid overestimating the ephemeral port budget.
return list(dict.fromkeys(ip.strip() for ip in v if ip and ip.strip()))


@model_validator(mode="after")
def _resolve_defaults(self) -> HTTPClientConfig:
"""Resolve auto-detect values and lazy defaults."""
Expand Down Expand Up @@ -253,13 +280,21 @@ def _resolve_defaults(self) -> HTTPClientConfig:
system_maximum_ports = high - low + 1
available_ports = get_ephemeral_port_limit()

# Each distinct source IP has its own ephemeral-port space to a given
# destination (4-tuple uniqueness is per src_ip), so binding across N
# source IPs scales the connection budget by ~N. See `source_ips`.
source_ip_multiplier = max(1, len(self.source_ips))
port_budget = available_ports * source_ip_multiplier

if self.max_connections == -1:
object.__setattr__(self, "max_connections", available_ports)
object.__setattr__(self, "max_connections", port_budget)
elif self.max_connections > 0:
if self.max_connections > available_ports:
if self.max_connections > port_budget:
raise RuntimeError(
f"--max-connections ({self.max_connections}) exceeds ephemeral port limit ({available_ports}). "
f"Either reduce --max-connections or increase system port limit."
f"--max-connections ({self.max_connections}) exceeds ephemeral port "
f"budget ({port_budget} = {available_ports} ports x {source_ip_multiplier} "
f"source IP(s)). Either reduce --max-connections, add --source-ips, or "
f"increase the system port limit."
)

if self.min_required_connections == -1:
Expand Down
17 changes: 17 additions & 0 deletions src/inference_endpoint/endpoint_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,20 @@ def __init__(
max_connections: int | None = None, # None means no limit
max_idle_time: float = 4.0, # Discard connections idle longer than this
ssl_context: ssl.SSLContext | None = None,
source_ips: list[str]
| None = None, # round-robin bind src IPs; None=OS default
):
self._host = host
self._port = port
self._loop = loop
self._max_connections = max_connections
self._max_idle_time = max_idle_time
self._ssl_context = ssl_context
# Round-robin outbound binds across these source IPs to multiply the
# ephemeral-port budget (one independent port space per src_ip). Empty
# list normalized to None so the hot path stays a single is-None check.
self._source_ips: list[str] | None = source_ips or None
self._source_ip_idx: int = 0
# Connection tracking
self._idle_stack: list[PooledConnection] = []
self._all_connections: set[PooledConnection] = set()
Expand Down Expand Up @@ -560,12 +567,22 @@ async def _create_connection(self) -> PooledConnection:
def protocol_factory() -> HttpResponseProtocol:
return HttpResponseProtocol(self._loop)

# Round-robin the outbound source IP so each new connection draws
# from a different src_ip's ephemeral-port space. port 0 => kernel
# picks the ephemeral port; None => OS default source selection.
local_addr: tuple[str, int] | None = None
if self._source_ips:
src_ip = self._source_ips[self._source_ip_idx % len(self._source_ips)]
self._source_ip_idx += 1
local_addr = (src_ip, 0)

# Create connection without timeout
transport, protocol = await self._loop.create_connection(
protocol_factory,
host=self._host,
port=self._port,
ssl=self._ssl_context,
local_addr=local_addr,
)

# Apply/Override socket defaults
Expand Down
1 change: 1 addition & 0 deletions src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ async def run(self) -> None:
max_connections=connections_per_worker,
max_idle_time=self.http_config.max_idle_time,
ssl_context=self._ssl_context,
source_ips=self.http_config.source_ips,
)

# Signal handlers for graceful shutdown
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/endpoint_client/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,93 @@ def is_closing(self) -> bool:

def get_extra_info(self, name: str, default=None):
return default


class _RecordingLoop:
"""Wraps a real event loop, records the ``local_addr`` passed to
create_connection, then delegates WITHOUT it so the real connection to the
echo server still succeeds even when the recorded source IPs are fake."""

def __init__(self, real):
self._real = real
self.local_addrs: list = []

def __getattr__(self, name):
return getattr(self._real, name)

async def create_connection(
self,
protocol_factory,
host=None,
port=None,
ssl=None,
local_addr=None,
**kwargs,
):
self.local_addrs.append(local_addr)
return await self._real.create_connection(
protocol_factory, host=host, port=port, ssl=ssl
)


class TestConnectionPoolSourceIps:
"""source_ips round-robin binding (multiplies the ephemeral-port budget)."""

@pytest.mark.asyncio
async def test_round_robins_local_addr_across_source_ips(self, echo_server):
loop = _RecordingLoop(asyncio.get_running_loop())
parsed = urlparse(echo_server.url)
pool = ConnectionPool(
host=parsed.hostname,
port=parsed.port,
loop=loop,
max_connections=8,
source_ips=["10.0.0.1", "10.0.0.2", "10.0.0.3"],
)
conns = []
try:
conns = [await pool.acquire() for _ in range(7)]
# Cycles through the IPs in order, wrapping around.
assert loop.local_addrs == [
("10.0.0.1", 0),
("10.0.0.2", 0),
("10.0.0.3", 0),
("10.0.0.1", 0),
("10.0.0.2", 0),
("10.0.0.3", 0),
("10.0.0.1", 0),
]
finally:
for c in conns:
pool.release(c)
await pool.close()

@pytest.mark.asyncio
async def test_no_source_ips_uses_default_local_addr_none(self, echo_server):
loop = _RecordingLoop(asyncio.get_running_loop())
parsed = urlparse(echo_server.url)
pool = ConnectionPool(
host=parsed.hostname, port=parsed.port, loop=loop, max_connections=4
)
conns = []
try:
conns = [await pool.acquire() for _ in range(2)]
assert loop.local_addrs == [None, None]
finally:
for c in conns:
pool.release(c)
await pool.close()

@pytest.mark.asyncio
async def test_empty_source_ips_normalized_to_none(self, echo_server):
loop = asyncio.get_running_loop()
parsed = urlparse(echo_server.url)
pool = ConnectionPool(
host=parsed.hostname,
port=parsed.port,
loop=loop,
max_connections=2,
source_ips=[],
)
assert pool._source_ips is None
await pool.close()
60 changes: 60 additions & 0 deletions tests/unit/endpoint_client/test_http_client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from unittest.mock import patch

import pytest
from inference_endpoint.endpoint_client import config as cfg
from inference_endpoint.endpoint_client.cpu_affinity import UnsupportedPlatformError

Expand Down Expand Up @@ -43,3 +44,62 @@ def test_http_client_config_constructs_when_numa_unsupported(self):
):
c = cfg.HTTPClientConfig()
assert c.num_workers == 10


class TestSourceIpsBudgetScaling:
"""source_ips multiplies the ephemeral-port budget by the IP count."""

# num_workers is pinned (>=1) so config resolution skips NUMA auto-probe.

def test_blank_source_ips_are_dropped(self):
c = cfg.HTTPClientConfig(source_ips=["127.0.0.1", " ", "", "127.0.0.2"])
assert c.source_ips == ["127.0.0.1", "127.0.0.2"]
Comment on lines +54 to +56

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to verify that duplicate source IPs are also correctly deduplicated and cleaned.

Suggested change
def test_blank_source_ips_are_dropped(self):
c = cfg.HTTPClientConfig(source_ips=["127.0.0.1", " ", "", "127.0.0.2"])
assert c.source_ips == ["127.0.0.1", "127.0.0.2"]
def test_blank_and_duplicate_source_ips_are_cleaned(self):
c = cfg.HTTPClientConfig(source_ips=["127.0.0.1", " ", "", "127.0.0.1", "127.0.0.2"])
assert c.source_ips == ["127.0.0.1", "127.0.0.2"]


def test_auto_max_connections_scales_by_source_ip_count(self):
with (
patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)),
patch.object(cfg, "get_ephemeral_port_limit", return_value=10000),
):
c = cfg.HTTPClientConfig(
endpoint_urls=["http://localhost:8000"],
num_workers=10,
source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"],
)
assert c.max_connections == 30000 # 10000 available x 3 source IPs

def test_auto_max_connections_unchanged_without_source_ips(self):
with (
patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)),
patch.object(cfg, "get_ephemeral_port_limit", return_value=10000),
):
c = cfg.HTTPClientConfig(
endpoint_urls=["http://localhost:8000"], num_workers=10
)
assert c.max_connections == 10000 # single-IP budget, unchanged

def test_explicit_max_connections_within_scaled_budget_ok(self):
# 25000 exceeds the single-IP budget (10000) but fits 3 IPs (30000).
with (
patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)),
patch.object(cfg, "get_ephemeral_port_limit", return_value=10000),
):
c = cfg.HTTPClientConfig(
endpoint_urls=["http://localhost:8000"],
num_workers=10,
max_connections=25000,
source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"],
)
assert c.max_connections == 25000

def test_explicit_max_connections_exceeding_scaled_budget_raises(self):
with (
patch.object(cfg, "get_ephemeral_port_range", return_value=(32768, 60999)),
patch.object(cfg, "get_ephemeral_port_limit", return_value=10000),
):
with pytest.raises(RuntimeError, match="exceeds ephemeral port"):
cfg.HTTPClientConfig(
endpoint_urls=["http://localhost:8000"],
num_workers=10,
max_connections=40000,
source_ips=["1.1.1.1", "2.2.2.2", "3.3.3.3"],
)
Loading