Skip to content

Commit c8918fd

Browse files
author
Dylan Huang
authored
Replace deprecated class for pydantic ai / fix function toolset query (#169)
* fix bg color and hover color * remove deprecated usage of openaimodel * done * vite build
1 parent 3410973 commit c8918fd

17 files changed

Lines changed: 198 additions & 176 deletions

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Callable
55
import logging
66
import time
7+
from pydantic_ai.toolsets import FunctionToolset
78
from pydantic_ai.usage import UsageLimits
89
from typing_extensions import override
910
from eval_protocol.models import EvaluationRow, Message
@@ -21,7 +22,7 @@
2122
ToolReturnPart,
2223
UserPromptPart,
2324
)
24-
from pydantic_ai.models.openai import OpenAIModel
25+
from pydantic_ai.models.openai import OpenAIChatModel
2526
from pydantic_ai.providers.openai import OpenAIProvider
2627

2728
logger = logging.getLogger(__name__)
@@ -37,7 +38,7 @@ def __init__(
3738
usage_limits: UsageLimits | None = None,
3839
):
3940
# dummy model used for its helper functions for processing messages
40-
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
41+
self._util: OpenAIChatModel = OpenAIChatModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
4142
self._setup_agent = agent_factory
4243

4344
@override
@@ -53,18 +54,19 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
5354
start_time = time.perf_counter()
5455

5556
tools = []
56-
for _, tool in agent._function_tools.items():
57-
tool_dict = {
58-
"type": "function",
59-
"function": {
60-
"name": tool.name,
61-
"parameters": tool.function_schema.json_schema,
62-
},
63-
}
64-
if tool.description:
65-
tool_dict["function"]["description"] = tool.description
66-
67-
tools.append(tool_dict)
57+
for toolset in agent.toolsets:
58+
if isinstance(toolset, FunctionToolset):
59+
for _, tool in toolset.tools.items():
60+
tool_dict = {
61+
"type": "function",
62+
"function": {
63+
"name": tool.name,
64+
"parameters": tool.function_schema.json_schema,
65+
},
66+
}
67+
if tool.description:
68+
tool_dict["function"]["description"] = tool.description
69+
tools.append(tool_dict)
6870
row.tools = tools
6971

7072
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]

tests/chinook/langfuse/generate_traces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
try:
1414
from langfuse import get_client, observe # pyright: ignore[reportPrivateImportUsage]
1515
from pydantic_ai.agent import Agent
16-
from pydantic_ai.models.openai import OpenAIModel
16+
from pydantic_ai.models.openai import OpenAIChatModel
1717

1818
LANGFUSE_AVAILABLE = True
1919
langfuse_client = get_client()
@@ -42,7 +42,7 @@ def decorator(func):
4242
def agent_factory(config: RolloutProcessorConfig) -> Agent:
4343
model_name = config.completion_params["model"]
4444
provider = config.completion_params["provider"]
45-
model = OpenAIModel(model_name, provider=provider)
45+
model = OpenAIChatModel(model_name, provider=provider)
4646
return setup_agent(model)
4747

4848

tests/chinook/langfuse/test_langfuse_chinook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
from pydantic import BaseModel
1818
from pydantic_ai import Agent
19-
from pydantic_ai.models.openai import OpenAIModel
19+
from pydantic_ai.models.openai import OpenAIChatModel
2020

