From 5bad4979b128e500e6febe844c43e3ec2362f4f6 Mon Sep 17 00:00:00 2001 From: Lucas Gerads Date: Sat, 23 May 2026 15:17:07 +0200 Subject: [PATCH 1/2] Add OpenAICompatClient for OpenAI-compatible endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Provider-agnostic client for any backend exposing the OpenAI /v1/chat/completions shape with optional bearer auth — Cloudflare Workers AI, Fireworks, OpenRouter, OpenAI itself, etc. Native function calling, SSE streaming, full LLMClient protocol surface. Addresses the review feedback in #88: - Mirrors VLLMClient sampling kwargs (top_k, min_p, repeat_penalty, presence_penalty, chat_template_kwargs) with recommended_sampling defaulting to False (hosted model IDs aren't in forge's registry). - Accepts passthrough and inbound_anthropic_body to satisfy the LLMClient protocol; inbound_anthropic_body is a no-op here. - extra_headers kwarg for provider quirks (e.g. OpenRouter's HTTP-Referer) without a per-provider registry. - aclose() for httpx pool cleanup. - Docstring sets the contributor expectation: file an issue rather than adding per-provider if/else branches. Includes BACKEND_SETUP.md section and CHANGELOG entry. --- CHANGELOG.md | 6 + docs/BACKEND_SETUP.md | 34 ++ src/forge/__init__.py | 2 + src/forge/clients/__init__.py | 2 + src/forge/clients/openai_compat.py | 278 ++++++++++++++ tests/unit/test_openai_compat_client.py | 478 ++++++++++++++++++++++++ 6 files changed, 800 insertions(+) create mode 100644 src/forge/clients/openai_compat.py create mode 100644 tests/unit/test_openai_compat_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 29aa581..fbae77f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to forge are documented here. +## [Unreleased] + +### Added +- **OpenAI-compatible hosted-provider client (`OpenAICompatClient`).** Adapter for any backend exposing `/v1/chat/completions` with optional bearer auth — covers Cloudflare Workers AI, Fireworks, OpenRouter, OpenAI itself, and similar. Native function calling, SSE streaming, full `LLMClient` protocol surface. Supports `extra_headers` for provider quirks (e.g. OpenRouter's `HTTP-Referer` / `X-Title`) without a per-provider registry. Exported from `forge` and `forge.clients`. #88. +- **Hosted-providers section in [Backend Setup](docs/BACKEND_SETUP.md)** covering bearer-auth setup, the `get_context_length() → None` contract, and the per-model nature of function-calling support on hosted providers. + ## [0.7.2] — 2026-05-24 vLLM backend support — serve AWQ/GPTQ and other vLLM-hosted models behind forge's guardrails, in both proxy modes and via `WorkflowRunner`. diff --git a/docs/BACKEND_SETUP.md b/docs/BACKEND_SETUP.md index f7df0e2..ab85f01 100644 --- a/docs/BACKEND_SETUP.md +++ b/docs/BACKEND_SETUP.md @@ -201,6 +201,40 @@ No server to smoke-test — first inference call surfaces auth/network issues. --- +## Hosted OpenAI-compatible providers + +Any backend exposing `/v1/chat/completions` with bearer auth — Cloudflare Workers AI, Fireworks, OpenRouter, Together, OpenAI itself, and similar. The client is provider-agnostic: caller supplies the `base_url` and `api_key`; forge has no per-provider knowledge. + +Forge client (Cloudflare Workers AI): + +```python +from forge.clients import OpenAICompatClient + +client = OpenAICompatClient( + model="@cf/mistralai/mistral-small-3.1-24b-instruct", + base_url=f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1", + api_key=API_TOKEN, +) +``` + +Provider-specific request headers ride on `extra_headers` (e.g. OpenRouter's attribution): + +```python +client = OpenAICompatClient( + model="mistralai/mistral-small-3.1-24b-instruct", + base_url="https://openrouter.ai/api/v1", + api_key=API_KEY, + extra_headers={"HTTP-Referer": "https://your-app.example", "X-Title": "Your App"}, +) +``` + +Notes: +- **`get_context_length()` returns `None`.** Hosted providers don't expose `max_model_len`. Pass `budget_tokens` explicitly when constructing the `ContextManager` (or `--budget-tokens` to the proxy). +- **Native function calling is per-model, not per-provider.** Many hosted providers serve dozens of models; only the ones with a tool-calling chat template will return structured `tool_calls`. Check the provider's per-model capability docs. +- **Sampling defaults are opt-in.** `recommended_sampling=False` (default) skips the registry lookup, since hosted-provider model identifiers usually aren't in forge's per-model sampling map. Pass explicit `temperature` / `top_p` / etc. as needed. + +--- + ## Gotcha: reasoning budget on recent llama.cpp builds llama.cpp builds after April 10 2026 activate a reasoning budget sampler for models with thinking tags (Gemma 4, Qwen 3.5, Ministral Reasoning). The default budget is unlimited, which causes some runs to hang indefinitely or fill the KV cache until the server crashes. diff --git a/src/forge/__init__.py b/src/forge/__init__.py index 3bf5719..b6b8a27 100644 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -23,6 +23,7 @@ from forge.clients.base import ChunkType, LLMClient, StreamChunk, TokenUsage from forge.clients.llamafile import LlamafileClient from forge.clients.ollama import OllamaClient +from forge.clients.openai_compat import OpenAICompatClient from forge.clients.vllm import VLLMClient from forge.context import ( CompactEvent, @@ -96,6 +97,7 @@ "LLMClient", "LlamafileClient", "OllamaClient", + "OpenAICompatClient", "VLLMClient", "StreamChunk", "TokenUsage", diff --git a/src/forge/clients/__init__.py b/src/forge/clients/__init__.py index d1be4a0..bc2865d 100644 --- a/src/forge/clients/__init__.py +++ b/src/forge/clients/__init__.py @@ -3,6 +3,7 @@ from forge.clients.base import ChunkType, LLMClient, StreamChunk from forge.clients.llamafile import LlamafileClient from forge.clients.ollama import OllamaClient +from forge.clients.openai_compat import OpenAICompatClient from forge.clients.vllm import VLLMClient from forge.clients.sampling_defaults import ( MODEL_SAMPLING_DEFAULTS, @@ -16,6 +17,7 @@ "LlamafileClient", "MODEL_SAMPLING_DEFAULTS", "OllamaClient", + "OpenAICompatClient", "StreamChunk", "VLLMClient", "apply_sampling_defaults", diff --git a/src/forge/clients/openai_compat.py b/src/forge/clients/openai_compat.py new file mode 100644 index 0000000..2c0b131 --- /dev/null +++ b/src/forge/clients/openai_compat.py @@ -0,0 +1,278 @@ +"""OpenAI-compatible client adapter using native function calling. + +Works with any backend that exposes the OpenAI ``/v1/chat/completions`` +endpoint: llama-server's OpenAI mode, Ollama's ``/v1`` shim, Cloudflare +Workers AI, Groq, Together, Fireworks, OpenRouter, OpenAI itself, etc. + +This client is provider-agnostic by design. It knows the *protocol* +(base_url + bearer key + chat/completions), not any specific provider. +The caller is responsible for constructing the ``base_url`` and supplying +the ``api_key``. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from forge.clients.base import ChunkType, StreamChunk, TokenUsage, format_tool +from forge.clients.sampling_defaults import apply_sampling_defaults +from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec +from forge.errors import BackendError + + +class OpenAICompatClient: + """Native function calling via an OpenAI-compatible chat endpoint. + + Posts to ``{base_url}/chat/completions`` with the standard OpenAI + request shape. Bearer auth is sent when ``api_key`` is provided + (omit it for unauthenticated local servers). Provider-specific + headers (e.g. OpenRouter's ``HTTP-Referer``) ride on + ``extra_headers`` without a per-provider quirks registry. + + If a provider's quirks require diverging the parse or stream path, + file an issue rather than adding if/else branches — we'll subclass + or extract a base at that point. + """ + + api_format: str = "openai" + + def __init__( + self, + model: str, + base_url: str, + *, + api_key: str = "", + extra_headers: dict[str, str] | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, + min_p: float | None = None, + repeat_penalty: float | None = None, + presence_penalty: float | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + timeout: float = 120.0, + recommended_sampling: bool = False, + ) -> None: + self.base_url = base_url.rstrip("/") + self.model = model + + # Apply per-model recommended sampling defaults. Caller's explicit + # (non-None) kwargs win over the map field-by-field. With + # recommended_sampling=False (default) and an unknown model stem, + # apply_sampling_defaults returns an empty dict silently — which + # is the common case for hosted providers whose model identifiers + # aren't in forge's registry. + defaults = apply_sampling_defaults(self.model, strict=recommended_sampling) + self.temperature = temperature if temperature is not None else defaults.get("temperature") + self.top_p = top_p if top_p is not None else defaults.get("top_p") + self.top_k = top_k if top_k is not None else defaults.get("top_k") + self.min_p = min_p if min_p is not None else defaults.get("min_p") + self.repeat_penalty = repeat_penalty if repeat_penalty is not None else defaults.get("repeat_penalty") + self.presence_penalty = presence_penalty if presence_penalty is not None else defaults.get("presence_penalty") + # chat_template_kwargs is a nested dict of Jinja template variables + # — whole-value replacement at this field level (no nested merge). + self.chat_template_kwargs = ( + chat_template_kwargs if chat_template_kwargs is not None + else defaults.get("chat_template_kwargs") + ) + + # Auth header is set when api_key is provided; extra_headers ride + # on top and can override (kept open so a provider with a different + # scheme doesn't need a new constructor kwarg). + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if extra_headers: + headers.update(extra_headers) + self._http = httpx.AsyncClient(headers=headers, timeout=timeout) + self.last_usage: dict[int, TokenUsage] = {} + + async def aclose(self) -> None: + """Close the underlying httpx connection pool.""" + await self._http.aclose() + + # ── request building ───────────────────────────────────────────── + + # Sampling fields recognized in per-call overrides. ``seed`` is + # accepted only as a per-call override (not an instance field). + # ``chat_template_kwargs`` is a nested dict — whole-value replacement + # at this field level (no nested merge). + _SAMPLING_FIELDS = ( + "temperature", "top_p", "top_k", "min_p", + "repeat_penalty", "presence_penalty", "seed", + "chat_template_kwargs", + ) + + def _build_body( + self, + messages: list[dict[str, str]], + tools: list[ToolSpec] | None, + sampling: dict[str, Any] | None, + stream: bool, + passthrough: dict[str, Any] | None = None, + ) -> dict[str, Any]: + # Passthrough fields (max_tokens, stop, tool_choice, model, etc. + # extracted by the proxy from the inbound body) seed the outbound + # body first. Forge-owned fields then overlay on top so the + # client's model/messages/stream/tools/sampling invariants win. + body: dict[str, Any] = dict(passthrough or {}) + body["model"] = self.model + body["messages"] = messages + body["stream"] = stream + for field in self._SAMPLING_FIELDS: + override = (sampling or {}).get(field) + if override is not None: + body[field] = override + else: + instance_val = getattr(self, field, None) + if instance_val is not None: + body[field] = instance_val + if tools: + body["tools"] = [format_tool(t) for t in tools] + return body + + def _record_usage(self, data: dict[str, Any]) -> None: + usage = data.get("usage") + if not usage: + return + prompt = usage.get("prompt_tokens") or 0 + completion = usage.get("completion_tokens") or 0 + self.last_usage[0] = TokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=usage.get("total_tokens") or (prompt + completion), + ) + + @staticmethod + def _parse_tool_calls(tool_calls: list[dict[str, Any]]) -> list[ToolCall]: + """Parse OpenAI tool_calls — arguments are JSON strings.""" + parsed: list[ToolCall] = [] + for tc in tool_calls: + fn = tc.get("function", {}) + raw_args = fn.get("arguments") or "{}" + try: + args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args + except json.JSONDecodeError: + args = {} + parsed.append(ToolCall(tool=fn.get("name", ""), args=args)) + return parsed + + # ── send ───────────────────────────────────────────────────────── + + async def send( + self, + messages: list[dict[str, str]], + tools: list[ToolSpec] | None = None, + sampling: dict[str, Any] | None = None, + passthrough: dict[str, Any] | None = None, + inbound_anthropic_body: dict[str, Any] | None = None, + ) -> LLMResponse: + """Send messages via /chat/completions and parse the response. + + ``inbound_anthropic_body`` is accepted to satisfy the LLMClient + protocol but ignored — Path-1 Anthropic forwarding doesn't apply + to OpenAI-shape clients. + """ + del inbound_anthropic_body # protocol-only, never read here + body = self._build_body(messages, tools, sampling, stream=False, passthrough=passthrough) + try: + resp = await self._http.post(f"{self.base_url}/chat/completions", json=body) + except httpx.ReadTimeout as exc: + raise BackendError(408, "Read timeout") from exc + + if resp.status_code != 200: + raise BackendError(resp.status_code, resp.text) + + data = resp.json() + self._record_usage(data) + + msg = data["choices"][0]["message"] + tool_calls = msg.get("tool_calls") + if tool_calls: + return self._parse_tool_calls(tool_calls) + return TextResponse(content=msg.get("content") or "") + + # ── streaming ──────────────────────────────────────────────────── + + async def send_stream( + self, + messages: list[dict[str, str]], + tools: list[ToolSpec] | None = None, + sampling: dict[str, Any] | None = None, + passthrough: dict[str, Any] | None = None, + inbound_anthropic_body: dict[str, Any] | None = None, + ) -> AsyncIterator[StreamChunk]: + """Stream via SSE from /chat/completions. + + ``inbound_anthropic_body`` is accepted to satisfy the LLMClient + protocol but ignored — see :meth:`send`. + """ + del inbound_anthropic_body # protocol-only, never read here + body = self._build_body(messages, tools, sampling, stream=True, passthrough=passthrough) + + accumulated_content = "" + tool_calls: dict[int, dict[str, Any]] = {} + usage: dict[str, Any] | None = None + + async with self._http.stream( + "POST", f"{self.base_url}/chat/completions", json=body + ) as response: + if response.status_code != 200: + error_body = await response.aread() + raise BackendError(response.status_code, error_body.decode(errors="replace")) + + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + chunk = json.loads(data_str) + if chunk.get("usage"): + usage = chunk["usage"] + choices = chunk.get("choices") or [] + if not choices: + continue + delta = choices[0].get("delta", {}) + + content = delta.get("content") + if content is not None: + if not isinstance(content, str): + content = str(content) + if content: + accumulated_content += content + yield StreamChunk(type=ChunkType.TEXT_DELTA, content=content) + + for tc in delta.get("tool_calls") or []: + idx = tc.get("index", 0) + slot = tool_calls.setdefault( + idx, {"function": {"name": "", "arguments": ""}} + ) + fn = tc.get("function", {}) + if fn.get("name"): + slot["function"]["name"] += str(fn["name"]) + # OpenAI streaming sends `arguments` as string fragments + # that we concatenate into the final JSON string. Anything + # non-string would be a non-compliant provider; drop it + # rather than json.dumps a partial object into the buffer. + args_frag = fn.get("arguments") + if isinstance(args_frag, str): + slot["function"]["arguments"] += args_frag + + if usage: + self._record_usage({"usage": usage}) + + if tool_calls: + ordered = [tool_calls[i] for i in sorted(tool_calls)] + final: LLMResponse = self._parse_tool_calls(ordered) + else: + final = TextResponse(content=accumulated_content) + yield StreamChunk(type=ChunkType.FINAL, response=final) + + async def get_context_length(self) -> int | None: + """OpenAI-compatible endpoints don't expose context length. Returns None.""" + return None diff --git a/tests/unit/test_openai_compat_client.py b/tests/unit/test_openai_compat_client.py new file mode 100644 index 0000000..8461277 --- /dev/null +++ b/tests/unit/test_openai_compat_client.py @@ -0,0 +1,478 @@ +"""Tests for forge.clients.openai_compat — OpenAICompatClient with mocked HTTP.""" + +import json + +import pytest +from pydantic import BaseModel, Field +from unittest.mock import AsyncMock, MagicMock + +from forge.clients.openai_compat import OpenAICompatClient +from forge.clients.base import ChunkType +from forge.core.workflow import TextResponse, ToolCall, ToolSpec +from forge.errors import BackendError + + +class PartParams(BaseModel): + part: str = Field(description="Part number") + + +def _make_spec(name: str = "get_pricing") -> ToolSpec: + return ToolSpec(name=name, description=f"Get {name}", parameters=PartParams) + + +def _make_client(model: str = "test-model", api_key: str = "tok") -> OpenAICompatClient: + client = OpenAICompatClient( + base_url="https://api.example.com/v1", model=model, api_key=api_key + ) + mock_http = AsyncMock() + mock_http.stream = MagicMock() # sync method returning async context manager + client._http = mock_http + return client + + +def _mock_response(data: dict, status_code: int = 200) -> MagicMock: + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = data + resp.text = json.dumps(data) + return resp + + +class _MockStreamResponse: + """Mock for httpx streaming response with aiter_lines / aread.""" + + def __init__(self, lines: list[str], status_code: int = 200) -> None: + self._lines = lines + self.status_code = status_code + + async def aiter_lines(self): + for line in self._lines: + yield line + + async def aread(self) -> bytes: + return "".join(self._lines).encode() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +# ── send ───────────────────────────────────────────────────────── + + +class TestSend: + @pytest.mark.asyncio + async def test_returns_tool_call(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{ + "message": { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "get_pricing", "arguments": '{"part": "X123"}'}, + }], + } + }] + }) + result = await client.send( + [{"role": "user", "content": "test"}], tools=[_make_spec()] + ) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].tool == "get_pricing" + assert result[0].args == {"part": "X123"} + + @pytest.mark.asyncio + async def test_returns_text_response(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "I need more info"}}] + }) + result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, TextResponse) + assert result.content == "I need more info" + + @pytest.mark.asyncio + async def test_null_content_returns_empty_text(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": None}}] + }) + result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, TextResponse) + assert result.content == "" + + @pytest.mark.asyncio + async def test_malformed_tool_args_fall_back_to_empty(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "function": {"name": "get_pricing", "arguments": "{not json"}, + }], + } + }] + }) + result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, list) + assert result[0].args == {} + + @pytest.mark.asyncio + async def test_formats_tools_in_request(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send([{"role": "user", "content": "test"}], tools=[_make_spec()]) + + body = client._http.post.call_args.kwargs["json"] + assert "tools" in body + tool = body["tools"][0] + assert tool["type"] == "function" + assert tool["function"]["name"] == "get_pricing" + assert "parameters" in tool["function"] + + @pytest.mark.asyncio + async def test_request_body_structure(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send([{"role": "user", "content": "hi"}]) + + body = client._http.post.call_args.kwargs["json"] + assert body["model"] == "test-model" + assert body["stream"] is False + assert body["messages"] == [{"role": "user", "content": "hi"}] + # No temperature passed → not in body + assert "temperature" not in body + + @pytest.mark.asyncio + async def test_posts_to_chat_completions(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send([{"role": "user", "content": "hi"}]) + url = client._http.post.call_args.args[0] + assert url == "https://api.example.com/v1/chat/completions" + + @pytest.mark.asyncio + async def test_sampling_override(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send( + [{"role": "user", "content": "hi"}], sampling={"temperature": 0.2} + ) + body = client._http.post.call_args.kwargs["json"] + assert body["temperature"] == 0.2 + + @pytest.mark.asyncio + async def test_http_error_raises_backend_error(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({"error": "bad"}, status_code=401) + with pytest.raises(BackendError) as exc: + await client.send([{"role": "user", "content": "test"}]) + assert exc.value.status_code == 401 + + @pytest.mark.asyncio + async def test_records_usage(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }) + await client.send([{"role": "user", "content": "test"}]) + assert client.last_usage[0].prompt_tokens == 10 + assert client.last_usage[0].completion_tokens == 5 + assert client.last_usage[0].total_tokens == 15 + + +# ── auth ───────────────────────────────────────────────────────── + + +class TestAuth: + def test_bearer_header_set_when_key_provided(self) -> None: + client = OpenAICompatClient( + base_url="https://x/v1", model="m", api_key="secret" + ) + assert client._http.headers["Authorization"] == "Bearer secret" + + def test_no_auth_header_when_no_key(self) -> None: + client = OpenAICompatClient(base_url="https://x/v1", model="m") + assert "Authorization" not in client._http.headers + + def test_base_url_trailing_slash_stripped(self) -> None: + client = OpenAICompatClient(base_url="https://x/v1/", model="m") + assert client.base_url == "https://x/v1" + + +# ── send_stream ────────────────────────────────────────────────── + + +class TestSendStream: + @pytest.mark.asyncio + async def test_yields_text_deltas_and_final(self) -> None: + client = _make_client() + lines = [ + 'data: ' + json.dumps({"choices": [{"delta": {"content": "Hello"}}]}), + 'data: ' + json.dumps({"choices": [{"delta": {"content": " world"}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + + chunks = [] + async for chunk in client.send_stream([{"role": "user", "content": "hi"}]): + chunks.append(chunk) + + text_deltas = [c for c in chunks if c.type == ChunkType.TEXT_DELTA] + assert [c.content for c in text_deltas] == ["Hello", " world"] + + finals = [c for c in chunks if c.type == ChunkType.FINAL] + assert len(finals) == 1 + assert isinstance(finals[0].response, TextResponse) + assert finals[0].response.content == "Hello world" + + @pytest.mark.asyncio + async def test_yields_final_with_tool_call(self) -> None: + client = _make_client() + lines = [ + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "id": "c1", "function": {"name": "get_pricing", "arguments": ""}} + ]}}]}), + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": '{"part": '}} + ]}}]}), + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": '"X"}'}} + ]}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + + chunks = [] + async for chunk in client.send_stream( + [{"role": "user", "content": "test"}], tools=[_make_spec()] + ): + chunks.append(chunk) + + finals = [c for c in chunks if c.type == ChunkType.FINAL] + assert len(finals) == 1 + assert isinstance(finals[0].response, list) + assert finals[0].response[0].tool == "get_pricing" + assert finals[0].response[0].args == {"part": "X"} + + @pytest.mark.asyncio + async def test_stream_http_error_raises(self) -> None: + client = _make_client() + client._http.stream.return_value = _MockStreamResponse( + ['{"error": "nope"}'], status_code=500 + ) + with pytest.raises(BackendError) as exc: + async for _ in client.send_stream([{"role": "user", "content": "x"}]): + pass + assert exc.value.status_code == 500 + + @pytest.mark.asyncio + async def test_ignores_non_data_lines(self) -> None: + client = _make_client() + lines = [ + '', + ': keep-alive comment', + 'data: ' + json.dumps({"choices": [{"delta": {"content": "hi"}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + chunks = [c async for c in client.send_stream([{"role": "user", "content": "x"}])] + finals = [c for c in chunks if c.type == ChunkType.FINAL] + assert finals[0].response.content == "hi" + + +class TestContextLength: + @pytest.mark.asyncio + async def test_returns_none(self) -> None: + client = _make_client() + assert await client.get_context_length() is None + + +# ── constructor ────────────────────────────────────────────────── + + +class TestConstructor: + def test_extra_headers_merged_alongside_bearer(self) -> None: + client = OpenAICompatClient( + model="test-model", + base_url="https://x/v1", + api_key="tok", + extra_headers={"HTTP-Referer": "https://example.com", "X-Title": "MyApp"}, + ) + # httpx normalizes header names to lowercase. + assert client._http.headers["authorization"] == "Bearer tok" + assert client._http.headers["http-referer"] == "https://example.com" + assert client._http.headers["x-title"] == "MyApp" + + def test_extra_headers_can_override_authorization(self) -> None: + client = OpenAICompatClient( + model="test-model", + base_url="https://x/v1", + api_key="ignored", + extra_headers={"Authorization": "ApiKey custom-scheme"}, + ) + assert client._http.headers["authorization"] == "ApiKey custom-scheme" + + def test_sampling_kwargs_stored_as_instance_fields(self) -> None: + client = OpenAICompatClient( + model="test-model", + base_url="https://x/v1", + top_k=40, + min_p=0.05, + repeat_penalty=1.1, + presence_penalty=0.2, + chat_template_kwargs={"enable_thinking": True}, + ) + assert client.top_k == 40 + assert client.min_p == 0.05 + assert client.repeat_penalty == 1.1 + assert client.presence_penalty == 0.2 + assert client.chat_template_kwargs == {"enable_thinking": True} + + def test_recommended_sampling_off_is_silent_for_unknown_model(self) -> None: + # Default behavior: unknown model -> empty defaults, explicit kwargs flow through. + client = OpenAICompatClient( + model="definitely-not-in-registry-zzz", + base_url="https://x/v1", + ) + assert client.temperature is None + assert client.top_p is None + assert client.top_k is None + + def test_recommended_sampling_on_raises_for_unknown_model(self) -> None: + from forge.errors import UnsupportedModelError + with pytest.raises(UnsupportedModelError): + OpenAICompatClient( + model="definitely-not-in-registry-zzz", + base_url="https://x/v1", + recommended_sampling=True, + ) + + +# ── instance sampling flows into request body ──────────────────── + + +class TestSendInstanceSampling: + @pytest.mark.asyncio + @pytest.mark.parametrize("field,value", [ + ("top_k", 40), + ("min_p", 0.05), + ("repeat_penalty", 1.1), + ("presence_penalty", 0.2), + ("chat_template_kwargs", {"enable_thinking": True}), + ]) + async def test_instance_sampling_flows_into_body(self, field, value) -> None: + client = OpenAICompatClient( + model="test-model", + base_url="https://x/v1", + api_key="tok", + **{field: value}, + ) + mock_http = AsyncMock() + client._http = mock_http + mock_http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send([{"role": "user", "content": "hi"}]) + body = mock_http.post.call_args.kwargs["json"] + assert body[field] == value + + +# ── passthrough + inbound_anthropic_body ───────────────────────── + + +class TestPassthrough: + @pytest.mark.asyncio + async def test_passthrough_fields_appear_in_body(self) -> None: + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send( + [{"role": "user", "content": "hi"}], + passthrough={"max_tokens": 512, "stop": ["END"], "tool_choice": "auto"}, + ) + body = client._http.post.call_args.kwargs["json"] + assert body["max_tokens"] == 512 + assert body["stop"] == ["END"] + assert body["tool_choice"] == "auto" + + @pytest.mark.asyncio + async def test_forge_owned_fields_override_passthrough(self) -> None: + # Proxy may include "model"/"messages"/"stream" in passthrough; forge's + # values must win to keep its invariants. + client = _make_client(model="forge-model") + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send( + [{"role": "user", "content": "hi"}], + passthrough={ + "model": "evil", + "messages": [{"role": "system", "content": "evil"}], + "stream": True, + }, + ) + body = client._http.post.call_args.kwargs["json"] + assert body["model"] == "forge-model" + assert body["messages"] == [{"role": "user", "content": "hi"}] + assert body["stream"] is False + + @pytest.mark.asyncio + async def test_inbound_anthropic_body_accepted_and_ignored(self) -> None: + # Protocol shape compatibility: accept the kwarg, never let it leak + # into the outbound OpenAI-shape body. + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{"message": {"role": "assistant", "content": "ok"}}] + }) + await client.send( + [{"role": "user", "content": "hi"}], + inbound_anthropic_body={"some_anthropic_field": "value", "shape": "data"}, + ) + body = client._http.post.call_args.kwargs["json"] + assert "some_anthropic_field" not in body + assert "shape" not in body + + @pytest.mark.asyncio + async def test_send_stream_accepts_passthrough_and_inbound(self) -> None: + client = _make_client() + lines = [ + 'data: ' + json.dumps({"choices": [{"delta": {"content": "ok"}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + chunks = [c async for c in client.send_stream( + [{"role": "user", "content": "x"}], + passthrough={"max_tokens": 64}, + inbound_anthropic_body={"ignored": True}, + )] + # If the kwargs were rejected, we'd never get here. + assert any(c.type == ChunkType.FINAL for c in chunks) + + +# ── aclose ─────────────────────────────────────────────────────── + + +class TestAclose: + @pytest.mark.asyncio + async def test_aclose_closes_http_pool(self) -> None: + client = _make_client() + await client.aclose() + client._http.aclose.assert_awaited_once() From 660a53cfd26c3c6a15dfbb05a740f2dc154b00c0 Mon Sep 17 00:00:00 2001 From: Lucas Gerads Date: Sat, 30 May 2026 09:09:54 +0200 Subject: [PATCH 2/2] Fail loud on malformed tool-call args in OpenAICompatClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses Antoine's review comments on PR #89 (CHANGES_REQUESTED). Per Forge's fail-fast/fail-loud preference, malformed tool-call argument JSON must not be coerced to {} and executed as fn(**{}) — that's a quiet false success. Route it through the retry-driving TextResponse path instead, matching LlamafileClient (the reference; not VLLMClient). One fix in the two spots the inline comments pointed at: - _parse_tool_calls (non-streaming): on JSONDecodeError, return a TextResponse instead of {}, so the validator's rescue-parse + retry/ nudge loop can drive a correction. Surfaces assistant text when present, else the raw malformed args. Matches llamafile.py:523-528. - send_stream: final assembly now routes the same way (matches llamafile.py:400-414), and a non-string arg fragment is serialized into the buffer rather than silently dropped (dropping left a gap that could parse into wrong-but-valid args). Tests: flip the mis-asserting malformed-args test to expect TextResponse; add assistant-text-surfaced, dict-args-accepted, and two streaming cases. --- src/forge/clients/openai_compat.py | 55 ++++++++--- tests/unit/test_openai_compat_client.py | 118 +++++++++++++++++++++++- 2 files changed, 157 insertions(+), 16 deletions(-) diff --git a/src/forge/clients/openai_compat.py b/src/forge/clients/openai_compat.py index 2c0b131..96ead28 100644 --- a/src/forge/clients/openai_compat.py +++ b/src/forge/clients/openai_compat.py @@ -148,16 +148,34 @@ def _record_usage(self, data: dict[str, Any]) -> None: ) @staticmethod - def _parse_tool_calls(tool_calls: list[dict[str, Any]]) -> list[ToolCall]: - """Parse OpenAI tool_calls — arguments are JSON strings.""" + def _parse_tool_calls( + tool_calls: list[dict[str, Any]], fallback_content: str = "" + ) -> LLMResponse: + """Parse OpenAI ``tool_calls`` into ``ToolCall`` objects. + + Tool-call ``arguments`` arrive as JSON strings. Forge is fail-loud: + malformed argument JSON must NOT be coerced into executable empty args, + or a provider/model can emit invalid arguments and Forge proceeds with + ``fn(**{})`` — exactly the quiet false success the library avoids. + Instead we return a ``TextResponse``, which routes the response back + into the validator's rescue-parse + retry/nudge loop, matching + ``LlamafileClient`` (see ``llamafile.py`` ``_send_native``). + + ``fallback_content`` is the assistant message text to surface for the + rescue attempt; we fall back to the raw malformed args when there is no + text, so the rescue parser still has the original JSON to work with. + """ parsed: list[ToolCall] = [] for tc in tool_calls: fn = tc.get("function", {}) raw_args = fn.get("arguments") or "{}" - try: - args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args - except json.JSONDecodeError: - args = {} + if isinstance(raw_args, str): + try: + args = json.loads(raw_args) + except json.JSONDecodeError: + return TextResponse(content=fallback_content or raw_args) + else: + args = raw_args parsed.append(ToolCall(tool=fn.get("name", ""), args=args)) return parsed @@ -193,7 +211,7 @@ async def send( msg = data["choices"][0]["message"] tool_calls = msg.get("tool_calls") if tool_calls: - return self._parse_tool_calls(tool_calls) + return self._parse_tool_calls(tool_calls, fallback_content=msg.get("content") or "") return TextResponse(content=msg.get("content") or "") # ── streaming ──────────────────────────────────────────────────── @@ -255,20 +273,29 @@ async def send_stream( fn = tc.get("function", {}) if fn.get("name"): slot["function"]["name"] += str(fn["name"]) - # OpenAI streaming sends `arguments` as string fragments - # that we concatenate into the final JSON string. Anything - # non-string would be a non-compliant provider; drop it - # rather than json.dumps a partial object into the buffer. + # OpenAI streaming sends `arguments` as JSON-string + # fragments we concatenate into the final JSON string. A + # non-string fragment is a non-compliant provider; serialize + # it into the buffer rather than silently dropping it. + # Dropping leaves a gap in the assembled JSON that may parse + # into wrong-but-valid args (a quiet false success); folding + # it in instead means the single parse at stream end either + # recovers a whole-object fragment or fails loud into the + # TextResponse/retry path below, matching LlamafileClient. args_frag = fn.get("arguments") - if isinstance(args_frag, str): - slot["function"]["arguments"] += args_frag + if args_frag is not None: + slot["function"]["arguments"] += ( + args_frag if isinstance(args_frag, str) else json.dumps(args_frag) + ) if usage: self._record_usage({"usage": usage}) if tool_calls: ordered = [tool_calls[i] for i in sorted(tool_calls)] - final: LLMResponse = self._parse_tool_calls(ordered) + final: LLMResponse = self._parse_tool_calls( + ordered, fallback_content=accumulated_content + ) else: final = TextResponse(content=accumulated_content) yield StreamChunk(type=ChunkType.FINAL, response=final) diff --git a/tests/unit/test_openai_compat_client.py b/tests/unit/test_openai_compat_client.py index 8461277..9aaff77 100644 --- a/tests/unit/test_openai_compat_client.py +++ b/tests/unit/test_openai_compat_client.py @@ -108,7 +108,12 @@ async def test_null_content_returns_empty_text(self) -> None: assert result.content == "" @pytest.mark.asyncio - async def test_malformed_tool_args_fall_back_to_empty(self) -> None: + async def test_malformed_tool_args_return_text_response(self) -> None: + # Fail-loud: malformed argument JSON must NOT become an executable + # empty-args tool call. It returns a TextResponse so the runner's + # rescue-parse + retry/nudge loop can drive a correction. With no + # assistant text present, the raw malformed args are surfaced so the + # rescue parser still has the original JSON to work with. client = _make_client() client._http.post.return_value = _mock_response({ "choices": [{ @@ -121,8 +126,69 @@ async def test_malformed_tool_args_fall_back_to_empty(self) -> None: }] }) result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, TextResponse) + assert result.content == "{not json" + + @pytest.mark.asyncio + async def test_malformed_tool_args_surface_assistant_text(self) -> None: + # When the assistant also produced text alongside the malformed + # tool-call args, the TextResponse carries that text (the more useful + # signal for rescue), not the raw broken JSON. + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{ + "message": { + "role": "assistant", + "content": "Let me look that up.", + "tool_calls": [{ + "function": {"name": "get_pricing", "arguments": "{not json"}, + }], + } + }] + }) + result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, TextResponse) + assert result.content == "Let me look that up." + + @pytest.mark.asyncio + async def test_one_malformed_among_several_bails_whole_batch(self) -> None: + # Fail-loud for parallel tool calls: if ANY call's args are malformed, + # the entire response becomes a TextResponse — we never execute the + # valid sibling calls alongside a broken one (no partial execution). + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "get_pricing", "arguments": '{"part": "A"}'}}, + {"function": {"name": "get_pricing", "arguments": "{broken"}}, + ], + } + }] + }) + result = await client.send([{"role": "user", "content": "test"}]) + assert isinstance(result, TextResponse) + assert result.content == "{broken" + + @pytest.mark.asyncio + async def test_dict_tool_args_accepted(self) -> None: + # A provider that returns already-parsed (non-string) arguments is + # non-compliant but unambiguous — accept the dict rather than failing. + client = _make_client() + client._http.post.return_value = _mock_response({ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "function": {"name": "get_pricing", "arguments": {"part": "X123"}}, + }], + } + }] + }) + result = await client.send([{"role": "user", "content": "test"}]) assert isinstance(result, list) - assert result[0].args == {} + assert result[0].args == {"part": "X123"} @pytest.mark.asyncio async def test_formats_tools_in_request(self) -> None: @@ -271,6 +337,54 @@ async def test_yields_final_with_tool_call(self) -> None: assert finals[0].response[0].tool == "get_pricing" assert finals[0].response[0].args == {"part": "X"} + @pytest.mark.asyncio + async def test_stream_malformed_tool_args_return_text_response(self) -> None: + # Streaming counterpart of the non-streaming fail-loud case: arg + # fragments that never assemble into valid JSON must end as a + # TextResponse final, not a {}-args tool call. With no streamed text, + # the raw assembled args are surfaced for rescue. + client = _make_client() + lines = [ + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "function": {"name": "get_pricing", "arguments": "{not"}} + ]}}]}), + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": " json"}} + ]}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + + chunks = [c async for c in client.send_stream( + [{"role": "user", "content": "test"}], tools=[_make_spec()] + )] + finals = [c for c in chunks if c.type == ChunkType.FINAL] + assert len(finals) == 1 + assert isinstance(finals[0].response, TextResponse) + assert finals[0].response.content == "{not json" + + @pytest.mark.asyncio + async def test_stream_non_string_arg_fragment_not_dropped(self) -> None: + # A non-compliant provider that streams the whole arguments object as a + # single non-string fragment must not be silently skipped (which would + # leave empty args). It is serialized into the buffer and recovered. + client = _make_client() + lines = [ + 'data: ' + json.dumps({"choices": [{"delta": {"tool_calls": [ + {"index": 0, "function": {"name": "get_pricing", "arguments": {"part": "X9"}}} + ]}}]}), + 'data: [DONE]', + ] + client._http.stream.return_value = _MockStreamResponse(lines) + + chunks = [c async for c in client.send_stream( + [{"role": "user", "content": "test"}], tools=[_make_spec()] + )] + finals = [c for c in chunks if c.type == ChunkType.FINAL] + assert len(finals) == 1 + assert isinstance(finals[0].response, list) + assert finals[0].response[0].args == {"part": "X9"} + @pytest.mark.asyncio async def test_stream_http_error_raises(self) -> None: client = _make_client()