Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/forge/proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def main() -> None:
parser.add_argument("--serialize", action="store_true", default=None, help="Force request serialization")
parser.add_argument("--no-serialize", action="store_true", help="Disable request serialization")
parser.add_argument("--max-retries", type=int, default=3, help="Max retries per request (default: 3)")
parser.add_argument(
"--backend-timeout",
type=float,
default=300.0,
help="Backend response timeout in seconds (default: 300)",
)
parser.add_argument("--no-rescue", action="store_true", help="Disable rescue parsing")
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging")

Expand Down Expand Up @@ -104,6 +110,7 @@ def main() -> None:
rescue_enabled=not args.no_rescue,
mode=args.mode,
backend_protocol=args.backend_protocol,
backend_timeout=args.backend_timeout,
)

def _shutdown(sig: int, _frame: object) -> None:
Expand Down
22 changes: 20 additions & 2 deletions src/forge/proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import asyncio
import logging
import math
import threading
from pathlib import Path
from typing import Literal
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
rescue_enabled: bool = True,
mode: Literal["native", "prompt"] = "native",
backend_protocol: Literal["openai", "anthropic"] = "openai",
backend_timeout: float = 300.0,
) -> None:
"""
Args:
Expand Down Expand Up @@ -101,6 +103,8 @@ def __init__(
for Anthropic-shape downstreams (the official Anthropic API,
LiteLLM's /v1/messages, a self-hosted Anthropic proxy).
Only meaningful in external mode; ignored in managed mode.
backend_timeout: Timeout in seconds for requests from the proxy to
the downstream backend.
"""
if backend_url is None and backend is None:
raise ValueError("Provide either backend_url (external) or backend (managed)")
Expand All @@ -126,6 +130,8 @@ def __init__(
"backend='vllm' parses tool calls server-side (native only); "
"mode='prompt' is not applicable."
)
if not math.isfinite(backend_timeout) or backend_timeout <= 0:
raise ValueError("backend_timeout must be a finite value greater than 0")
# Managed mode: each backend requires its own identity field. Fail
# fast at construction with a clear message (mirrors setup_backend).
if backend_url is None:
Expand All @@ -151,6 +157,7 @@ def __init__(
self._rescue_enabled = rescue_enabled
self._mode = mode
self._backend_protocol = backend_protocol
self._backend_timeout = backend_timeout

# Auto-detect serialization: managed (no external url) = single local
# GPU = serialize. External callers manage their own concurrency.
Expand Down Expand Up @@ -257,6 +264,7 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]:
client: LLMClient = AnthropicClient(
model=self._model or "claude",
base_url=self._backend_url.rstrip("/"),
timeout=self._backend_timeout,
)
# Anthropic models report a known context length; keep the legacy
# 8192 fallback rather than failing the well-behaved Path-1 case.
Expand All @@ -273,7 +281,11 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]:
base = base + "/v1"

if self._backend == "vllm":
client = VLLMClient(model_path="default", base_url=base)
client = VLLMClient(
model_path="default",
base_url=base,
timeout=self._backend_timeout,
)
# Unlike llama.cpp, vLLM validates the wire `model` field against
# its --served-model-name aliases (404 on mismatch). External mode
# has no model path to send, so discover the served identity from
Expand All @@ -299,6 +311,7 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]:
gguf_path=self._model or "default",
base_url=base,
mode=self._mode,
timeout=self._backend_timeout,
)

if self._budget_tokens is not None:
Expand Down Expand Up @@ -348,18 +361,23 @@ def _build_managed_client(self) -> LLMClient:
base_url = f"http://localhost:{self._backend_port}/v1"
if self._backend == "ollama":
assert self._model is not None
return OllamaClient(model=self._model)
return OllamaClient(
model=self._model,
timeout=self._backend_timeout,
)
if self._backend in ("llamaserver", "llamafile"):
return LlamafileClient(
gguf_path=self._gguf or "default",
base_url=base_url,
mode=self._mode,
timeout=self._backend_timeout,
)
if self._backend == "vllm":
assert self._model_path is not None
return VLLMClient(
model_path=self._model_path,
base_url=base_url,
timeout=self._backend_timeout,
)
raise ValueError(f"unsupported backend: {self._backend!r}")

Expand Down
24 changes: 23 additions & 1 deletion tests/unit/test_proxy_path1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand Down Expand Up @@ -55,6 +55,28 @@ def test_openai_default_unchanged(self):
)
assert proxy._backend_protocol == "openai"

@pytest.mark.asyncio
async def test_anthropic_external_receives_backend_timeout(self):
proxy = ProxyServer(
backend_url="http://localhost:8080",
backend_protocol="anthropic",
backend_timeout=1800.0,
)
with patch("forge.clients.anthropic.AnthropicClient") as mock_client_cls:
mock_client = MagicMock()
mock_client.get_context_length = AsyncMock(return_value=200000)
mock_client_cls.return_value = mock_client

client, ctx = await proxy._setup_external()

assert client is mock_client
assert ctx.budget_tokens == 200000
mock_client_cls.assert_called_once_with(
model="claude",
base_url="http://localhost:8080",
timeout=1800.0,
)


# ── AnthropicClient verbatim path ────────────────────────────

Expand Down
38 changes: 35 additions & 3 deletions tests/unit/test_proxy_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_managed_vllm_requires_model_path(self) -> None:
with pytest.raises(ValueError, match="requires model_path"):
ProxyServer(backend="vllm")