2121
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata
2222
from eval_protocol.pytest import evaluation_test, NoOpRolloutProcessor
@@ -99,7 +99,7 @@ async def test_langfuse_evaluation(row: EvaluationRow) -> EvaluationRow:
9999
reason="No assistant message found",
100100
)
101101
else:
102-
model = OpenAIModel(
102+
model = OpenAIChatModel(
103103
"accounts/fireworks/models/kimi-k2-instruct",
104104
provider="fireworks",
105105
)

tests/chinook/pydantic/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydantic_ai import Agent, RunContext
22
import asyncio
33
from pydantic_ai.models import Model
4-
from pydantic_ai.models.openai import OpenAIModel
4+
from pydantic_ai.models.openai import OpenAIChatModel
55
from pydantic_ai.exceptions import ModelRetry
66
import sys
77
import os
@@ -68,7 +68,7 @@ def execute_sql(ctx: RunContext, query: str) -> str:
6868

6969

7070
async def main():
71-
model = OpenAIModel(
71+
model = OpenAIChatModel(
7272
"accounts/fireworks/models/kimi-k2-instruct",
7373
provider="fireworks",
7474
)

tests/chinook/pydantic/test_pydantic_chinook.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from eval_protocol.pytest.types import RolloutProcessorConfig
1010
from tests.chinook.pydantic.agent import setup_agent
1111
import os
12-
from pydantic_ai.models.openai import OpenAIModel
12+
from pydantic_ai.models.openai import OpenAIChatModel
1313

1414
from tests.chinook.dataset import collect_dataset
1515

@@ -24,7 +24,7 @@
2424
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2525
model_name = config.completion_params["model"]
2626
provider = config.completion_params["provider"]
27-
model = OpenAIModel(model_name, provider=provider)
27+
model = OpenAIChatModel(model_name, provider=provider)
2828
return setup_agent(model)
2929

3030

@@ -44,6 +44,23 @@ async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
4444
"""
4545
Super simple query for the Chinook database
4646
"""
47+
expected_tools = [
48+
{
49+
"type": "function",
50+
"function": {
51+
"name": "execute_sql",
52+
"parameters": {
53+
"additionalProperties": False,
54+
"properties": {"query": {"type": "string"}},
55+
"required": ["query"],
56+
"type": "object",
57+
},
58+
},
59+
}
60+
]
61+
assert hasattr(row, "tools"), "Row missing 'tools' attribute"
62+
assert row.tools == expected_tools, f"Tools validation failed. Expected: {expected_tools}, Got: {row.tools}"
63+
4764
last_assistant_message = row.last_assistant_message()
4865
if last_assistant_message is None:
4966
row.evaluation_result = EvaluateResult(
@@ -56,7 +73,7 @@ async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
5673
reason="No assistant message found",
5774
)
5875
else:
59-
model = OpenAIModel(
76+
model = OpenAIChatModel(
6077
"accounts/fireworks/models/kimi-k2-instruct",
6178
provider="fireworks",
6279
)

tests/chinook/pydantic/test_pydantic_complex_queries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from pydantic import BaseModel
33
from pydantic_ai import Agent
4-
from pydantic_ai.models.openai import OpenAIModel
4+
from pydantic_ai.models.openai import OpenAIChatModel
55
import pytest
66

77
from eval_protocol.models import EvaluateResult, EvaluationRow
@@ -22,7 +22,7 @@
2222
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2323
model_name = config.completion_params["model"]
2424
provider = config.completion_params["provider"]
25-
model = OpenAIModel(model_name, provider=provider)
25+
model = OpenAIChatModel(model_name, provider=provider)
2626
return setup_agent(model)
2727

2828

@@ -57,7 +57,7 @@ async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow:
5757
reason="No assistant message found",
5858
)
5959
else:
60-
model = OpenAIModel(
60+
model = OpenAIChatModel(
6161
"accounts/fireworks/models/kimi-k2-instruct",
6262
provider="fireworks",
6363
)

tests/pytest/test_pydantic_agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from pydantic_ai.agent import Agent
2-
from pydantic_ai.models.openai import OpenAIModel
2+
from pydantic_ai.models.openai import OpenAIChatModel
33
import pytest
44

5-
from eval_protocol.models import EvaluationRow, Message
5+
from eval_protocol.models import EvaluationRow, Message, Status
66
from eval_protocol.pytest import evaluation_test
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
99
from eval_protocol.pytest.types import RolloutProcessorConfig
1010

1111

1212
def agent_factory(config: RolloutProcessorConfig) -> Agent:
13-
model = OpenAIModel(config.completion_params["model"], provider="fireworks")
13+
model = OpenAIChatModel(config.completion_params["model"], provider="fireworks")
1414
return Agent(model=model)
1515

1616

@@ -27,4 +27,5 @@ async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
2727
"""
2828
Super simple hello world test for Pydantic AI.
2929
"""
30+
assert row.rollout_status.code == Status.Code.FINISHED
3031
return row

tests/pytest/test_pydantic_multi_agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
to agent based on key.
88
"""
99

10-
from pydantic_ai.models.openai import OpenAIModel
10+
from pydantic_ai.models.openai import OpenAIChatModel
1111
import pytest
1212

1313
from eval_protocol.models import EvaluationRow, Message
@@ -49,10 +49,12 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig
4949

5050

5151
def agent_factory(config: RolloutProcessorConfig) -> Agent:
52-
joke_generation_model = OpenAIModel(
52+
joke_generation_model = OpenAIChatModel(
5353
config.completion_params["model"]["joke_generation_model"], provider="fireworks"
5454
)
55-
joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks")
55+
joke_selection_model = OpenAIChatModel(
56+
config.completion_params["model"]["joke_selection_model"], provider="fireworks"
57+
)
5658
return setup_agent(
5759
joke_generation_model,
5860
joke_selection_model,

0 commit comments

Comments
 (0)