Skip to content

Commit bfb5c2d

Browse files
committed
test: use real dependencies for rollout processor
1 parent c37ca45 commit bfb5c2d

File tree

1 file changed

+20
-117
lines changed

1 file changed

+20
-117
lines changed

tests/test_default_single_turn_rollout_processor.py

Lines changed: 20 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,27 @@
1-
import sys
2-
import types
3-
from dataclasses import dataclass
4-
from typing import Any, Dict, List
5-
61
import asyncio
7-
import pytest
8-
from pydantic import BaseModel
2+
from types import SimpleNamespace
3+
from typing import Any, Dict, List
94
from unittest import mock
5+
from openai.types.chat.chat_completion_message import (
6+
ChatCompletionMessageToolCall,
7+
ChatCompletionMessageToolCallFunction,
8+
)
109

11-
12-
# ---- Stub external dependencies ----
13-
openai = types.ModuleType("openai")
14-
types_mod = types.ModuleType("openai.types")
15-
chat_mod = types.ModuleType("openai.types.chat")
16-
chat_msg_mod = types.ModuleType("openai.types.chat.chat_completion_message")
17-
18-
19-
class FunctionCall(BaseModel):
20-
name: str
21-
arguments: str
22-
23-
24-
class ToolFunction(BaseModel):
25-
name: str
26-
arguments: str
27-
28-
29-
class ChatCompletionMessageToolCall(BaseModel):
30-
id: str
31-
type: str
32-
function: ToolFunction
33-
34-
35-
class CompletionUsage(BaseModel):
36-
prompt_tokens: int = 0
37-
completion_tokens: int = 0
38-
total_tokens: int = 0
39-
40-
41-
chat_msg_mod.FunctionCall = FunctionCall
42-
chat_msg_mod.ChatCompletionMessageToolCall = ChatCompletionMessageToolCall
43-
chat_mod.chat_completion_message = chat_msg_mod
44-
openai.types = types_mod
45-
types_mod.chat = chat_mod
46-
types_mod.CompletionUsage = CompletionUsage
47-
sys.modules["openai"] = openai
48-
sys.modules["openai.types"] = types_mod
49-
sys.modules["openai.types.chat"] = chat_mod
50-
sys.modules["openai.types.chat.chat_completion_message"] = chat_msg_mod
51-
52-
53-
# Stub litellm
54-
litellm = types.ModuleType("litellm")
55-
56-
57-
async def acompletion(**kwargs):
58-
raise NotImplementedError
59-
60-
61-
litellm.acompletion = acompletion
62-
sys.modules["litellm"] = litellm
63-
64-
65-
# Stub eval_protocol models and types
66-
class Message(BaseModel):
67-
role: str
68-
content: Any = ""
69-
name: str | None = None
70-
tool_call_id: str | None = None
71-
tool_calls: List[ChatCompletionMessageToolCall] | None = None
72-
function_call: FunctionCall | None = None
73-
74-
75-
class EvaluationRow(BaseModel):
76-
messages: List[Message]
77-
tools: Any = None
78-
ground_truth: Any = None
79-
80-
81-
@dataclass
82-
class RolloutProcessorConfig:
83-
model: str
84-
input_params: Dict[str, Any]
85-
mcp_config_path: str
86-
server_script_path: str | None = None
87-
max_concurrent_rollouts: int = 8
88-
steps: int = 30
89-
90-
91-
# Register stub modules
92-
import_path = "/workspace/python-sdk/eval_protocol"
93-
eval_protocol_pkg = types.ModuleType("eval_protocol")
94-
eval_protocol_pkg.__path__ = [import_path]
95-
models_module = types.ModuleType("eval_protocol.models")
96-
models_module.Message = Message
97-
models_module.EvaluationRow = EvaluationRow
98-
pytest_pkg = types.ModuleType("eval_protocol.pytest")
99-
pytest_pkg.__path__ = [f"{import_path}/pytest"]
100-
types_module = types.ModuleType("eval_protocol.pytest.types")
101-
types_module.RolloutProcessorConfig = RolloutProcessorConfig
102-
103-
sys.modules["eval_protocol"] = eval_protocol_pkg
104-
sys.modules["eval_protocol.models"] = models_module
105-
sys.modules["eval_protocol.pytest"] = pytest_pkg
106-
sys.modules["eval_protocol.pytest.types"] = types_module
107-
108-
109-
# Now we can import the rollout processor
10+
from eval_protocol.models import EvaluationRow, Message
11+
from eval_protocol.pytest.types import RolloutProcessorConfig
11012
from eval_protocol.pytest.default_single_turn_rollout_process import (
11113
default_single_turn_rollout_processor,
11214
)
11315

11416

115-
def test_handles_function_call_messages():
116-
async def run_test():
17+
def test_handles_function_call_messages() -> None:
18+
async def run_test() -> None:
11719
tool_call = ChatCompletionMessageToolCall(
11820
id="call_1",
11921
type="function",
120-
function=ToolFunction(name="get_weather", arguments="{}"),
22+
function=ChatCompletionMessageToolCallFunction(
23+
name="get_weather", arguments="{}"
24+
),
12125
)
12226
row = EvaluationRow(
12327
messages=[
@@ -133,19 +37,21 @@ async def run_test():
13337

13438
captured_messages: List[Dict[str, Any]] = []
13539

136-
async def fake_acompletion(**kwargs):
40+
async def fake_acompletion(**kwargs: Any) -> Any:
13741
nonlocal captured_messages
13842
captured_messages = kwargs["messages"]
139-
return types.SimpleNamespace(
43+
return SimpleNamespace(
14044
choices=[
141-
types.SimpleNamespace(
142-
message=types.SimpleNamespace(
45+
SimpleNamespace(
46+
message=SimpleNamespace(
14347
content="done",
14448
tool_calls=[
14549
ChatCompletionMessageToolCall(
14650
id="call_2",
14751
type="function",
148-
function=ToolFunction(name="foo", arguments="{}"),
52+
function=ChatCompletionMessageToolCallFunction(
53+
name="foo", arguments="{}"
54+
),
14955
)
15056
],
15157
function_call=None,
@@ -154,9 +60,6 @@ async def fake_acompletion(**kwargs):
15460
]
15561
)
15662

157-
with pytest.raises(NotImplementedError):
158-
await acompletion()
159-
16063
with mock.patch(
16164
"eval_protocol.pytest.default_single_turn_rollout_process.acompletion",
16265
side_effect=fake_acompletion,

0 commit comments

Comments
 (0)