Skip to content

Commit 2c52f3a

Browse files
author
Dylan Huang
committed
fix test_pydantic_agent.py
1 parent c2b19b7 commit 2c52f3a

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3232
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
3333

3434
def __init__(
35-
self, agent_factory: Callable[[RolloutProcessorConfig], Agent], usage_limits: UsageLimits | None = None
35+
self,
36+
agent_factory: Callable[[RolloutProcessorConfig], Agent],
37+
usage_limits: UsageLimits | None = None,
3638
):
3739
# dummy model used for its helper functions for processing messages
3840
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))

tests/pytest/test_pydantic_agent.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
from pydantic_ai.agent import Agent
2+
from pydantic_ai.models.openai import OpenAIModel
13
import pytest
24

35
from eval_protocol.models import EvaluationRow, Message
46
from eval_protocol.pytest import evaluation_test
5-
from pydantic_ai import Agent
67

78
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
810

9-
agent = Agent()
11+
12+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
13+
model = OpenAIModel(config.completion_params["model"], provider="fireworks")
14+
return Agent(model=model)
1015

1116

1217
@pytest.mark.asyncio
1318
@evaluation_test(
1419
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
1520
completion_params=[
16-
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
21+
{"model": "accounts/fireworks/models/gpt-oss-120b"},
1722
],
18-
rollout_processor=PydanticAgentRolloutProcessor(agent),
23+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
1924
mode="pointwise",
2025
)
2126
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)