1- import sys
2- import types
3- from dataclasses import dataclass
4- from typing import Any , Dict , List
5-
61import asyncio
7- import pytest
8- from pydantic import BaseModel
2+ from types import SimpleNamespace
3+ from typing import Any , Dict , List
94from unittest import mock
5+ from openai .types .chat .chat_completion_message import (
6+ ChatCompletionMessageToolCall ,
7+ ChatCompletionMessageToolCallFunction ,
8+ )
109
11-
12- # ---- Stub external dependencies ----
13- openai = types .ModuleType ("openai" )
14- types_mod = types .ModuleType ("openai.types" )
15- chat_mod = types .ModuleType ("openai.types.chat" )
16- chat_msg_mod = types .ModuleType ("openai.types.chat.chat_completion_message" )
17-
18-
19- class FunctionCall (BaseModel ):
20- name : str
21- arguments : str
22-
23-
24- class ToolFunction (BaseModel ):
25- name : str
26- arguments : str
27-
28-
29- class ChatCompletionMessageToolCall (BaseModel ):
30- id : str
31- type : str
32- function : ToolFunction
33-
34-
35- class CompletionUsage (BaseModel ):
36- prompt_tokens : int = 0
37- completion_tokens : int = 0
38- total_tokens : int = 0
39-
40-
41- chat_msg_mod .FunctionCall = FunctionCall
42- chat_msg_mod .ChatCompletionMessageToolCall = ChatCompletionMessageToolCall
43- chat_mod .chat_completion_message = chat_msg_mod
44- openai .types = types_mod
45- types_mod .chat = chat_mod
46- types_mod .CompletionUsage = CompletionUsage
47- sys .modules ["openai" ] = openai
48- sys .modules ["openai.types" ] = types_mod
49- sys .modules ["openai.types.chat" ] = chat_mod
50- sys .modules ["openai.types.chat.chat_completion_message" ] = chat_msg_mod
51-
52-
53- # Stub litellm
54- litellm = types .ModuleType ("litellm" )
55-
56-
57- async def acompletion (** kwargs ):
58- raise NotImplementedError
59-
60-
61- litellm .acompletion = acompletion
62- sys .modules ["litellm" ] = litellm
63-
64-
65- # Stub eval_protocol models and types
66- class Message (BaseModel ):
67- role : str
68- content : Any = ""
69- name : str | None = None
70- tool_call_id : str | None = None
71- tool_calls : List [ChatCompletionMessageToolCall ] | None = None
72- function_call : FunctionCall | None = None
73-
74-
75- class EvaluationRow (BaseModel ):
76- messages : List [Message ]
77- tools : Any = None
78- ground_truth : Any = None
79-
80-
81- @dataclass
82- class RolloutProcessorConfig :
83- model : str
84- input_params : Dict [str , Any ]
85- mcp_config_path : str
86- server_script_path : str | None = None
87- max_concurrent_rollouts : int = 8
88- steps : int = 30
89-
90-
91- # Register stub modules
92- import_path = "/workspace/python-sdk/eval_protocol"
93- eval_protocol_pkg = types .ModuleType ("eval_protocol" )
94- eval_protocol_pkg .__path__ = [import_path ]
95- models_module = types .ModuleType ("eval_protocol.models" )
96- models_module .Message = Message
97- models_module .EvaluationRow = EvaluationRow
98- pytest_pkg = types .ModuleType ("eval_protocol.pytest" )
99- pytest_pkg .__path__ = [f"{ import_path } /pytest" ]
100- types_module = types .ModuleType ("eval_protocol.pytest.types" )
101- types_module .RolloutProcessorConfig = RolloutProcessorConfig
102-
103- sys .modules ["eval_protocol" ] = eval_protocol_pkg
104- sys .modules ["eval_protocol.models" ] = models_module
105- sys .modules ["eval_protocol.pytest" ] = pytest_pkg
106- sys .modules ["eval_protocol.pytest.types" ] = types_module
107-
108-
109- # Now we can import the rollout processor
10+ from eval_protocol .models import EvaluationRow , Message
11+ from eval_protocol .pytest .types import RolloutProcessorConfig
11012from eval_protocol .pytest .default_single_turn_rollout_process import (
11113 default_single_turn_rollout_processor ,
11214)
11315
11416
115- def test_handles_function_call_messages ():
116- async def run_test ():
17+ def test_handles_function_call_messages () -> None :
18+ async def run_test () -> None :
11719 tool_call = ChatCompletionMessageToolCall (
11820 id = "call_1" ,
11921 type = "function" ,
120- function = ToolFunction (name = "get_weather" , arguments = "{}" ),
22+ function = ChatCompletionMessageToolCallFunction (
23+ name = "get_weather" , arguments = "{}"
24+ ),
12125 )
12226 row = EvaluationRow (
12327 messages = [
@@ -133,19 +37,21 @@ async def run_test():
13337
13438 captured_messages : List [Dict [str , Any ]] = []
13539
136- async def fake_acompletion (** kwargs ) :
40+ async def fake_acompletion (** kwargs : Any ) -> Any :
13741 nonlocal captured_messages
13842 captured_messages = kwargs ["messages" ]
139- return types . SimpleNamespace (
43+ return SimpleNamespace (
14044 choices = [
141- types . SimpleNamespace (
142- message = types . SimpleNamespace (
45+ SimpleNamespace (
46+ message = SimpleNamespace (
14347 content = "done" ,
14448 tool_calls = [
14549 ChatCompletionMessageToolCall (
14650 id = "call_2" ,
14751 type = "function" ,
148- function = ToolFunction (name = "foo" , arguments = "{}" ),
52+ function = ChatCompletionMessageToolCallFunction (
53+ name = "foo" , arguments = "{}"
54+ ),
14955 )
15056 ],
15157 function_call = None ,
@@ -154,9 +60,6 @@ async def fake_acompletion(**kwargs):
15460 ]
15561 )
15662
157- with pytest .raises (NotImplementedError ):
158- await acompletion ()
159-
16063 with mock .patch (
16164 "eval_protocol.pytest.default_single_turn_rollout_process.acompletion" ,
16265 side_effect = fake_acompletion ,
0 commit comments