|
6 | 6 | # Import versioneer for getting version information |
7 | 7 | import versioneer |
8 | 8 | from eval_protocol.dataset_logger import default_logger |
9 | | -from eval_protocol.models import EvalMetadata, EvaluationRow |
| 9 | +from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata |
10 | 10 | from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter |
11 | 11 | from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor |
12 | 12 | from eval_protocol.pytest.types import ( |
@@ -207,16 +207,34 @@ def wrapper_body(**kwargs): |
207 | 207 | raise ValueError("No input dataset or input messages provided") |
208 | 208 |
|
209 | 209 | input_dataset: List[EvaluationRow] = [] |
| 210 | + input_params = kwargs.get("input_params") or {} |
210 | 211 | config = RolloutProcessorConfig( |
211 | 212 | model=model_name, |
212 | | - input_params=kwargs.get("input_params") or {}, |
| 213 | + input_params=input_params, |
213 | 214 | mcp_config_path=mcp_config_path or "", |
214 | 215 | max_concurrent_rollouts=max_concurrent_rollouts, |
215 | 216 | server_script_path=server_script_path, |
216 | 217 | steps=steps, |
217 | 218 | ) |
218 | 219 | input_dataset = execute_function(rollout_processor, rows=data, config=config) |
219 | 220 |
|
| 221 | + # Populate completion_params in input_metadata for all rows |
| 222 | + completion_params = CompletionParams( |
| 223 | + model=model_name, |
| 224 | + temperature=input_params.get("temperature"), |
| 225 | + max_tokens=input_params.get("max_tokens"), |
| 226 | + max_tool_calls=input_params.get("max_tool_calls"), |
| 227 | + ) |
| 228 | + |
| 229 | + for row in input_dataset: |
| 230 | + if row.input_metadata is None: |
| 231 | + row.input_metadata = InputMetadata() |
| 232 | + row.input_metadata.completion_params = completion_params |
| 233 | + # Add mode to session_data |
| 234 | + if row.input_metadata.session_data is None: |
| 235 | + row.input_metadata.session_data = {} |
| 236 | + row.input_metadata.session_data["mode"] = mode |
| 237 | + |
220 | 238 | all_results: List[EvaluationRow] = [] |
221 | 239 | for _ in range(num_runs): |
222 | 240 | if mode == "pointwise": |
@@ -263,6 +281,9 @@ def wrapper_body(**kwargs): |
263 | 281 | description=test_func.__doc__, |
264 | 282 | version=versioneer.get_version(), |
265 | 283 | status="finished", |
| 284 | + num_runs=num_runs, |
| 285 | + aggregation_method=aggregation_method, |
| 286 | + threshold_of_success=threshold_of_success, |
266 | 287 | ) |
267 | 288 |
|
268 | 289 | # Add metadata to all results before logging |
|
0 commit comments