diff --git a/src/forge/proxy/__main__.py b/src/forge/proxy/__main__.py index 055bae0..4ef0688 100644 --- a/src/forge/proxy/__main__.py +++ b/src/forge/proxy/__main__.py @@ -67,6 +67,12 @@ def main() -> None: parser.add_argument("--serialize", action="store_true", default=None, help="Force request serialization") parser.add_argument("--no-serialize", action="store_true", help="Disable request serialization") parser.add_argument("--max-retries", type=int, default=3, help="Max retries per request (default: 3)") + parser.add_argument( + "--backend-timeout", + type=float, + default=300.0, + help="Backend response timeout in seconds (default: 300)", + ) parser.add_argument("--no-rescue", action="store_true", help="Disable rescue parsing") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") @@ -104,6 +110,7 @@ def main() -> None: rescue_enabled=not args.no_rescue, mode=args.mode, backend_protocol=args.backend_protocol, + backend_timeout=args.backend_timeout, ) def _shutdown(sig: int, _frame: object) -> None: diff --git a/src/forge/proxy/proxy.py b/src/forge/proxy/proxy.py index d534411..a4d3c8d 100644 --- a/src/forge/proxy/proxy.py +++ b/src/forge/proxy/proxy.py @@ -9,6 +9,7 @@ import asyncio import logging +import math import threading from pathlib import Path from typing import Literal @@ -68,6 +69,7 @@ def __init__( rescue_enabled: bool = True, mode: Literal["native", "prompt"] = "native", backend_protocol: Literal["openai", "anthropic"] = "openai", + backend_timeout: float = 300.0, ) -> None: """ Args: @@ -101,6 +103,8 @@ def __init__( for Anthropic-shape downstreams (the official Anthropic API, LiteLLM's /v1/messages, a self-hosted Anthropic proxy). Only meaningful in external mode; ignored in managed mode. + backend_timeout: Timeout in seconds for requests from the proxy to + the downstream backend. """ if backend_url is None and backend is None: raise ValueError("Provide either backend_url (external) or backend (managed)") @@ -126,6 +130,8 @@ def __init__( "backend='vllm' parses tool calls server-side (native only); " "mode='prompt' is not applicable." ) + if not math.isfinite(backend_timeout) or backend_timeout <= 0: + raise ValueError("backend_timeout must be a finite value greater than 0") # Managed mode: each backend requires its own identity field. Fail # fast at construction with a clear message (mirrors setup_backend). if backend_url is None: @@ -151,6 +157,7 @@ def __init__( self._rescue_enabled = rescue_enabled self._mode = mode self._backend_protocol = backend_protocol + self._backend_timeout = backend_timeout # Auto-detect serialization: managed (no external url) = single local # GPU = serialize. External callers manage their own concurrency. @@ -257,6 +264,7 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]: client: LLMClient = AnthropicClient( model=self._model or "claude", base_url=self._backend_url.rstrip("/"), + timeout=self._backend_timeout, ) # Anthropic models report a known context length; keep the legacy # 8192 fallback rather than failing the well-behaved Path-1 case. @@ -273,7 +281,11 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]: base = base + "/v1" if self._backend == "vllm": - client = VLLMClient(model_path="default", base_url=base) + client = VLLMClient( + model_path="default", + base_url=base, + timeout=self._backend_timeout, + ) # Unlike llama.cpp, vLLM validates the wire `model` field against # its --served-model-name aliases (404 on mismatch). External mode # has no model path to send, so discover the served identity from @@ -299,6 +311,7 @@ async def _setup_external(self) -> tuple[LLMClient, ContextManager]: gguf_path=self._model or "default", base_url=base, mode=self._mode, + timeout=self._backend_timeout, ) if self._budget_tokens is not None: @@ -348,18 +361,23 @@ def _build_managed_client(self) -> LLMClient: base_url = f"http://localhost:{self._backend_port}/v1" if self._backend == "ollama": assert self._model is not None - return OllamaClient(model=self._model) + return OllamaClient( + model=self._model, + timeout=self._backend_timeout, + ) if self._backend in ("llamaserver", "llamafile"): return LlamafileClient( gguf_path=self._gguf or "default", base_url=base_url, mode=self._mode, + timeout=self._backend_timeout, ) if self._backend == "vllm": assert self._model_path is not None return VLLMClient( model_path=self._model_path, base_url=base_url, + timeout=self._backend_timeout, ) raise ValueError(f"unsupported backend: {self._backend!r}") diff --git a/tests/unit/test_proxy_path1.py b/tests/unit/test_proxy_path1.py index 5ba902b..6a1d591 100644 --- a/tests/unit/test_proxy_path1.py +++ b/tests/unit/test_proxy_path1.py @@ -12,7 +12,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -55,6 +55,28 @@ def test_openai_default_unchanged(self): ) assert proxy._backend_protocol == "openai" + @pytest.mark.asyncio + async def test_anthropic_external_receives_backend_timeout(self): + proxy = ProxyServer( + backend_url="http://localhost:8080", + backend_protocol="anthropic", + backend_timeout=1800.0, + ) + with patch("forge.clients.anthropic.AnthropicClient") as mock_client_cls: + mock_client = MagicMock() + mock_client.get_context_length = AsyncMock(return_value=200000) + mock_client_cls.return_value = mock_client + + client, ctx = await proxy._setup_external() + + assert client is mock_client + assert ctx.budget_tokens == 200000 + mock_client_cls.assert_called_once_with( + model="claude", + base_url="http://localhost:8080", + timeout=1800.0, + ) + # ── AnthropicClient verbatim path ──────────────────────────── diff --git a/tests/unit/test_proxy_proxy.py b/tests/unit/test_proxy_proxy.py index 7cd8e32..56f8c34 100644 --- a/tests/unit/test_proxy_proxy.py +++ b/tests/unit/test_proxy_proxy.py @@ -62,6 +62,15 @@ def test_managed_vllm_requires_model_path(self) -> None: with pytest.raises(ValueError, match="requires model_path"): ProxyServer(backend="vllm") + @pytest.mark.parametrize("backend_timeout", [0, -1, float("nan"), float("inf")]) + def test_backend_timeout_must_be_finite_and_positive( + self, backend_timeout: float, + ) -> None: + with pytest.raises( + ValueError, match="backend_timeout must be a finite value greater than 0", + ): + ProxyServer(backend_url="http://x:8000", backend_timeout=backend_timeout) + def test_managed_ok(self) -> None: ProxyServer(backend="llamaserver", gguf="m.gguf") ProxyServer(backend="llamafile", gguf="m.gguf") @@ -75,6 +84,11 @@ def test_external_ok(self) -> None: proxy2 = ProxyServer(backend_url="http://x:8000", backend="vllm") assert proxy2._backend == "vllm" + def test_backend_timeout_default_and_override(self) -> None: + assert ProxyServer(backend_url="http://x:8000")._backend_timeout == 300.0 + proxy = ProxyServer(backend_url="http://x:8000", backend_timeout=1800.0) + assert proxy._backend_timeout == 1800.0 + # Serialize auto-detection: managed (no url) serializes, external does not. def test_serialize_auto_managed_true(self) -> None: assert ProxyServer(backend="vllm", model_path="/m")._serialize is True @@ -92,10 +106,15 @@ class TestSetupExternal: @pytest.mark.asyncio async def test_llamaserver_uses_llamafile_client(self) -> None: - proxy = ProxyServer(backend_url="http://localhost:8080", budget_tokens=8192) + proxy = ProxyServer( + backend_url="http://localhost:8080", + budget_tokens=8192, + backend_timeout=1800.0, + ) client, ctx = await proxy._setup_external() assert isinstance(client, LlamafileClient) assert client.base_url == "http://localhost:8080/v1" + assert client._http.timeout.read == 1800.0 assert ctx.budget_tokens == 8192 @pytest.mark.asyncio @@ -109,7 +128,10 @@ async def test_explicit_llamafile_backend_uses_llamafile_client(self) -> None: @pytest.mark.asyncio async def test_vllm_uses_vllm_client(self) -> None: proxy = ProxyServer( - backend_url="http://localhost:8000", backend="vllm", budget_tokens=8192, + backend_url="http://localhost:8000", + backend="vllm", + budget_tokens=8192, + backend_timeout=1800.0, ) with patch.object( VLLMClient, "get_served_model_name", new_callable=AsyncMock, return_value=None, @@ -117,6 +139,7 @@ async def test_vllm_uses_vllm_client(self) -> None: client, ctx = await proxy._setup_external() assert isinstance(client, VLLMClient) assert client.base_url == "http://localhost:8000/v1" + assert client._http.timeout.read == 1800.0 assert ctx.budget_tokens == 8192 @pytest.mark.asyncio @@ -186,6 +209,7 @@ async def test_llamaserver_wiring(self) -> None: backend_port=8080, budget_mode=BudgetMode.FORGE_FAST, extra_flags=["-ngl", "99"], + backend_timeout=1800.0, ) mock_ctx = ContextManager.__new__(ContextManager) mock_ctx.budget_tokens = 16384 @@ -199,6 +223,7 @@ async def test_llamaserver_wiring(self) -> None: assert isinstance(client, LlamafileClient) assert client.base_url == "http://localhost:8080/v1" + assert client._http.timeout.read == 1800.0 kwargs = mock_setup.await_args.kwargs assert kwargs["backend"] == "llamaserver" assert kwargs["gguf_path"] == "/models/x.gguf" @@ -217,6 +242,7 @@ async def test_vllm_wiring(self) -> None: proxy = ProxyServer( backend="vllm", model_path="/models/awq", backend_port=8000, budget_tokens=113000, budget_mode=BudgetMode.MANUAL, + backend_timeout=1800.0, ) mock_ctx = ContextManager.__new__(ContextManager) mock_ctx.budget_tokens = 113000 @@ -228,6 +254,7 @@ async def test_vllm_wiring(self) -> None: assert isinstance(client, VLLMClient) assert client.base_url == "http://localhost:8000/v1" + assert client._http.timeout.read == 1800.0 kwargs = mock_setup.await_args.kwargs assert kwargs["backend"] == "vllm" assert kwargs["model_path"] == "/models/awq" @@ -238,7 +265,11 @@ async def test_vllm_wiring(self) -> None: @pytest.mark.asyncio async def test_ollama_wiring(self) -> None: - proxy = ProxyServer(backend="ollama", model="ministral-3:14b") + proxy = ProxyServer( + backend="ollama", + model="ministral-3:14b", + backend_timeout=1800.0, + ) mock_ctx = ContextManager.__new__(ContextManager) mock_ctx.budget_tokens = 4096 with patch( @@ -247,6 +278,7 @@ async def test_ollama_wiring(self) -> None: ) as mock_setup: client, _ = await proxy._setup_managed() assert isinstance(client, OllamaClient) + assert client._http.timeout.read == 1800.0 kwargs = mock_setup.await_args.kwargs assert kwargs["backend"] == "ollama" assert kwargs["model"] == "ministral-3:14b"