99from eval_protocol .models import EvaluationRow , Message
1010from eval_protocol .pytest .rollout_processor import RolloutProcessor
1111from eval_protocol .pytest .types import RolloutProcessorConfig
12- from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
12+ from openai .types .chat import ChatCompletion , ChatCompletionMessage , ChatCompletionMessageParam
1313from openai .types .chat .chat_completion import Choice as ChatCompletionChoice
1414from pydantic_ai .models .anthropic import AnthropicModel
1515from pydantic_ai .models .openai import OpenAIModel
@@ -36,7 +36,7 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3636
3737 def __init__ (self ):
3838 # dummy model used for its helper functions for processing messages
39- self .util = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
39+ self .util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
4040
4141 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
4242 """Create agent rollout tasks and return them for external handling."""
@@ -60,7 +60,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6060 raise ValueError (
6161 "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6262 )
63- kwargs = {}
63+ kwargs : dict = {}
6464 for k , v in config .completion_params ["model" ].items ():
6565 if v ["model" ] and v ["model" ].startswith ("anthropic:" ):
6666 kwargs [k ] = AnthropicModel (
@@ -75,10 +75,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
7575 v ["model" ],
7676 provider = v ["provider" ],
7777 )
78- agent = setup_agent (** kwargs )
78+ agent_instance : Agent = setup_agent (** kwargs )
7979 model = None
8080 else :
81- agent = config .kwargs ["agent" ]
81+ agent_instance = config .kwargs ["agent" ]
8282 model = OpenAIModel (
8383 config .completion_params ["model" ],
8484 provider = config .completion_params ["provider" ],
@@ -87,7 +87,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
8787 async def process_row (row : EvaluationRow ) -> EvaluationRow :
8888 """Process a single row with agent rollout."""
8989 model_messages = [self .convert_ep_message_to_pyd_message (m , row ) for m in row .messages ]
90- response = await agent .run (
90+ response = await agent_instance .run (
9191 message_history = model_messages , model = model , usage_limits = config .kwargs .get ("usage_limits" )
9292 )
9393 row .messages = await self .convert_pyd_message_to_ep_message (response .all_messages ())
@@ -104,11 +104,11 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104104
105105 async def convert_pyd_message_to_ep_message (self , messages : list [ModelMessage ]) -> list [Message ]:
106106 oai_messages : list [ChatCompletionMessageParam ] = await self .util ._map_messages (messages )
107- return [Message (** m ) for m in oai_messages ]
107+ return [Message (role = m [ "role" ], ** m ) for m in oai_messages ]
108108
109109 def convert_ep_message_to_pyd_message (self , message : Message , row : EvaluationRow ) -> ModelMessage :
110110 if message .role == "assistant" :
111- type_adapter = TypeAdapter (ChatCompletionAssistantMessageParam )
111+ type_adapter = TypeAdapter (ChatCompletionMessage )
112112 oai_message = type_adapter .validate_python (message )
113113 # Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
114114 return self .util ._process_response (
@@ -117,23 +117,23 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
117117 object = "chat.completion" ,
118118 model = "" ,
119119 id = "" ,
120- created = (
121- int (row .created_at .timestamp ())
122- if hasattr (row .created_at , "timestamp" )
123- else int (row .created_at )
124- ),
120+ created = int (row .created_at .timestamp ()),
125121 )
126122 )
127123 elif message .role == "user" :
128124 if isinstance (message .content , str ):
129125 return ModelRequest (parts = [UserPromptPart (content = message .content )])
130126 elif isinstance (message .content , list ):
131127 return ModelRequest (parts = [UserPromptPart (content = message .content [0 ].text )])
128+ else :
129+ raise ValueError (f"Unsupported content type for user message: { type (message .content )} " )
132130 elif message .role == "system" :
133131 if isinstance (message .content , str ):
134132 return ModelRequest (parts = [SystemPromptPart (content = message .content )])
135133 elif isinstance (message .content , list ):
136134 return ModelRequest (parts = [SystemPromptPart (content = message .content [0 ].text )])
135+ else :
136+ raise ValueError (f"Unsupported content type for system message: { type (message .content )} " )
137137 elif message .role == "tool" :
138138 return ModelRequest (
139139 parts = [
0 commit comments