diff --git a/tests/conftest.py b/tests/conftest.py index e69de29..5a6c745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,104 @@ +"""Shared pytest fixtures and test doubles for the forge unit suite.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Literal + +from forge.clients.base import ChunkType, StreamChunk +from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec + + +# ── Shared LLM client double ───────────────────────────────────── + + +class MockClient: + """Configurable mock ``LLMClient`` that replays scripted responses. + + A single behavioral superset of three doubles that had drifted apart + across the suite. Behavior is selected entirely via constructor kwargs; + the defaults reproduce the most common case (the runner tests). + + Constructor knobs: + responses: + Scripted ``ToolCall`` / ``TextResponse`` entries, consumed one + per ``send`` / ``send_stream`` call. A bare ``ToolCall`` entry is + auto-wrapped to ``[ToolCall]`` to match + ``LLMResponse = list[ToolCall] | TextResponse``. + on_exhausted: + What to do once ``responses`` is exhausted. ``"raise"`` (default) + raises ``IndexError``. Otherwise pass a fallback response object + (e.g. ``TextResponse(content="stuck")``) which is returned + instead of raising. + stream_mode: + How ``send_stream`` behaves. ``"deltas"`` (default) yields a + ``TEXT_DELTA`` chunk then a ``FINAL`` chunk. ``"final"`` yields + only the ``FINAL`` chunk. ``"unsupported"`` raises + ``NotImplementedError``. + api_format: + Wire format string exposed on the instance (default ``"ollama"``). + context_length: + Value returned by ``get_context_length`` (default ``None``). + + Call-spy lists ``send_calls`` and ``send_stream_calls`` are always + populated; call sites that never read them are unaffected. + """ + + def __init__( + self, + responses: list[ToolCall | TextResponse], + *, + on_exhausted: Literal["raise"] | LLMResponse = "raise", + stream_mode: Literal["deltas", "final", "unsupported"] = "deltas", + api_format: str = "ollama", + context_length: int | None = None, + ): + self.responses = list(responses) + self._call_index = 0 + self._on_exhausted = on_exhausted + self._stream_mode = stream_mode + self.api_format = api_format + self._context_length = context_length + self.send_calls: list[tuple[list[dict], list[ToolSpec] | None]] = [] + self.send_stream_calls: list[tuple[list[dict], list[ToolSpec] | None]] = [] + + def _next(self) -> LLMResponse: + if self._call_index >= len(self.responses): + if self._on_exhausted == "raise": + raise IndexError("MockClient: scripted responses exhausted") + return self._on_exhausted + resp = self.responses[self._call_index] + self._call_index += 1 + if isinstance(resp, ToolCall): + return [resp] + return resp + + async def send( + self, + messages: list[dict[str, str]], + tools: list[ToolSpec] | None = None, + sampling: dict[str, object] | None = None, + passthrough: dict[str, object] | None = None, + inbound_anthropic_body: dict[str, object] | None = None, + ) -> LLMResponse: + self.send_calls.append((messages, tools)) + return self._next() + + async def send_stream( + self, + messages: list[dict[str, str]], + tools: list[ToolSpec] | None = None, + sampling: dict[str, object] | None = None, + passthrough: dict[str, object] | None = None, + inbound_anthropic_body: dict[str, object] | None = None, + ) -> AsyncIterator[StreamChunk]: + if self._stream_mode == "unsupported": + raise NotImplementedError + self.send_stream_calls.append((messages, tools)) + resp = self._next() + if self._stream_mode == "deltas": + yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial...") + yield StreamChunk(type=ChunkType.FINAL, response=resp) + + async def get_context_length(self) -> int | None: + return self._context_length diff --git a/tests/unit/test_eval_budget.py b/tests/unit/test_eval_budget.py index cc30832..cd2893f 100644 --- a/tests/unit/test_eval_budget.py +++ b/tests/unit/test_eval_budget.py @@ -2,55 +2,26 @@ from __future__ import annotations -from collections.abc import AsyncIterator +from functools import partial from typing import Any import pytest -from forge.clients.base import ChunkType, StreamChunk from forge.context.strategies import TieredCompact -from forge.core.workflow import LLMResponse, ToolCall, ToolSpec, TextResponse +from forge.core.workflow import ToolCall, TextResponse +from tests.conftest import MockClient from tests.eval.eval_runner import EvalConfig, RunResult, run_scenario from tests.eval.scenarios import compaction_chain_p1, basic_2step - -class _MockClient: - """Minimal client that returns a tool call on each send.""" - - api_format: str = "ollama" - - def __init__(self, calls: list[ToolCall]) -> None: - self._calls = list(calls) - self._idx = 0 - - async def send( - self, - messages: list[dict[str, str]], - tools: list[ToolSpec] | None = None, - sampling: dict[str, object] | None = None, - passthrough: dict[str, object] | None = None, - inbound_anthropic_body: dict[str, object] | None = None, - ) -> LLMResponse: - if self._idx < len(self._calls): - tc = self._calls[self._idx] - self._idx += 1 - return [tc] - return TextResponse(content="stuck") - - async def send_stream( - self, - messages: list[dict[str, str]], - tools: list[ToolSpec] | None = None, - sampling: dict[str, object] | None = None, - passthrough: dict[str, object] | None = None, - inbound_anthropic_body: dict[str, object] | None = None, - ) -> AsyncIterator[StreamChunk]: - resp = await self.send(messages, tools) - yield StreamChunk(type=ChunkType.FINAL, response=resp) - - async def get_context_length(self) -> int | None: - return None +# The eval-budget tests expect the client to return a "stuck" TextResponse +# once scripted calls run out (rather than raising) and to stream a single +# FINAL chunk with no text deltas. +_MockClient = partial( + MockClient, + on_exhausted=TextResponse(content="stuck"), + stream_mode="final", +) class TestBudgetOverride: diff --git a/tests/unit/test_runner.py b/tests/unit/test_runner.py index 9a33c07..e3e4d55 100644 --- a/tests/unit/test_runner.py +++ b/tests/unit/test_runner.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator from unittest.mock import MagicMock import pytest @@ -15,7 +14,6 @@ from forge.core.runner import WorkflowRunner from pydantic import BaseModel from forge.core.workflow import ( - LLMResponse, TextResponse, ToolCall, ToolDef, @@ -24,6 +22,8 @@ ) from forge.errors import MaxIterationsError, PrerequisiteError, StepEnforcementError, StreamError, ToolCallError, ToolExecutionError, ToolResolutionError, WorkflowCancelledError +from tests.conftest import MockClient + class EmptyParams(BaseModel): pass @@ -32,55 +32,6 @@ class EmptyParams(BaseModel): # ── Helpers ────────────────────────────────────────────────────── -class MockClient: - """Mock LLMClient that returns scripted responses. - - Accepts ToolCall or TextResponse in the response list. Single ToolCall - entries are automatically wrapped in a list to match the runner's - expected LLMResponse = list[ToolCall] | TextResponse. - """ - - def __init__(self, responses: list[ToolCall | TextResponse]): - self.responses = list(responses) - self._call_index = 0 - self.send_calls: list[tuple[list[dict], list[ToolSpec] | None]] = [] - self.send_stream_calls: list[tuple[list[dict], list[ToolSpec] | None]] = [] - - def _next(self) -> LLMResponse: - resp = self.responses[self._call_index] - self._call_index += 1 - if isinstance(resp, ToolCall): - return [resp] - return resp - - async def send( - self, - messages: list[dict[str, str]], - tools: list[ToolSpec] | None = None, - sampling: dict[str, object] | None = None, - passthrough: dict[str, object] | None = None, - inbound_anthropic_body: dict[str, object] | None = None, - ) -> LLMResponse: - self.send_calls.append((messages, tools)) - return self._next() - - async def send_stream( - self, - messages: list[dict[str, str]], - tools: list[ToolSpec] | None = None, - sampling: dict[str, object] | None = None, - passthrough: dict[str, object] | None = None, - inbound_anthropic_body: dict[str, object] | None = None, - ) -> AsyncIterator[StreamChunk]: - self.send_stream_calls.append((messages, tools)) - resp = self._next() - yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial...") - yield StreamChunk(type=ChunkType.FINAL, response=resp) - - async def get_context_length(self) -> int | None: - return None - - def _make_tool(name: str, fn=None) -> ToolDef: """Create a minimal ToolDef for testing.""" if fn is None: diff --git a/tests/unit/test_slot_worker.py b/tests/unit/test_slot_worker.py index 1fd2145..41fc240 100644 --- a/tests/unit/test_slot_worker.py +++ b/tests/unit/test_slot_worker.py @@ -4,6 +4,7 @@ import asyncio from collections.abc import AsyncIterator +from functools import partial from unittest.mock import AsyncMock import pytest @@ -23,6 +24,12 @@ from forge.errors import WorkflowCancelledError from pydantic import BaseModel +from tests.conftest import MockClient as _SharedMockClient + +# SlotWorker tests never stream; the shared double's send_stream is left +# unsupported to match the original local double's behavior exactly. +MockClient = partial(_SharedMockClient, stream_mode="unsupported") + class EmptyParams(BaseModel): pass @@ -54,28 +61,7 @@ def _make_workflow(terminal_tool: str = "submit") -> Workflow: ) -class MockClient: - """Mock LLMClient that returns scripted responses.""" - - def __init__(self, responses: list): - self.responses = list(responses) - self._call_index = 0 - - async def send(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None): - resp = self.responses[self._call_index] - self._call_index += 1 - if isinstance(resp, ToolCall): - return [resp] - return resp - - async def send_stream(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None): - raise NotImplementedError - - async def get_context_length(self): - return None - - -def _make_worker(client: MockClient) -> SlotWorker: +def _make_worker(client: _SharedMockClient) -> SlotWorker: ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000) runner = WorkflowRunner(client=client, context_manager=ctx) return SlotWorker(runner)