@pytest.mark.parametrize("backend_timeout", [0, -1, float("nan"), float("inf")])
def test_backend_timeout_must_be_finite_and_positive(
self, backend_timeout: float,
) -> None:
with pytest.raises(
ValueError, match="backend_timeout must be a finite value greater than 0",
):
ProxyServer(backend_url="http://x:8000", backend_timeout=backend_timeout)

def test_managed_ok(self) -> None:
ProxyServer(backend="llamaserver", gguf="m.gguf")
ProxyServer(backend="llamafile", gguf="m.gguf")
Expand All @@ -75,6 +84,11 @@ def test_external_ok(self) -> None:
proxy2 = ProxyServer(backend_url="http://x:8000", backend="vllm")
assert proxy2._backend == "vllm"

def test_backend_timeout_default_and_override(self) -> None:
assert ProxyServer(backend_url="http://x:8000")._backend_timeout == 300.0
proxy = ProxyServer(backend_url="http://x:8000", backend_timeout=1800.0)
assert proxy._backend_timeout == 1800.0

# Serialize auto-detection: managed (no url) serializes, external does not.
def test_serialize_auto_managed_true(self) -> None:
assert ProxyServer(backend="vllm", model_path="/m")._serialize is True
Expand All @@ -92,10 +106,15 @@ class TestSetupExternal:

@pytest.mark.asyncio
async def test_llamaserver_uses_llamafile_client(self) -> None:
proxy = ProxyServer(backend_url="http://localhost:8080", budget_tokens=8192)
proxy = ProxyServer(
backend_url="http://localhost:8080",
budget_tokens=8192,
backend_timeout=1800.0,
)
client, ctx = await proxy._setup_external()
assert isinstance(client, LlamafileClient)
assert client.base_url == "http://localhost:8080/v1"
assert client._http.timeout.read == 1800.0
assert ctx.budget_tokens == 8192

@pytest.mark.asyncio
Expand All @@ -109,14 +128,18 @@ async def test_explicit_llamafile_backend_uses_llamafile_client(self) -> None:
@pytest.mark.asyncio
async def test_vllm_uses_vllm_client(self) -> None:
proxy = ProxyServer(
backend_url="http://localhost:8000", backend="vllm", budget_tokens=8192,
backend_url="http://localhost:8000",
backend="vllm",
budget_tokens=8192,
backend_timeout=1800.0,
)
with patch.object(
VLLMClient, "get_served_model_name", new_callable=AsyncMock, return_value=None,
):
client, ctx = await proxy._setup_external()
assert isinstance(client, VLLMClient)
assert client.base_url == "http://localhost:8000/v1"
assert client._http.timeout.read == 1800.0
assert ctx.budget_tokens == 8192

@pytest.mark.asyncio
Expand Down Expand Up @@ -186,6 +209,7 @@ async def test_llamaserver_wiring(self) -> None:
backend_port=8080,
budget_mode=BudgetMode.FORGE_FAST,
extra_flags=["-ngl", "99"],
backend_timeout=1800.0,
)
mock_ctx = ContextManager.__new__(ContextManager)
mock_ctx.budget_tokens = 16384
Expand All @@ -199,6 +223,7 @@ async def test_llamaserver_wiring(self) -> None:

assert isinstance(client, LlamafileClient)
assert client.base_url == "http://localhost:8080/v1"
assert client._http.timeout.read == 1800.0
kwargs = mock_setup.await_args.kwargs
assert kwargs["backend"] == "llamaserver"
assert kwargs["gguf_path"] == "/models/x.gguf"
Expand All @@ -217,6 +242,7 @@ async def test_vllm_wiring(self) -> None:
proxy = ProxyServer(
backend="vllm", model_path="/models/awq", backend_port=8000,
budget_tokens=113000, budget_mode=BudgetMode.MANUAL,
backend_timeout=1800.0,
)
mock_ctx = ContextManager.__new__(ContextManager)
mock_ctx.budget_tokens = 113000
Expand All @@ -228,6 +254,7 @@ async def test_vllm_wiring(self) -> None:

assert isinstance(client, VLLMClient)
assert client.base_url == "http://localhost:8000/v1"
assert client._http.timeout.read == 1800.0
kwargs = mock_setup.await_args.kwargs
assert kwargs["backend"] == "vllm"
assert kwargs["model_path"] == "/models/awq"
Expand All @@ -238,7 +265,11 @@ async def test_vllm_wiring(self) -> None:

@pytest.mark.asyncio
async def test_ollama_wiring(self) -> None:
proxy = ProxyServer(backend="ollama", model="ministral-3:14b")
proxy = ProxyServer(
backend="ollama",
model="ministral-3:14b",
backend_timeout=1800.0,
)
mock_ctx = ContextManager.__new__(ContextManager)
mock_ctx.budget_tokens = 4096
with patch(
Expand All @@ -247,6 +278,7 @@ async def test_ollama_wiring(self) -> None:
) as mock_setup:
client, _ = await proxy._setup_managed()
assert isinstance(client, OllamaClient)
assert client._http.timeout.read == 1800.0
kwargs = mock_setup.await_args.kwargs
assert kwargs["backend"] == "ollama"
assert kwargs["model"] == "ministral-3:14b"
Expand Down
Loading