Skip to content

Commit 332f25b

Browse files
author
Dylan Huang
committed
add more fields
1 parent 6dca6e3 commit 332f25b

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

eval_protocol/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ class EvalMetadata(BaseModel):
209209
..., description="Version of the evaluation. By default, we will populate this with the current commit hash."
210210
)
211211
status: Literal["running", "finished", "error"] = Field("running", description="Status of the evaluation")
212+
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
213+
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
214+
threshold_of_success: Optional[float] = Field(None, description="Threshold score for test success")
212215

213216

214217
class EvaluationRow(BaseModel):

eval_protocol/pytest/evaluation_test.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Import versioneer for getting version information
77
import versioneer
88
from eval_protocol.dataset_logger import default_logger
9-
from eval_protocol.models import EvalMetadata, EvaluationRow
9+
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata
1010
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
1111
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
1212
from eval_protocol.pytest.types import (
@@ -207,16 +207,34 @@ def wrapper_body(**kwargs):
207207
raise ValueError("No input dataset or input messages provided")
208208

209209
input_dataset: List[EvaluationRow] = []
210+
input_params = kwargs.get("input_params") or {}
210211
config = RolloutProcessorConfig(
211212
model=model_name,
212-
input_params=kwargs.get("input_params") or {},
213+
input_params=input_params,
213214
mcp_config_path=mcp_config_path or "",
214215
max_concurrent_rollouts=max_concurrent_rollouts,
215216
server_script_path=server_script_path,
216217
steps=steps,
217218
)
218219
input_dataset = execute_function(rollout_processor, rows=data, config=config)
219220

221+
# Populate completion_params in input_metadata for all rows
222+
completion_params = CompletionParams(
223+
model=model_name,
224+
temperature=input_params.get("temperature"),
225+
max_tokens=input_params.get("max_tokens"),
226+
max_tool_calls=input_params.get("max_tool_calls"),
227+
)
228+
229+
for row in input_dataset:
230+
if row.input_metadata is None:
231+
row.input_metadata = InputMetadata()
232+
row.input_metadata.completion_params = completion_params
233+
# Add mode to session_data
234+
if row.input_metadata.session_data is None:
235+
row.input_metadata.session_data = {}
236+
row.input_metadata.session_data["mode"] = mode
237+
220238
all_results: List[EvaluationRow] = []
221239
for _ in range(num_runs):
222240
if mode == "pointwise":
@@ -263,6 +281,9 @@ def wrapper_body(**kwargs):
263281
description=test_func.__doc__,
264282
version=versioneer.get_version(),
265283
status="finished",
284+
num_runs=num_runs,
285+
aggregation_method=aggregation_method,
286+
threshold_of_success=threshold_of_success,
266287
)
267288

268289
# Add metadata to all results before logging

vite-app/src/types/eval-protocol.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ export const EvalMetadataSchema = z.object({
7878
name: z.string().describe('Name of the evaluation'),
7979
description: z.string().optional().describe('Description of the evaluation'),
8080
version: z.string().describe('Version of the evaluation. By default, we will populate this with the current commit hash.'),
81-
status: z.enum(['running', 'finished', 'error']).default('running').describe('Status of the evaluation')
81+
status: z.enum(['running', 'finished', 'error']).default('running').describe('Status of the evaluation'),
82+
num_runs: z.number().int().describe('Number of times the evaluation was repeated'),
83+
aggregation_method: z.string().describe('Method used to aggregate scores across runs'),
84+
threshold_of_success: z.number().optional().describe('Threshold score for test success')
8285
});
8386

8487
export const EvaluationRowSchema = z.object({

0 commit comments

Comments
 (0)