2626from eval_protocol .pytest .default_dataset_adapter import default_dataset_adapter
2727from eval_protocol .pytest .default_no_op_rollout_process import default_no_op_rollout_processor
2828from eval_protocol .pytest .types import (
29- CompletionsParams ,
3029 Dataset ,
3130 DatasetPathParam ,
3231 EvaluationInputParam ,
5251
5352def evaluation_test ( # noqa: C901
5453 * ,
55- completion_params : List [CompletionsParams ],
54+ completion_params : List [CompletionParams ],
5655 input_messages : Optional [List [InputMessagesParam ]] = None ,
5756 input_dataset : Optional [List [DatasetPathParam ]] = None ,
5857 dataset_adapter : Callable [[List [Dict [str , Any ]]], Dataset ] = default_dataset_adapter ,
@@ -240,7 +239,7 @@ def generate_combinations():
240239 datasets = [[input_dataset ]] # type: ignore
241240 else :
242241 datasets = [None ]
243- cps : List [Optional [CompletionsParams ]] = completion_params if completion_params is not None else [None ] # type: ignore
242+ cps : List [Optional [CompletionParams ]] = completion_params if completion_params is not None else [None ] # type: ignore
244243 # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
245244 # each row. Instead, pass the entire sliced list through in a single test run
246245 # so summaries aggregate all rows together (AIME-style behavior).
@@ -348,7 +347,16 @@ def _log_eval_error(
348347 else :
349348 raise ValueError ("No input dataset or input messages provided" )
350349
351- completions_params = kwargs .get ("completion_params" ) or {}
350+ if "completion_params" not in kwargs or not kwargs ["completion_params" ]:
351+ raise ValueError (
352+ "No completion parameters provided. Please provide a completion parameters object."
353+ )
354+ completion_params = kwargs ["completion_params" ]
355+ if "model" not in completion_params or not completion_params ["model" ]:
356+ raise ValueError (
357+ "No model provided. Please provide a model in the completion parameters object."
358+ )
359+
352360 # Optional global overrides via environment for ad-hoc experimentation
353361 # EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
354362 # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
@@ -359,7 +367,7 @@ def _log_eval_error(
359367 if _env_override :
360368 override_obj = _json .loads (_env_override )
361369 if isinstance (override_obj , dict ):
362- completions_params = _deep_update_dict (dict (completions_params ), override_obj )
370+ completion_params = _deep_update_dict (dict (completion_params ), override_obj )
363371 except Exception :
364372 pass
365373
@@ -374,11 +382,6 @@ def _log_eval_error(
374382 passed = None ,
375383 )
376384
377- # Populate completion_params in input_metadata for all rows and initialize eval_metadata BEFORE rollouts
378- completion_params = CompletionParams (
379- ** completions_params ,
380- )
381-
382385 for row in data :
383386 if row .input_metadata is None :
384387 row .input_metadata = InputMetadata ()
@@ -398,13 +401,13 @@ def _log_eval_error(
398401
399402 # Prepare rollout processor config once; we will generate fresh outputs per run
400403 config = RolloutProcessorConfig (
401- completion_params = CompletionParams ( ** completions_params ) ,
404+ completion_params = completion_params ,
402405 mcp_config_path = mcp_config_path or "" ,
403406 max_concurrent_rollouts = max_concurrent_rollouts ,
404407 server_script_path = server_script_path ,
405408 steps = steps ,
406409 logger = active_logger ,
407- kwargs = rollout_processor_kwargs ,
410+ kwargs = rollout_processor_kwargs or {} ,
408411 )
409412
410413 for i in range (num_runs ):
@@ -611,7 +614,7 @@ def _extract_effort_tag(params: dict) -> str | None:
611614 return None
612615
613616 model_slug = _sanitize_filename (model_used )
614- effort_tag = _extract_effort_tag (completions_params ) or ""
617+ effort_tag = _extract_effort_tag (completion_params ) or ""
615618 effort_suffix = f"__effort-{ _sanitize_filename (effort_tag )} " if effort_tag else ""
616619 base_name = f"{ suite_name } __{ model_slug } { effort_suffix } __{ mode } __runs{ num_runs } .json"
617620
@@ -788,7 +791,7 @@ def __ep_run_direct(
788791 input_messages = cfg .get ("input_messages" ),
789792 input_dataset = cfg .get ("input_dataset" ),
790793 dataset_adapter = cfg .get ("dataset_adapter" ),
791- completions_params = rip ,
794+ completion_params = rip ,
792795 rollout_processor = cfg .get ("rollout_processor" ),
793796 aggregation_method = cfg .get ("aggregation_method" ),
794797 passed_threshold = cfg .get ("passed_threshold" ),
@@ -818,7 +821,7 @@ def run_evaluation_test_direct(
818821 input_messages : Optional [List [InputMessagesParam ]] = None ,
819822 input_dataset : Optional [List [DatasetPathParam ]] = None ,
820823 dataset_adapter : Callable [[List [Dict [str , Any ]]], Dataset ] = default_dataset_adapter ,
821- completions_params : Optional [CompletionsParams ] = None ,
824+ completion_params : Optional [CompletionParams ] = None ,
822825 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
823826 rollout_processor_kwargs : Optional [RolloutProcessorInputParam ] = None ,
824827 aggregation_method : AggregationMethod = "mean" ,
@@ -885,7 +888,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
885888 raise ValueError ("No input dataset or input messages provided" )
886889
887890 # Build input params and apply env JSON override
888- completion_params : Dict [str , Any ] = completions_params or {}
891+ completion_params : Dict [str , Any ] = completion_params or {}
889892 try :
890893 import json as _json
891894
@@ -911,7 +914,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
911914 for row in data :
912915 if row .input_metadata is None :
913916 row .input_metadata = InputMetadata ()
914- row .input_metadata .completion_params = CompletionParams ( ** completion_params )
917+ row .input_metadata .completion_params = completion_params
915918 if row .input_metadata .session_data is None :
916919 row .input_metadata .session_data = {}
917920 row .input_metadata .session_data ["mode" ] = mode
@@ -925,7 +928,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
925928 max_concurrent_rollouts = max_concurrent_rollouts ,
926929 server_script_path = server_script_path ,
927930 steps = steps ,
928- kwargs = rollout_processor_kwargs ,
931+ kwargs = rollout_processor_kwargs or {} ,
929932 )
930933
931934 all_results : List [EvaluationRow ] = []
0 commit comments