Skip to content

Commit 3ad780b

Browse files
author
Dylan Huang
committed
Refactor EvalMetadata and EvaluationRow models; add cohort_id, rollout_id, and run_id fields. Update evaluation_test to handle new identifiers and improve documentation on evaluation concepts.
1 parent b28fa2b commit 3ad780b

File tree

3 files changed

+85
-27
lines changed

3 files changed

+85
-27
lines changed

eval_protocol/models.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,6 @@ class EvalMetadata(BaseModel):
214214
status: Optional[Literal["running", "finished", "error", "stopped"]] = Field(
215215
None, description="Status of the evaluation"
216216
)
217-
run_id: Optional[str] = Field(
218-
None,
219-
description=(
220-
"Unique identifier for the run. A 'run' is a group of rows"
221-
"that were evaluated together in single configuration of a @evaluation_test."
222-
" This means that running the save @evaluation_test with "
223-
),
224-
)
225217
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
226218
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
227219
threshold_of_success: Optional[float] = Field(None, description="Threshold score for test success")
@@ -253,8 +245,8 @@ class EvaluationRow(BaseModel):
253245
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
254246
"""
255247

256-
# Core conversation data
257-
messages: List[Message] = Field(description="List of messages in the conversation/trajectory.")
248+
# Core OpenAI ChatCompletion compatible conversation data
249+
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
258250

259251
# Tool and function call information
260252
tools: Optional[List[Dict[str, Any]]] = Field(
@@ -272,6 +264,21 @@ class EvaluationRow(BaseModel):
272264
description="The status of the rollout.",
273265
)
274266

267+
cohort_id: Optional[str] = Field(
268+
default_factory=generate_id,
269+
description="The ID of the cohort that this row belongs to.",
270+
)
271+
272+
rollout_id: Optional[str] = Field(
273+
default_factory=generate_id,
274+
description="The ID of the rollout that this row belongs to.",
275+
)
276+
277+
run_id: Optional[str] = Field(
278+
None,
279+
description=("The ID of the run that this row belongs to."),
280+
)
281+
275282
# Ground truth reference (moved from EvaluateResult to top level)
276283
ground_truth: Optional[str] = Field(
277284
default=None, description="Optional ground truth reference for this evaluation."

eval_protocol/pytest/evaluation_test.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,33 @@ def evaluation_test( # noqa: C901
6161
]:
6262
"""Decorator to create pytest-based evaluation tests.
6363
64+
Here are some key concepts to understand the terminology in EP:
65+
66+
- "cohort" is a group of runs with a static set of parameters. A single
67+
cohort will have multiple runs if num_runs > 1.
68+
1. If your evaluation_test has combinations of parameters, it will generate
69+
multiple cohorts per combination of parameters.
70+
2. A new execution of a test function will generate a new cohort.
71+
- "run" is a group of rollouts. For multiple num_runs > 1, there will be
72+
multiple "run_id"s.
73+
- "rollout" is the execution/process that produces a "trajectory". You
74+
"execute" multiple rollouts to generate a dataset of trajectories.
75+
- "trajectory" is the result produced by a rollout — a list of OpenAI Chat
76+
Completion messages (e.g. the "messages" field in EvaluationRow).
77+
- "row" both the input and output of an evaluation. For example, in
78+
tau-bench, a row is a task within the dataset that can be identified as
79+
"airline_task_0" or "airline_task_1" etc. The "row_id" can be populated from
80+
the dataset itself to identify a particular task you want to evaluate. If
81+
not provided, EP will generate a "row_id" for each row whenever you call the
82+
evaluation test.
83+
- "dataset" is a collection of rows (e.g. List[EvauluationRow])
84+
- "eval" is a rubric implemented in the body of an @evaluation_test
85+
decorated test. It simply produces a score from 0 to 1 and attached it
86+
to the row as the "evaluation_result" field.
87+
88+
A "cohort", "run", "rollout", and "row" each have a unique ID which can be
89+
used to easily group and identify them.
90+
6491
Args:
6592
model: Model identifiers to query.
6693
input_messages: Messages to send to the model. This is useful if you
@@ -121,15 +148,15 @@ def decorator(
121148

122149
def execute_with_params(
123150
test_func: TestFunction,
124-
row: EvaluationRow | None = None,
125-
input_dataset: List[EvaluationRow] | None = None,
151+
processed_row: EvaluationRow | None = None,
152+
processed_dataset: List[EvaluationRow] | None = None,
126153
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
127154
):
128155
kwargs = {}
129-
if input_dataset is not None:
130-
kwargs["rows"] = input_dataset
131-
if row is not None:
132-
kwargs["row"] = row
156+
if processed_dataset is not None:
157+
kwargs["rows"] = processed_dataset
158+
if processed_row is not None:
159+
kwargs["row"] = processed_row
133160
if evaluation_test_kwargs is not None:
134161
if "row" in evaluation_test_kwargs:
135162
raise ValueError("'row' is a reserved parameter for the evaluation function")
@@ -244,7 +271,7 @@ def generate_combinations():
244271
# Create wrapper function with exact signature that pytest expects
245272
def create_wrapper_with_signature() -> Callable:
246273
# Create the function body that will be used
247-
run_id = generate_id()
274+
cohort_id = generate_id()
248275

249276
def wrapper_body(**kwargs):
250277
model_name = kwargs["model"]
@@ -310,7 +337,6 @@ def _log_eval_error(
310337
aggregation_method=aggregation_method,
311338
threshold_of_success=threshold_of_success,
312339
passed=None,
313-
run_id=run_id,
314340
)
315341

316342
# Populate completion_params in input_metadata for all rows and initialize eval_metadata BEFORE rollouts
@@ -331,6 +357,7 @@ def _log_eval_error(
331357
row.input_metadata.session_data["mode"] = mode
332358
# Initialize eval_metadata for each row
333359
row.eval_metadata = eval_metadata
360+
row.cohort_id = cohort_id
334361

335362
# has to be done in the pytest main process since it's
336363
# used to determine whether this eval has stopped
@@ -350,14 +377,25 @@ def _log_eval_error(
350377
for _ in range(num_runs):
351378
# Regenerate outputs each run by deep-copying the pristine dataset
352379
# so model responses are not reused across runs.
353-
fresh_rows = [copy.deepcopy(r) for r in data]
354-
input_dataset = execute_function(rollout_processor, rows=fresh_rows, config=config)
380+
run_id = generate_id()
381+
fresh_dataset = [copy.deepcopy(r) for r in data]
382+
383+
# apply new run_id to fresh_dataset
384+
for row in fresh_dataset:
385+
row.run_id = run_id
386+
387+
# generate new rollout_id for each row
388+
for row in fresh_dataset:
389+
row.rollout_id = generate_id()
390+
391+
processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)
392+
355393
if mode == "pointwise":
356394
# Pointwise mode: apply the evaluator function to each row
357-
for row in input_dataset:
395+
for row in processed_dataset:
358396
result = execute_with_params(
359397
test_func,
360-
row=row,
398+
processed_row=row,
361399
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
362400
)
363401
if result is None or not isinstance(result, EvaluationRow):
@@ -369,7 +407,7 @@ def _log_eval_error(
369407
# Batch mode: call the test function with the full dataset
370408
results = execute_with_params(
371409
test_func,
372-
input_dataset=input_dataset,
410+
processed_dataset=processed_dataset,
373411
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
374412
)
375413
if results is None:

vite-app/src/types/eval-protocol.ts

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ export const CompletionParamsSchema = z.object({
6262
});
6363

6464
export const InputMetadataSchema = z.object({
65-
row_id: z.string().describe('Unique string to ID the row'),
65+
row_id: z.string().optional().describe('Unique string to ID the row'),
6666
completion_params: CompletionParamsSchema.optional().describe('Completion endpoint parameters used'),
6767
dataset_info: z.record(z.string(), z.any()).optional().describe('Dataset row details: seed, system_prompt, environment_context, etc'),
6868
session_data: z.record(z.string(), z.any()).optional().describe('Session metadata like timestamp (input only, no duration/usage)')
@@ -78,18 +78,30 @@ export const EvalMetadataSchema = z.object({
7878
name: z.string().describe('Name of the evaluation'),
7979
description: z.string().optional().describe('Description of the evaluation'),
8080
version: z.string().describe('Version of the evaluation. By default, we will populate this with the current commit hash.'),
81-
status: z.enum(['running', 'finished', 'error', 'stopped']).default('running').describe('Status of the evaluation'),
81+
status: z.enum(['running', 'finished', 'error', 'stopped']).optional().describe('Status of the evaluation'),
8282
num_runs: z.number().int().describe('Number of times the evaluation was repeated'),
8383
aggregation_method: z.string().describe('Method used to aggregate scores across runs'),
8484
threshold_of_success: z.number().optional().describe('Threshold score for test success'),
85-
passed: z.boolean().optional().describe('Whether the evaluation passed based on the threshold'),
86-
run_id: z.string().optional().describe('Unique identifier for the run. A "run" is a group of rows that were evaluated together in single configuration of a @evaluation_test.')
85+
passed: z.boolean().optional().describe('Whether the evaluation passed based on the threshold')
86+
});
87+
88+
// Rollout status model (matches Python RolloutStatus)
89+
export const RolloutStatusSchema = z.object({
90+
status: z
91+
.enum(['running', 'finished', 'error', 'stopped'])
92+
.default('finished')
93+
.describe('Status of the rollout.'),
94+
error_message: z.string().optional().describe('Error message if the rollout failed.')
8795
});
8896

8997
export const EvaluationRowSchema = z.object({
9098
messages: z.array(MessageSchema).describe('List of messages in the conversation/trajectory.'),
9199
tools: z.array(z.record(z.string(), z.any())).optional().describe('Available tools/functions that were provided to the agent.'),
92100
input_metadata: InputMetadataSchema.describe('Metadata related to the input (dataset info, model config, session data, etc.).'),
101+
rollout_status: RolloutStatusSchema.default({ status: 'finished' }).describe('The status of the rollout.'),
102+
cohort_id: z.string().optional().describe('The ID of the cohort that this row belongs to.'),
103+
rollout_id: z.string().optional().describe('The ID of the rollout that this row belongs to.'),
104+
run_id: z.string().optional().describe('The ID of the run that this row belongs to.'),
93105
ground_truth: z.string().optional().describe('Optional ground truth reference for this evaluation.'),
94106
evaluation_result: EvaluateResultSchema.optional().describe('The evaluation result for this row/trajectory.'),
95107
usage: CompletionUsageSchema.optional().describe('Token usage statistics from LLM calls during execution.'),
@@ -158,6 +170,7 @@ export type InputMetadata = z.infer<typeof InputMetadataSchema>;
158170
export type CompletionUsage = z.infer<typeof CompletionUsageSchema>;
159171
export type EvalMetadata = z.infer<typeof EvalMetadataSchema>;
160172
export type EvaluationRow = z.infer<typeof EvaluationRowSchema>;
173+
export type RolloutStatus = z.infer<typeof RolloutStatusSchema>;
161174
export type ResourceServerConfig = z.infer<typeof ResourceServerConfigSchema>;
162175
export type EvaluationCriteriaModel = z.infer<typeof EvaluationCriteriaModelSchema>;
163176
export type TaskDefinitionModel = z.infer<typeof TaskDefinitionModelSchema>;

0 commit comments

Comments
 (0)