Skip to content

Commit aaed955

Browse files
committed
Merge branch 'main' into derekx/persist-onto-fireworks
2 parents 55c0aea + fd2bec1 commit aaed955

48 files changed

Lines changed: 18074 additions & 318 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ jobs:
8686
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
8787
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
8888
FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }}
89+
SUPABASE_PASSWORD: ${{ secrets.SUPABASE_PASSWORD }}
90+
SUPABASE_HOST: ${{ secrets.SUPABASE_HOST }}
91+
SUPABASE_PORT: ${{ secrets.SUPABASE_PORT }}
92+
SUPABASE_DATABASE: ${{ secrets.SUPABASE_DATABASE }}
93+
SUPABASE_USER: ${{ secrets.SUPABASE_USER }}
8994
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
9095
run: |
9196
# Run most tests in parallel, but explicitly ignore tests that manage their own servers or are slow

README.md

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,53 @@
22

33
[![PyPI - Version](https://img.shields.io/pypi/v/eval-protocol)](https://pypi.org/project/eval-protocol/)
44

5-
**Eval Protocol (EP) is the open-source standard and toolkit for practicing Eval-Driven Development.**
5+
**The open-source toolkit for building your internal model leaderboard.**
66

7-
Building with AI is different. Traditional software is deterministic, but AI systems are probabilistic. How do you ship new features without causing silent regressions? How do you prove a new prompt is actually better?
8-
9-
The answer is a new engineering discipline: **Eval-Driven Development (EDD)**. It adapts the rigor of Test-Driven Development for the uncertain world of AI. With EDD, you define your AI's desired behavior as a suite of executable tests, creating a safety net that allows you to innovate with confidence.
10-
11-
EP provides a consistent way to write evals, store traces, and analyze results.
12-
13-
<p align="center">
14-
<img src="https://raw.githubusercontent.com/eval-protocol/python-sdk/refs/heads/main/assets/ui.png" alt="UI" />
15-
<br>
16-
<sub><b>Log Viewer: Monitor your evaluation rollouts in real time.</b></sub>
17-
</p>
7+
When you have multiple AI models to choose from—different versions, providers, or configurations—how do you know which one is best for your use case?
188

199
## Quick Example
2010

21-
Here's a simple test function that checks if a model's response contains **bold** text formatting:
11+
Compare models on a simple formatting task:
2212

2313
```python test_bold_format.py
2414
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
25-
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
15+
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
2616

2717
@evaluation_test(
2818
input_messages=[
2919
[
30-
Message(role="system", content="You are a helpful assistant. Use bold text to highlight important information."),
31-
Message(role="user", content="Explain why **evaluations** matter for building AI agents. Make it dramatic!"),
20+
Message(role="system", content="Use bold text to highlight important information."),
21+
Message(role="user", content="Explain why evaluations matter for AI agents. Make it dramatic!"),
3222
],
3323
],
34-
completion_params=[{"model": "accounts/fireworks/models/llama-v3p1-8b-instruct"}],
35-
rollout_processor=SingleTurnRolloutProcessor(),
24+
model=[
25+
"fireworks_ai/accounts/fireworks/models/llama-v3p1-8b-instruct",
26+
"openai/gpt-4",
27+
"anthropic/claude-3-sonnet"
28+
],
29+
rollout_processor=default_single_turn_rollout_processor,
3630
mode="pointwise",
3731
)
3832
def test_bold_format(row: EvaluationRow) -> EvaluationRow:
39-
"""
40-
Simple evaluation that checks if the model's response contains bold text.
41-
"""
42-
33+
"""Check if the model's response contains bold text."""
4334
assistant_response = row.messages[-1].content
4435

45-
# Check if response contains **bold** text
46-
has_bold = "**" in assistant_response
36+
if assistant_response is None:
37+
row.evaluation_result = EvaluateResult(score=0.0, reason="No response")
38+
return row
4739

48-
if has_bold:
49-
result = EvaluateResult(score=1.0, reason="✅ Response contains bold text")
50-
else:
51-
result = EvaluateResult(score=0.0, reason="❌ No bold text found")
40+
has_bold = "**" in str(assistant_response)
41+
score = 1.0 if has_bold else 0.0
42+
reason = "Contains bold text" if has_bold else "No bold text found"
5243

53-
row.evaluation_result = result
44+
row.evaluation_result = EvaluateResult(score=score, reason=reason)
5445
return row
5546
```
5647

57-
## Documentation
48+
## 📚 Resources
5849

59-
See our [documentation](https://evalprotocol.io) for more details.
50+
- **[Documentation](https://evalprotocol.io)** - Complete guides and API reference
51+
- **[Discord](https://discord.com/channels/1137072072808472616/1400975572405850155)** - Community discussions
6052

6153
## Installation
6254

eval_protocol/models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,11 @@ def __iter__(self):
397397

398398
CompletionParams = Dict[str, Any]
399399
"""
400-
Common set of completion parameters that most model providers support in their
401-
API. Set total=False to allow extra fields since LiteLLM + providers have their
402-
own set of parameters. The following parameters are common fields that are
403-
populated.
404-
405-
model: str
406-
temperature: Optional[float]
407-
max_tokens: Optional[int]
408-
top_p: Optional[float]
400+
The completion parameters for the respective LLM SDK or agent framework.
401+
Depending on the rollout processor, this might be the parameters passed to
402+
LiteLLM completion call or parameters for the "run" method of the "Agent" class
403+
in Pydantic AI. You can also customize this dictionary to whatever you need if
404+
you implement your own custom rollout processor.
409405
"""
410406

411407

@@ -576,6 +572,13 @@ def get_assistant_messages(self) -> List[Message]:
576572
"""Returns only the assistant messages from the conversation."""
577573
return [msg for msg in self.messages if msg.role == "assistant"]
578574

575+
def last_assistant_message(self) -> Optional[Message]:
576+
"""Returns the last assistant message from the conversation. Returns None if none found."""
577+
assistant_messages = self.get_assistant_messages()
578+
if not assistant_messages:
579+
return None
580+
return assistant_messages[-1]
581+
579582
def get_user_messages(self) -> List[Message]:
580583
"""Returns only the user messages from the conversation."""
581584
return [msg for msg in self.messages if msg.role == "user"]

eval_protocol/pytest/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
1010

11+
# Conditional import for optional dependency
12+
try:
13+
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
14+
15+
PYDANTIC_AI_AVAILABLE = True
16+
except ImportError:
17+
PYDANTIC_AI_AVAILABLE = False
18+
PydanticAgentRolloutProcessor = None
19+
1120
__all__ = [
1221
"AgentRolloutProcessor",
1322
"MCPGymRolloutProcessor",
@@ -21,3 +30,7 @@
2130
"BackoffConfig",
2231
"get_default_exception_handler_config",
2332
]
33+
34+
# Only add to __all__ if available
35+
if PYDANTIC_AI_AVAILABLE:
36+
__all__.append("PydanticAgentRolloutProcessor")
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
import logging
3+
import types
4+
from typing import List
5+
6+
from attr import dataclass
7+
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
8+
9+
from eval_protocol.models import EvaluationRow, Message
10+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
11+
from eval_protocol.pytest.types import RolloutProcessorConfig
12+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
13+
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
14+
from pydantic_ai.models.anthropic import AnthropicModel
15+
from pydantic_ai.models.openai import OpenAIModel
16+
from pydantic_ai.models.google import GoogleModel
17+
from pydantic import TypeAdapter
18+
from pydantic_ai.messages import ModelMessage
19+
from pydantic_ai._utils import generate_tool_call_id
20+
from pydantic_ai import Agent
21+
from pydantic_ai.messages import (
22+
ModelRequest,
23+
SystemPromptPart,
24+
ToolReturnPart,
25+
UserPromptPart,
26+
)
27+
from pydantic_ai.providers.openai import OpenAIProvider
28+
from typing_extensions import TypedDict
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
class PydanticAgentRolloutProcessor(RolloutProcessor):
34+
"""Rollout processor for Pydantic AI agents. Mainly converts
35+
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
36+
37+
def __init__(self):
38+
# dummy model used for its helper functions for processing messages
39+
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
40+
41+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
42+
"""Create agent rollout tasks and return them for external handling."""
43+
44+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
45+
semaphore = asyncio.Semaphore(max_concurrent)
46+
47+
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
48+
if "agent" not in config.kwargs:
49+
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
50+
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
51+
config.kwargs["agent"], types.FunctionType
52+
):
53+
raise ValueError(
54+
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
55+
)
56+
57+
if isinstance(config.kwargs["agent"], types.FunctionType):
58+
setup_agent = config.kwargs["agent"]
59+
if not isinstance(config.completion_params["model"], dict):
60+
raise ValueError(
61+
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
62+
)
63+
kwargs = {}
64+
for k, v in config.completion_params["model"].items():
65+
if v["model"] and v["model"].startswith("anthropic:"):
66+
kwargs[k] = AnthropicModel(
67+
v["model"].removeprefix("anthropic:"),
68+
)
69+
elif v["model"] and v["model"].startswith("google:"):
70+
kwargs[k] = GoogleModel(
71+
v["model"].removeprefix("google:"),
72+
)
73+
else:
74+
kwargs[k] = OpenAIModel(
75+
v["model"],
76+
provider=v["provider"],
77+
)
78+
agent = setup_agent(**kwargs)
79+
model = None
80+
else:
81+
agent = config.kwargs["agent"]
82+
model = OpenAIModel(
83+
config.completion_params["model"],
84+
provider=config.completion_params["provider"],
85+
)
86+
87+
async def process_row(row: EvaluationRow) -> EvaluationRow:
88+
"""Process a single row with agent rollout."""
89+
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
90+
response = await agent.run(
91+
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
92+
)
93+
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
94+
return row
95+
96+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
97+
async with semaphore:
98+
result = await process_row(r)
99+
return result
100+
101+
# Create and return tasks for external handling
102+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
103+
return tasks
104+
105+
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
106+
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
107+
return [Message(**m) for m in oai_messages]
108+
109+
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
110+
if message.role == "assistant":
111+
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
112+
oai_message = type_adapter.validate_python(message)
113+
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
114+
return self.util._process_response(
115+
ChatCompletion(
116+
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
117+
object="chat.completion",
118+
model="",
119+
id="",
120+
created=(
121+
int(row.created_at.timestamp())
122+
if hasattr(row.created_at, "timestamp")
123+
else int(row.created_at)
124+
),
125+
)
126+
)
127+
elif message.role == "user":
128+
if isinstance(message.content, str):
129+
return ModelRequest(parts=[UserPromptPart(content=message.content)])
130+
elif isinstance(message.content, list):
131+
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
132+
elif message.role == "system":
133+
if isinstance(message.content, str):
134+
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
135+
elif isinstance(message.content, list):
136+
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
137+
elif message.role == "tool":
138+
return ModelRequest(
139+
parts=[
140+
ToolReturnPart(
141+
content=message.content,
142+
tool_name="",
143+
tool_call_id=message.tool_call_id or generate_tool_call_id(),
144+
)
145+
]
146+
)
147+
else:
148+
raise ValueError(f"Unknown role: {message.role}")

eval_protocol/pytest/evaluation_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
parse_ep_max_concurrent_rollouts,
6262
parse_ep_num_runs,
6363
parse_ep_completion_params,
64+
parse_ep_passed_threshold,
6465
rollout_processor_with_retry,
6566
sanitize_filename,
6667
)
@@ -538,6 +539,7 @@ def evaluation_test( # noqa: C901
538539
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
539540
completion_params = parse_ep_completion_params(completion_params)
540541
original_completion_params = completion_params
542+
passed_threshold = parse_ep_passed_threshold(passed_threshold)
541543

542544
def decorator(
543545
test_func: TestFunction,
@@ -925,16 +927,18 @@ async def _collect_result(config, lst):
925927
r.eval_metadata.status = Status.eval_finished()
926928
active_logger.log(r)
927929

928-
tasks = []
929-
for i in range(num_runs):
930-
tasks.append(asyncio.create_task(execute_run(i, config)))
931-
932930
# if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs
933931
# else, we execute runs in parallel
934932
if isinstance(rollout_processor, MCPGymRolloutProcessor):
935-
for task in tasks:
933+
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
934+
for i in range(num_runs):
935+
task = asyncio.create_task(execute_run(i, config))
936936
await task
937937
else:
938+
# For other processors, create all tasks at once and run in parallel
939+
tasks = []
940+
for i in range(num_runs):
941+
tasks.append(asyncio.create_task(execute_run(i, config)))
938942
await asyncio.gather(*tasks)
939943

940944
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them

0 commit comments

Comments
 (0)