Skip to content

Commit 0e609e2

Browse files
author
Dylan Huang
committed
save
1 parent c0bece6 commit 0e609e2

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from eval_protocol.models import EvaluationRow, Message
1010
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1111
from eval_protocol.pytest.types import RolloutProcessorConfig
12-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
12+
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
1313
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
1414
from pydantic_ai.models.anthropic import AnthropicModel
1515
from pydantic_ai.models.openai import OpenAIModel
@@ -36,7 +36,7 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3636

3737
def __init__(self):
3838
# dummy model used for its helper functions for processing messages
39-
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
39+
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
4040

4141
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
4242
"""Create agent rollout tasks and return them for external handling."""
@@ -60,7 +60,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6060
raise ValueError(
6161
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6262
)
63-
kwargs = {}
63+
kwargs: dict = {}
6464
for k, v in config.completion_params["model"].items():
6565
if v["model"] and v["model"].startswith("anthropic:"):
6666
kwargs[k] = AnthropicModel(
@@ -75,10 +75,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
7575
v["model"],
7676
provider=v["provider"],
7777
)
78-
agent = setup_agent(**kwargs)
78+
agent_instance: Agent = setup_agent(**kwargs)
7979
model = None
8080
else:
81-
agent = config.kwargs["agent"]
81+
agent_instance = config.kwargs["agent"]
8282
model = OpenAIModel(
8383
config.completion_params["model"],
8484
provider=config.completion_params["provider"],
@@ -87,7 +87,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
8787
async def process_row(row: EvaluationRow) -> EvaluationRow:
8888
"""Process a single row with agent rollout."""
8989
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
90-
response = await agent.run(
90+
response = await agent_instance.run(
9191
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
9292
)
9393
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
@@ -104,11 +104,11 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104104

105105
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
106106
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
107-
return [Message(**m) for m in oai_messages]
107+
return [Message(role=m["role"], **m) for m in oai_messages]
108108

109109
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
110110
if message.role == "assistant":
111-
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
111+
type_adapter = TypeAdapter(ChatCompletionMessage)
112112
oai_message = type_adapter.validate_python(message)
113113
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
114114
return self.util._process_response(
@@ -117,23 +117,23 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
117117
object="chat.completion",
118118
model="",
119119
id="",
120-
created=(
121-
int(row.created_at.timestamp())
122-
if hasattr(row.created_at, "timestamp")
123-
else int(row.created_at)
124-
),
120+
created=int(row.created_at.timestamp()),
125121
)
126122
)
127123
elif message.role == "user":
128124
if isinstance(message.content, str):
129125
return ModelRequest(parts=[UserPromptPart(content=message.content)])
130126
elif isinstance(message.content, list):
131127
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
128+
else:
129+
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
132130
elif message.role == "system":
133131
if isinstance(message.content, str):
134132
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
135133
elif isinstance(message.content, list):
136134
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
135+
else:
136+
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
137137
elif message.role == "tool":
138138
return ModelRequest(
139139
parts=[

tests/chinook/test_pydantic_chinook.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from eval_protocol.pytest import evaluation_test
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9-
from agent import setup_agent
9+
from tests.chinook.agent import setup_agent
10+
import os
1011
from pydantic_ai.models.openai import OpenAIModel
1112

1213
from tests.chinook.dataset import collect_dataset
@@ -21,7 +22,7 @@
2122

2223
@pytest.mark.asyncio
2324
@evaluation_test(
24-
input_messages=[Message(role="user", content="What is the total number of tracks in the database?")],
25+
input_messages=[[Message(role="user", content="What is the total number of tracks in the database?")]],
2526
completion_params=[
2627
{
2728
"model": {
@@ -82,7 +83,10 @@ class Response(BaseModel):
8283
return row
8384

8485

85-
@pytest.mark.skip(reason="takes too long to run")
86+
@pytest.mark.skipif(
87+
os.environ.get("CI") == "true",
88+
reason="Only run this test locally (skipped in CI)",
89+
)
8690
@pytest.mark.asyncio
8791
@evaluation_test(
8892
input_rows=collect_dataset(),

0 commit comments

Comments
 (0)