Skip to content

Commit 56bf3ab

Browse files
author
Dylan Huang
committed
fix
1 parent b95909b commit 56bf3ab

3 files changed

Lines changed: 18 additions & 53 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
RolloutProcessorConfig,
1717
TestFunction,
1818
)
19-
from eval_protocol.pytest.utils import aggregate, create_dynamically_parameterized_wrapper, execute_function
19+
from eval_protocol.pytest.utils import (
20+
AggregationMethod,
21+
aggregate,
22+
create_dynamically_parameterized_wrapper,
23+
execute_function,
24+
)
2025

2126
from ..common_utils import load_jsonl
2227

@@ -29,7 +34,7 @@ def evaluation_test(
2934
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
3035
input_params: Optional[List[InputParam]] = None,
3136
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
32-
aggregation_method: str = "mean",
37+
aggregation_method: AggregationMethod = "mean",
3338
threshold_of_success: Optional[float] = None,
3439
num_runs: int = 1,
3540
max_dataset_rows: Optional[int] = None,
@@ -58,45 +63,10 @@ def evaluation_test(
5863
below this threshold.
5964
num_runs: Number of times to repeat the evaluation.
6065
max_dataset_rows: Limit dataset to the first N rows.
66+
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
6167
mode: Evaluation mode. "batch" (default) expects test function to handle
6268
full dataset. "pointwise" applies test function to each row. If your evaluation requires
6369
the full rollout of all rows to compute the score, use
64-
65-
Usage:
66-
With an input dataset and input params, the test function will be called with the following arguments:
67-
68-
```python
69-
@evaluation_test(
70-
model=["gpt-4o", "gpt-4o-mini"],
71-
input_dataset=["data/test.jsonl"],
72-
input_params=[{"temperature": 0.5}],
73-
rollout_processor=default_rollout_processor,
74-
aggregation_method="mean",
75-
)
76-
def test_func(dataset_path: str, model_name: str, input_params: Dict[str, Any]):
77-
pass
78-
```
79-
80-
Without an input dataset and input params, the test function will be called with the following arguments:
81-
82-
```python
83-
@evaluation_test(
84-
model=["gpt-4o", "gpt-4o-mini"],
85-
)
86-
def test_func(model_name: str):
87-
pass
88-
```
89-
90-
With model and input_messages, the test function will be called with the following arguments:
91-
92-
```python
93-
@evaluation_test(
94-
model=["gpt-4o", "gpt-4o-mini"],
95-
input_messages=[{"role": "user", "content": "Hello, how are you?"}],
96-
)
97-
def test_func(model_name: str, input_messages: List[List[Message]]):
98-
pass
99-
```
10070
"""
10171

10272
def decorator(
@@ -132,18 +102,12 @@ def decorator(
132102

133103
def execute_with_params(
134104
test_func: TestFunction,
135-
model: str,
136105
row: EvaluationRow | None = None,
137106
input_dataset: List[EvaluationRow] | None = None,
138-
input_params: InputParam | None = None,
139107
):
140108
kwargs = {}
141109
if input_dataset is not None:
142110
kwargs["rows"] = input_dataset
143-
if input_params is not None:
144-
kwargs["input_params"] = input_params
145-
if model is not None:
146-
kwargs["model"] = model
147111
if row is not None:
148112
kwargs["row"] = row
149113
return execute_function(test_func, **kwargs)
@@ -231,9 +195,7 @@ def wrapper_body(**kwargs):
231195
for row in input_dataset:
232196
result = execute_with_params(
233197
test_func,
234-
model=model_name,
235198
row=row,
236-
input_params=kwargs.get("input_params") if "input_params" in kwargs else None,
237199
)
238200
if result is None or not isinstance(result, EvaluationRow):
239201
raise ValueError(
@@ -244,9 +206,7 @@ def wrapper_body(**kwargs):
244206
# Batch mode: call the test function with the full dataset
245207
results = execute_with_params(
246208
test_func,
247-
model=model_name,
248209
input_dataset=input_dataset,
249-
input_params=kwargs.get("input_params") if "input_params" in kwargs else None,
250210
)
251211
if results is None:
252212
raise ValueError(

eval_protocol/pytest/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import inspect
3-
from typing import Any, Callable, List
3+
from typing import Any, Callable, List, Literal
44

55
from ..models import EvaluateResult, EvaluationRow
66

@@ -51,7 +51,10 @@ def evaluate(
5151
return evaluated
5252

5353

54-
def aggregate(scores: List[float], method: str) -> float:
54+
AggregationMethod = Literal["mean", "max", "min"]
55+
56+
57+
def aggregate(scores: List[float], method: AggregationMethod) -> float:
5558
if not scores:
5659
return 0.0
5760
if method == "mean":

tests/pytest/test_pytest_word_count_example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from haikus import haikus
2+
3+
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
14
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
2-
from eval_protocol.models import EvaluateResult, MetricResult, EvaluationRow
35
from tests.pytest.helper.word_count_to_evaluation_row import word_count_to_evaluation_row
4-
from haikus import haikus
56

67

78
@evaluation_test(
@@ -74,8 +75,9 @@ def test_word_count_evaluate(row: EvaluationRow) -> EvaluationRow:
7475
),
7576
}
7677

77-
return EvaluateResult(
78+
row.evaluation_result = EvaluateResult(
7879
score=word_count_score,
7980
reason=f"Word count: {word_count}. {haiku_metric_reason}",
8081
metrics=metrics,
8182
)
83+
return row

0 commit comments

Comments
 (0)