Skip to content

Commit e709563

Browse files
author
Dylan Huang
committed
Super simple hello world test for Pydantic AI
1 parent d951083 commit e709563

File tree

6 files changed

+492
-14
lines changed

6 files changed

+492
-14
lines changed

eval_protocol/models.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,11 @@ def __iter__(self):
397397

398398
CompletionParams = Dict[str, Any]
399399
"""
400-
Common set of completion parameters that most model providers support in their
401-
API. Set total=False to allow extra fields since LiteLLM + providers have their
402-
own set of parameters. The following parameters are common fields that are
403-
populated.
404-
405-
model: str
406-
temperature: Optional[float]
407-
max_tokens: Optional[int]
408-
top_p: Optional[float]
400+
The completion parameters for the respective LLM SDK or agent framework.
401+
Depending on the rollout processor, this might be the parameters passed to
402+
LiteLLM completion call or parameters for the "run" method of the "Agent" class
403+
in Pydantic AI. You can also customize this dictionary to whatever you need if
404+
you implement your own custom rollout processor.
409405
"""
410406

411407

eval_protocol/pytest/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
1010

11+
# Conditional import for optional dependency
12+
try:
13+
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
14+
15+
PYDANTIC_AI_AVAILABLE = True
16+
except ImportError:
17+
PYDANTIC_AI_AVAILABLE = False
18+
PydanticAgentRolloutProcessor = None
19+
1120
__all__ = [
1221
"AgentRolloutProcessor",
1322
"MCPGymRolloutProcessor",
@@ -21,3 +30,7 @@
2130
"BackoffConfig",
2231
"get_default_exception_handler_config",
2332
]
33+
34+
# Only add to __all__ if available
35+
if PYDANTIC_AI_AVAILABLE:
36+
__all__.append("PydanticAgentRolloutProcessor")
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import asyncio
2+
import logging
3+
from typing import List
4+
5+
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
6+
7+
from eval_protocol.models import EvaluationRow, Message
8+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
10+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
11+
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
12+
13+
from pydantic_ai.models.openai import OpenAIModel
14+
from pydantic import TypeAdapter
15+
from pydantic_ai.messages import ModelMessage
16+
from pydantic_ai._utils import generate_tool_call_id
17+
from pydantic_ai import Agent
18+
from pydantic_ai.messages import (
19+
ModelRequest,
20+
SystemPromptPart,
21+
ToolReturnPart,
22+
UserPromptPart,
23+
)
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class PydanticAgentRolloutProcessor(RolloutProcessor):
29+
"""Rollout processor for Pydantic AI agents. Mainly converts
30+
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
31+
32+
def __init__(self):
33+
# dummy model used for its helper functions for processing messages
34+
self.util = OpenAIModel("dummy-model")
35+
36+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
37+
"""Create agent rollout tasks and return them for external handling."""
38+
39+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
40+
semaphore = asyncio.Semaphore(max_concurrent)
41+
42+
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
43+
if "agent" not in config.kwargs:
44+
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
45+
if not isinstance(config.kwargs["agent"], Agent):
46+
raise ValueError("kwargs['agent'] must be a valid Pydantic AI Agent instance")
47+
48+
agent: Agent = config.kwargs["agent"]
49+
50+
model = OpenAIModel(
51+
config.completion_params["model"],
52+
provider=config.completion_params["provider"],
53+
)
54+
55+
async def process_row(row: EvaluationRow) -> EvaluationRow:
56+
"""Process a single row with agent rollout."""
57+
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
58+
response = await agent.run(message_history=model_messages, model=model)
59+
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
60+
return row
61+
62+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
63+
async with semaphore:
64+
result = await process_row(r)
65+
return result
66+
67+
# Create and return tasks for external handling
68+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
69+
return tasks
70+
71+
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
72+
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
73+
return [Message(**m) for m in oai_messages]
74+
75+
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
76+
if message.role == "assistant":
77+
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
78+
oai_message = type_adapter.validate_python(message)
79+
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
80+
return self.util._process_response(
81+
ChatCompletion(
82+
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
83+
object="chat.completion",
84+
model="",
85+
id="",
86+
created=(
87+
int(row.created_at.timestamp())
88+
if hasattr(row.created_at, "timestamp")
89+
else int(row.created_at)
90+
),
91+
)
92+
)
93+
elif message.role == "user":
94+
if isinstance(message.content, str):
95+
return ModelRequest(parts=[UserPromptPart(content=message.content)])
96+
elif isinstance(message.content, list):
97+
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
98+
elif message.role == "system":
99+
if isinstance(message.content, str):
100+
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
101+
elif isinstance(message.content, list):
102+
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
103+
elif message.role == "tool":
104+
return ModelRequest(
105+
parts=[
106+
ToolReturnPart(
107+
content=message.content,
108+
tool_name="",
109+
tool_call_id=message.tool_call_id or generate_tool_call_id(),
110+
)
111+
]
112+
)
113+
else:
114+
raise ValueError(f"Unknown role: {message.role}")

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ bigquery = [
122122
svgbench = [
123123
"selenium>=4.0.0",
124124
]
125+
pydantic = [
126+
"pydantic-ai",
127+
]
125128

126129
[tool.pytest.ini_options]
127130
addopts = "-q"
@@ -170,7 +173,6 @@ dev = [
170173
"haikus==0.3.8",
171174
"pytest>=8.4.1",
172175
]
173-
174176
[tool.ruff]
175177
line-length = 119
176178
target-version = "py310"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import pytest
3+
4+
from eval_protocol.models import EvaluationRow, Message
5+
from eval_protocol.pytest import evaluation_test
6+
from pydantic_ai import Agent
7+
8+
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
10+
agent = Agent()
11+
12+
13+
@pytest.mark.asyncio
14+
@evaluation_test(
15+
input_messages=[Message(role="user", content="Hello, how are you?")],
16+
completion_params=[
17+
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
18+
],
19+
rollout_processor=PydanticAgentRolloutProcessor(),
20+
rollout_processor_kwargs={"agent": agent},
21+
mode="pointwise",
22+
)
23+
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
24+
"""
25+
Super simple hello world test for Pydantic AI.
26+
"""
27+
return row

0 commit comments

Comments
 (0)