Skip to content

Commit 2ceaf72

Browse files
author
Dylan Huang
committed
square up all the id madness and add a test
1 parent 87c3dcb commit 2ceaf72

File tree

3 files changed

+89
-11
lines changed

3 files changed

+89
-11
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def evaluation_test( # noqa: C901
6363
6464
Here are some key concepts to understand the terminology in EP:
6565
66-
- "cohort" is a group of runs with a static set of parameters. A single
66+
- "invocation" is a single execution of a test function. An invocation can
67+
generate 1 or more cohorts. Grouping by invocation might be useful to
68+
aggregate eval scores across multiple invocations when you want to aggregate
69+
scores across multiple datasets.
70+
- "cohort" is a group of runs with for a combination of parameters. A single
6771
cohort will have multiple runs if num_runs > 1.
6872
1. If your evaluation_test has combinations of parameters, it will generate
6973
multiple cohorts per combination of parameters.
@@ -85,8 +89,8 @@ def evaluation_test( # noqa: C901
8589
decorated test. It simply produces a score from 0 to 1 and attached it
8690
to the row as the "evaluation_result" field.
8791
88-
A "cohort", "run", "rollout", and "row" each have a unique ID which can be
89-
used to easily group and identify them.
92+
"invocation", "cohort", "run", "rollout", and "row" each have a unique ID
93+
which can be used to easily group and identify your dataset by.
9094
9195
Args:
9296
model: Model identifiers to query.
@@ -205,7 +209,7 @@ def generate_combinations():
205209
datasets = [[input_dataset]] # type: ignore
206210
else:
207211
datasets = [None]
208-
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
212+
rips: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
209213
# Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
210214
# each row. Instead, pass the entire sliced list through in a single test run
211215
# so summaries aggregate all rows together (AIME-style behavior).
@@ -224,15 +228,15 @@ def generate_combinations():
224228
# Generate all combinations
225229
for m in model:
226230
for ds in datasets:
227-
for ip in params:
231+
for rip in rips:
228232
for im in messages:
229233
for etk in kwargs:
230234
# if no dataset and no messages, raise an error
231235
if ds is None and im is None:
232236
raise ValueError(
233237
"No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
234238
)
235-
combinations.append((m, ds, ip, im, etk))
239+
combinations.append((m, ds, rip, im, etk))
236240

237241
return combinations
238242

@@ -245,12 +249,12 @@ def generate_combinations():
245249
# Create parameter tuples for pytest.mark.parametrize
246250
param_tuples = []
247251
for combo in combinations:
248-
model_name, dataset, params, messages, etk = combo
252+
model_name, dataset, rip, messages, etk = combo
249253
param_tuple = [model_name]
250254
if input_dataset is not None:
251255
param_tuple.append(dataset)
252256
if rollout_input_params is not None:
253-
param_tuple.append(params)
257+
param_tuple.append(rip)
254258
if input_messages is not None:
255259
param_tuple.append(messages)
256260
if evaluation_test_kwargs is not None:
@@ -271,13 +275,15 @@ def generate_combinations():
271275
# Create wrapper function with exact signature that pytest expects
272276
def create_wrapper_with_signature() -> Callable:
273277
# Create the function body that will be used
274-
cohort_id = generate_id()
278+
invocation_id = generate_id()
275279

276280
def wrapper_body(**kwargs):
277281
model_name = kwargs["model"]
278282
eval_metadata = None
279283
all_results: List[EvaluationRow] = []
280284

285+
cohort_id = generate_id()
286+
281287
def _log_eval_error(
282288
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
283289
) -> None:
@@ -358,6 +364,7 @@ def _log_eval_error(
358364
# Initialize eval_metadata for each row
359365
row.eval_metadata = eval_metadata
360366
row.cohort_id = cohort_id
367+
row.invocation_id = invocation_id
361368

362369
# has to be done in the pytest main process since it's
363370
# used to determine whether this eval has stopped

tests/pytest/test_markdown_highlighting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
from typing import Any, Dict, List
99

10-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
10+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
1111
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
1212

1313

@@ -16,7 +16,11 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
1616
Convert entries from markdown dataset to EvaluationRow objects.
1717
"""
1818
return [
19-
EvaluationRow(messages=[Message(role="user", content=row["prompt"])], ground_truth=str(row["num_highlights"]))
19+
EvaluationRow(
20+
messages=[Message(role="user", content=row["prompt"])],
21+
ground_truth=str(row["num_highlights"]),
22+
input_metadata=InputMetadata(row_id=str(row["key"])),
23+
)
2024
for row in data
2125
]
2226

tests/pytest/test_pytest_ids.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import eval_protocol.pytest.evaluation_test as evaluation_test_module
2+
from eval_protocol.models import EvaluationRow
3+
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
4+
from eval_protocol.pytest.evaluation_test import evaluation_test as evaluation_decorator
5+
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
6+
7+
8+
class InMemoryLogger:
9+
def __init__(self):
10+
self._rows = []
11+
12+
def log(self, row):
13+
self._rows.append(row)
14+
15+
def read(self):
16+
return list(self._rows)
17+
18+
19+
def test_evaluation_test_decorator_ids_single(monkeypatch):
20+
# Use an in-memory logger to avoid sqlite side effects
21+
in_memory_logger = InMemoryLogger()
22+
monkeypatch.setattr(evaluation_test_module, "default_logger", in_memory_logger, raising=False)
23+
24+
unique_run_ids = set()
25+
unique_cohort_ids = set()
26+
unique_rollout_ids = set()
27+
unique_invocation_ids = set()
28+
unique_row_ids = set()
29+
30+
@evaluation_decorator(
31+
input_dataset=[
32+
"tests/pytest/data/markdown_dataset.jsonl",
33+
"tests/pytest/data/markdown_dataset.jsonl",
34+
],
35+
rollout_input_params=[{"temperature": 0.0}, {"temperature": 1.0}],
36+
model=["dummy/local-model"],
37+
dataset_adapter=markdown_dataset_to_evaluation_row,
38+
rollout_processor=default_no_op_rollout_processor,
39+
mode="pointwise",
40+
combine_datasets=False,
41+
num_runs=5,
42+
)
43+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
44+
unique_run_ids.add(row.run_id)
45+
unique_cohort_ids.add(row.cohort_id)
46+
unique_rollout_ids.add(row.rollout_id)
47+
unique_invocation_ids.add(row.invocation_id)
48+
unique_row_ids.add(row.input_metadata.row_id)
49+
return row
50+
51+
dataset_paths = [
52+
"tests/pytest/data/markdown_dataset.jsonl",
53+
"tests/pytest/data/markdown_dataset.jsonl",
54+
]
55+
input_params_list = [{"temperature": 0.0}, {"temperature": 1.0}]
56+
57+
# Manually invoke all parameter combinations within a single test
58+
for ds_path in dataset_paths:
59+
for params in input_params_list:
60+
eval_fn(model="dummy/local-model", dataset_path=[ds_path], input_params=params)
61+
62+
# Assertions on IDs generated by the decorator logic
63+
assert len(unique_invocation_ids) == 1
64+
assert len(unique_run_ids) == 20 # 4 combinations * 5 runs each
65+
assert len(unique_cohort_ids) == 2 * 2 # 2 datasets * 2 param sets
66+
assert len(unique_row_ids) == 19 # from the markdown dataset
67+
assert len(unique_rollout_ids) == 19 * 5 * 2 * 2 # rows * runs * datasets * params

0 commit comments

Comments
 (0)