Skip to content

Commit c2b19b7

Browse files
author
Dylan Huang
committed
refactor test_pydantic_multi_agent to work with factory setup
1 parent 15a74d8 commit c2b19b7

1 file changed

Lines changed: 23 additions & 11 deletions

File tree

tests/pytest/test_pydantic_multi_agent.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
22
Copied and modified for eval-protocol from https://ai.pydantic.dev/multi-agent-applications/#agent-delegation
33
4-
To test your Pydantic AI multi-agent application, you can pass a function that
5-
sets up the agents and their tools. The function should accept parameters that
6-
map a model to each agent. In completion_params, you can provide mappings of
7-
model to agent based on key.
4+
To test your Pydantic AI multi-agent application, you can pass a factory that
5+
sets up the agenet based on the completion_params. The function should accept a
6+
RolloutProcessorConfig. In completion_params, you can provide mappings of model
7+
to agent based on key.
88
"""
99

10+
from pydantic_ai.models.openai import OpenAIModel
1011
import pytest
1112

1213
from eval_protocol.models import EvaluationRow, Message
@@ -18,6 +19,8 @@
1819
from pydantic_ai.models import Model
1920
from pydantic_ai.usage import UsageLimits
2021

22+
from eval_protocol.pytest.types import RolloutProcessorConfig
23+
2124

2225
def setup_agent(joke_generation_model: Model, joke_selection_model: Model) -> Agent:
2326
"""
@@ -45,22 +48,31 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig
4548
return joke_selection_agent
4649

4750

51+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
52+
joke_generation_model = OpenAIModel(
53+
config.completion_params["model"]["joke_generation_model"], provider="fireworks"
54+
)
55+
joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks")
56+
return setup_agent(
57+
joke_generation_model,
58+
joke_selection_model,
59+
)
60+
61+
4862
@pytest.mark.asyncio
4963
@evaluation_test(
5064
input_messages=[[[Message(role="user", content="Tell me a joke.")]]],
5165
completion_params=[
5266
# multi-agent
5367
{
54-
"joke_generation_model": {
55-
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct",
56-
},
57-
"joke_selection_model": {
58-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
59-
},
68+
"model": {
69+
"joke_generation_model": "accounts/fireworks/models/kimi-k2-instruct",
70+
"joke_selection_model": "accounts/fireworks/models/deepseek-v3p1",
71+
}
6072
},
6173
],
6274
rollout_processor=PydanticAgentRolloutProcessor(
63-
setup_agent, UsageLimits(request_limit=5, total_tokens_limit=1000)
75+
agent_factory, UsageLimits(request_limit=5, total_tokens_limit=1000)
6476
),
6577
mode="pointwise",
6678
)

0 commit comments

Comments
 (0)