44from collections .abc import Callable
55import logging
66import time
7- import types
8- from pydantic_ai .models import Model
97from pydantic_ai .usage import UsageLimits
108from typing_extensions import override
119from eval_protocol .models import EvaluationRow , Message
12- from openai .types import CompletionUsage
1310from eval_protocol .pytest .rollout_processor import RolloutProcessor
1411from eval_protocol .pytest .types import RolloutProcessorConfig
1512from openai .types .chat import ChatCompletion , ChatCompletionMessage , ChatCompletionMessageParam
1613from openai .types .chat .chat_completion import Choice as ChatCompletionChoice
17- from openai .types .chat .chat_completion_assistant_message_param import (
18- ChatCompletionAssistantMessageParam ,
19- )
2014from pydantic import TypeAdapter
2115from pydantic_ai import Agent
2216from pydantic_ai ._utils import generate_tool_call_id
2721 ToolReturnPart ,
2822 UserPromptPart ,
2923)
30- from pydantic_ai .models .anthropic import AnthropicModel
31- from pydantic_ai .models .google import GoogleModel
3224from pydantic_ai .models .openai import OpenAIModel
3325from pydantic_ai .providers .openai import OpenAIProvider
3426
@@ -39,64 +31,27 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3931 """Rollout processor for Pydantic AI agents. Mainly converts
4032 EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
4133
42- def __init__ (self , setup_agent : Callable [..., Agent ] | Agent , usage_limits : UsageLimits | None = None ):
34+ def __init__ (
35+ self , agent_factory : Callable [[RolloutProcessorConfig ], Agent ], usage_limits : UsageLimits | None = None
36+ ):
4337 # dummy model used for its helper functions for processing messages
44- self .util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
38+ self ._util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
39+ self ._setup_agent = agent_factory
4540
4641 @override
4742 def __call__ (self , rows : list [EvaluationRow ], config : RolloutProcessorConfig ) -> list [asyncio .Task [EvaluationRow ]]:
4843 """Create agent rollout tasks and return them for external handling."""
4944
5045 semaphore = config .semaphore
5146
52- # validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
53- if "agent" not in config .kwargs :
54- raise ValueError ("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance" )
55- if not isinstance (config .kwargs ["agent" ], Agent ) and not isinstance (
56- config .kwargs ["agent" ], types .FunctionType
57- ):
58- raise ValueError (
59- "kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
60- )
61-
62- if isinstance (config .kwargs ["agent" ], types .FunctionType ):
63- setup_agent = config .kwargs ["agent" ]
64- if not isinstance (config .completion_params ["model" ], dict ):
65- raise ValueError (
66- "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
67- )
68- kwargs : dict [str , Model ] = {}
69- for k , v in config .completion_params ["model" ].items (): # pyright: ignore[reportUnknownVariableType]
70- if v ["model" ] and v ["model" ].startswith ("anthropic:" ): # pyright: ignore[reportUnknownMemberType]
71- kwargs [k ] = AnthropicModel (
72- v ["model" ].removeprefix ("anthropic:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
73- )
74- elif v ["model" ] and v ["model" ].startswith ("google:" ): # pyright: ignore[reportUnknownMemberType]
75- kwargs [k ] = GoogleModel (
76- v ["model" ].removeprefix ("google:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
77- )
78- else :
79- kwargs [k ] = OpenAIModel (
80- v ["model" ], # pyright: ignore[reportUnknownArgumentType]
81- provider = v ["provider" ], # pyright: ignore[reportUnknownArgumentType]
82- )
83- agent_instance : Agent = setup_agent (** kwargs ) # pyright: ignore[reportAny]
84- model = None
85- else :
86- agent_instance = config .kwargs ["agent" ] # pyright: ignore[reportAssignmentType]
87- model = OpenAIModel (
88- config .completion_params ["model" ], # pyright: ignore[reportAny]
89- provider = config .completion_params ["provider" ], # pyright: ignore[reportAny]
90- )
47+ agent = self ._setup_agent (config )
9148
9249 async def process_row (row : EvaluationRow ) -> EvaluationRow :
9350 """Process a single row with agent rollout."""
9451 start_time = time .perf_counter ()
9552
9653 model_messages = [self .convert_ep_message_to_pyd_message (m , row ) for m in row .messages ]
97- response = await agent_instance .run (
98- message_history = model_messages , model = model , usage_limits = config .kwargs .get ("usage_limits" )
99- )
54+ response = await agent .run (message_history = model_messages , usage_limits = config .kwargs .get ("usage_limits" ))
10055 row .messages = await self .convert_pyd_message_to_ep_message (response .all_messages ())
10156
10257 # TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
@@ -121,15 +76,15 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
12176 return tasks
12277
12378 async def convert_pyd_message_to_ep_message (self , messages : list [ModelMessage ]) -> list [Message ]:
124- oai_messages : list [ChatCompletionMessageParam ] = await self .util ._map_messages (messages )
79+ oai_messages : list [ChatCompletionMessageParam ] = await self ._util ._map_messages (messages )
12580 return [Message (** m ) for m in oai_messages ] # pyright: ignore[reportArgumentType]
12681
12782 def convert_ep_message_to_pyd_message (self , message : Message , row : EvaluationRow ) -> ModelMessage :
12883 if message .role == "assistant" :
12984 type_adapter = TypeAdapter (ChatCompletionMessage )
13085 oai_message = type_adapter .validate_python (message )
13186 # Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
132- return self .util ._process_response (
87+ return self ._util ._process_response (
13388 ChatCompletion (
13489 choices = [ChatCompletionChoice (message = oai_message , finish_reason = "stop" , index = 0 )],
13590 object = "chat.completion" ,
0 commit comments