Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Shared pytest fixtures and test doubles for the forge unit suite."""

from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Literal

from forge.clients.base import ChunkType, StreamChunk
from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec


# ── Shared LLM client double ─────────────────────────────────────


class MockClient:
"""Configurable mock ``LLMClient`` that replays scripted responses.

A single behavioral superset of three doubles that had drifted apart
across the suite. Behavior is selected entirely via constructor kwargs;
the defaults reproduce the most common case (the runner tests).

Constructor knobs:
responses:
Scripted ``ToolCall`` / ``TextResponse`` entries, consumed one
per ``send`` / ``send_stream`` call. A bare ``ToolCall`` entry is
auto-wrapped to ``[ToolCall]`` to match
``LLMResponse = list[ToolCall] | TextResponse``.
on_exhausted:
What to do once ``responses`` is exhausted. ``"raise"`` (default)
raises ``IndexError``. Otherwise pass a fallback response object
(e.g. ``TextResponse(content="stuck")``) which is returned
instead of raising.
stream_mode:
How ``send_stream`` behaves. ``"deltas"`` (default) yields a
``TEXT_DELTA`` chunk then a ``FINAL`` chunk. ``"final"`` yields
only the ``FINAL`` chunk. ``"unsupported"`` raises
``NotImplementedError``.
api_format:
Wire format string exposed on the instance (default ``"ollama"``).
context_length:
Value returned by ``get_context_length`` (default ``None``).

Call-spy lists ``send_calls`` and ``send_stream_calls`` are always
populated; call sites that never read them are unaffected.
"""

def __init__(
self,
responses: list[ToolCall | TextResponse],
*,
on_exhausted: Literal["raise"] | LLMResponse = "raise",
stream_mode: Literal["deltas", "final", "unsupported"] = "deltas",
api_format: str = "ollama",
context_length: int | None = None,
):
self.responses = list(responses)
self._call_index = 0
self._on_exhausted = on_exhausted
self._stream_mode = stream_mode
self.api_format = api_format
self._context_length = context_length
self.send_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []
self.send_stream_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []

def _next(self) -> LLMResponse:
if self._call_index >= len(self.responses):
if self._on_exhausted == "raise":
raise IndexError("MockClient: scripted responses exhausted")
return self._on_exhausted
resp = self.responses[self._call_index]
self._call_index += 1
if isinstance(resp, ToolCall):
return [resp]
return resp

async def send(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> LLMResponse:
self.send_calls.append((messages, tools))
return self._next()

async def send_stream(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> AsyncIterator[StreamChunk]:
if self._stream_mode == "unsupported":
raise NotImplementedError
self.send_stream_calls.append((messages, tools))
resp = self._next()
if self._stream_mode == "deltas":
yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial...")
yield StreamChunk(type=ChunkType.FINAL, response=resp)

async def get_context_length(self) -> int | None:
return self._context_length
51 changes: 11 additions & 40 deletions tests/unit/test_eval_budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,26 @@

from __future__ import annotations

from collections.abc import AsyncIterator
from functools import partial
from typing import Any

import pytest

from forge.clients.base import ChunkType, StreamChunk
from forge.context.strategies import TieredCompact
from forge.core.workflow import LLMResponse, ToolCall, ToolSpec, TextResponse
from forge.core.workflow import ToolCall, TextResponse

from tests.conftest import MockClient
from tests.eval.eval_runner import EvalConfig, RunResult, run_scenario
from tests.eval.scenarios import compaction_chain_p1, basic_2step


class _MockClient:
"""Minimal client that returns a tool call on each send."""

api_format: str = "ollama"

def __init__(self, calls: list[ToolCall]) -> None:
self._calls = list(calls)
self._idx = 0

async def send(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> LLMResponse:
if self._idx < len(self._calls):
tc = self._calls[self._idx]
self._idx += 1
return [tc]
return TextResponse(content="stuck")

async def send_stream(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> AsyncIterator[StreamChunk]:
resp = await self.send(messages, tools)
yield StreamChunk(type=ChunkType.FINAL, response=resp)

async def get_context_length(self) -> int | None:
return None
# The eval-budget tests expect the client to return a "stuck" TextResponse
# once scripted calls run out (rather than raising) and to stream a single
# FINAL chunk with no text deltas.
_MockClient = partial(
MockClient,
on_exhausted=TextResponse(content="stuck"),
stream_mode="final",
)


class TestBudgetOverride:
Expand Down
53 changes: 2 additions & 51 deletions tests/unit/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator
from unittest.mock import MagicMock

import pytest
Expand All @@ -15,7 +14,6 @@
from forge.core.runner import WorkflowRunner
from pydantic import BaseModel
from forge.core.workflow import (
LLMResponse,
TextResponse,
ToolCall,
ToolDef,
Expand All @@ -24,6 +22,8 @@
)
from forge.errors import MaxIterationsError, PrerequisiteError, StepEnforcementError, StreamError, ToolCallError, ToolExecutionError, ToolResolutionError, WorkflowCancelledError

from tests.conftest import MockClient


class EmptyParams(BaseModel):
pass
Expand All @@ -32,55 +32,6 @@ class EmptyParams(BaseModel):
# ── Helpers ──────────────────────────────────────────────────────


class MockClient:
"""Mock LLMClient that returns scripted responses.

Accepts ToolCall or TextResponse in the response list. Single ToolCall
entries are automatically wrapped in a list to match the runner's
expected LLMResponse = list[ToolCall] | TextResponse.
"""

def __init__(self, responses: list[ToolCall | TextResponse]):
self.responses = list(responses)
self._call_index = 0
self.send_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []
self.send_stream_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []

def _next(self) -> LLMResponse:
resp = self.responses[self._call_index]
self._call_index += 1
if isinstance(resp, ToolCall):
return [resp]
return resp

async def send(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> LLMResponse:
self.send_calls.append((messages, tools))
return self._next()

async def send_stream(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> AsyncIterator[StreamChunk]:
self.send_stream_calls.append((messages, tools))
resp = self._next()
yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial...")
yield StreamChunk(type=ChunkType.FINAL, response=resp)

async def get_context_length(self) -> int | None:
return None


def _make_tool(name: str, fn=None) -> ToolDef:
"""Create a minimal ToolDef for testing."""
if fn is None:
Expand Down
30 changes: 8 additions & 22 deletions tests/unit/test_slot_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
from collections.abc import AsyncIterator
from functools import partial
from unittest.mock import AsyncMock

import pytest
Expand All @@ -23,6 +24,12 @@
from forge.errors import WorkflowCancelledError
from pydantic import BaseModel

from tests.conftest import MockClient as _SharedMockClient

# SlotWorker tests never stream; the shared double's send_stream is left
# unsupported to match the original local double's behavior exactly.
MockClient = partial(_SharedMockClient, stream_mode="unsupported")


class EmptyParams(BaseModel):
pass
Expand Down Expand Up @@ -54,28 +61,7 @@ def _make_workflow(terminal_tool: str = "submit") -> Workflow:
)


class MockClient:
"""Mock LLMClient that returns scripted responses."""

def __init__(self, responses: list):
self.responses = list(responses)
self._call_index = 0

async def send(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None):
resp = self.responses[self._call_index]
self._call_index += 1
if isinstance(resp, ToolCall):
return [resp]
return resp

async def send_stream(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None):
raise NotImplementedError

async def get_context_length(self):
return None


def _make_worker(client: MockClient) -> SlotWorker:
def _make_worker(client: _SharedMockClient) -> SlotWorker:
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(client=client, context_manager=ctx)
return SlotWorker(runner)
Expand Down
Loading