Skip to content

Commit 15a74d8

Browse files
author
Dylan Huang
committed
factory pattern works
1 parent e551be6 commit 15a74d8

2 files changed

Lines changed: 23 additions & 68 deletions

File tree

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,13 @@
44
from collections.abc import Callable
55
import logging
66
import time
7-
import types
8-
from pydantic_ai.models import Model
97
from pydantic_ai.usage import UsageLimits
108
from typing_extensions import override
119
from eval_protocol.models import EvaluationRow, Message
12-
from openai.types import CompletionUsage
1310
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1411
from eval_protocol.pytest.types import RolloutProcessorConfig
1512
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
1613
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
17-
from openai.types.chat.chat_completion_assistant_message_param import (
18-
ChatCompletionAssistantMessageParam,
19-
)
2014
from pydantic import TypeAdapter
2115
from pydantic_ai import Agent
2216
from pydantic_ai._utils import generate_tool_call_id
@@ -27,8 +21,6 @@
2721
ToolReturnPart,
2822
UserPromptPart,
2923
)
30-
from pydantic_ai.models.anthropic import AnthropicModel
31-
from pydantic_ai.models.google import GoogleModel
3224
from pydantic_ai.models.openai import OpenAIModel
3325
from pydantic_ai.providers.openai import OpenAIProvider
3426

@@ -39,64 +31,27 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3931
"""Rollout processor for Pydantic AI agents. Mainly converts
4032
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
4133

42-
def __init__(self, setup_agent: Callable[..., Agent] | Agent, usage_limits: UsageLimits | None = None):
34+
def __init__(
35+
self, agent_factory: Callable[[RolloutProcessorConfig], Agent], usage_limits: UsageLimits | None = None
36+
):
4337
# dummy model used for its helper functions for processing messages
44-
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
38+
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
39+
self._setup_agent = agent_factory
4540

4641
@override
4742
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
4843
"""Create agent rollout tasks and return them for external handling."""
4944

5045
semaphore = config.semaphore
5146

52-
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
53-
if "agent" not in config.kwargs:
54-
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
55-
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
56-
config.kwargs["agent"], types.FunctionType
57-
):
58-
raise ValueError(
59-
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
60-
)
61-
62-
if isinstance(config.kwargs["agent"], types.FunctionType):
63-
setup_agent = config.kwargs["agent"]
64-
if not isinstance(config.completion_params["model"], dict):
65-
raise ValueError(
66-
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
67-
)
68-
kwargs: dict[str, Model] = {}
69-
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
70-
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
71-
kwargs[k] = AnthropicModel(
72-
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
73-
)
74-
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
75-
kwargs[k] = GoogleModel(
76-
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
77-
)
78-
else:
79-
kwargs[k] = OpenAIModel(
80-
v["model"], # pyright: ignore[reportUnknownArgumentType]
81-
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
82-
)
83-
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
84-
model = None
85-
else:
86-
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
87-
model = OpenAIModel(
88-
config.completion_params["model"], # pyright: ignore[reportAny]
89-
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
90-
)
47+
agent = self._setup_agent(config)
9148

9249
async def process_row(row: EvaluationRow) -> EvaluationRow:
9350
"""Process a single row with agent rollout."""
9451
start_time = time.perf_counter()
9552

9653
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
97-
response = await agent_instance.run(
98-
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
99-
)
54+
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
10055
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
10156

10257
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
@@ -121,15 +76,15 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
12176
return tasks
12277

12378
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
124-
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
79+
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
12580
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
12681

12782
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
12883
if message.role == "assistant":
12984
type_adapter = TypeAdapter(ChatCompletionMessage)
13085
oai_message = type_adapter.validate_python(message)
13186
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
132-
return self.util._process_response(
87+
return self._util._process_response(
13388
ChatCompletion(
13489
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
13590
object="chat.completion",

tests/chinook/pydantic/test_pydantic_chinook.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from eval_protocol.pytest import evaluation_test
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
910
from tests.chinook.pydantic.agent import setup_agent
1011
import os
1112
from pydantic_ai.models.openai import OpenAIModel
@@ -20,20 +21,23 @@
2021
)
2122

2223

24+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
25+
model_name = config.completion_params["model"]
26+
provider = config.completion_params["provider"]
27+
model = OpenAIModel(model_name, provider=provider)
28+
return setup_agent(model)
29+
30+
2331
@pytest.mark.asyncio
2432
@evaluation_test(
2533
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
2634
completion_params=[
2735
{
28-
"model": {
29-
"orchestrator_agent_model": {
30-
"model": "accounts/fireworks/models/kimi-k2-instruct",
31-
"provider": "fireworks",
32-
}
33-
}
36+
"model": "accounts/fireworks/models/kimi-k2-instruct",
37+
"provider": "fireworks",
3438
},
3539
],
36-
rollout_processor=PydanticAgentRolloutProcessor(setup_agent),
40+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
3741
mode="pointwise",
3842
)
3943
async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
@@ -91,15 +95,11 @@ class Response(BaseModel):
9195
input_rows=[collect_dataset()],
9296
completion_params=[
9397
{
94-
"model": {
95-
"orchestrator_agent_model": {
96-
"model": "accounts/fireworks/models/kimi-k2-instruct",
97-
"provider": "fireworks",
98-
}
99-
}
98+
"model": "accounts/fireworks/models/kimi-k2-instruct",
99+
"provider": "fireworks",
100100
},
101101
],
102-
rollout_processor=PydanticAgentRolloutProcessor(setup_agent),
102+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
103103
mode="pointwise",
104104
)
105105
async def test_complex_queries(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)