4747 aggregate ,
4848 create_dynamically_parameterized_wrapper ,
4949 deep_update_dict ,
50- execute_function ,
5150 extract_effort_tag ,
5251 generate_parameter_combinations ,
5352 log_eval_status_and_rows ,
@@ -333,6 +332,11 @@ def evaluation_test( # noqa: C901
333332
334333 active_logger : DatasetLogger = logger if logger else default_logger
335334
335+ # Apply override from pytest flags if present
336+ num_runs = parse_ep_num_runs (num_runs )
337+ max_concurrent_rollouts = parse_ep_max_concurrent_rollouts (max_concurrent_rollouts )
338+ max_dataset_rows = parse_ep_max_rows (max_dataset_rows )
339+
336340 def decorator (
337341 test_func : TestFunction ,
338342 ):
@@ -481,10 +485,7 @@ def create_wrapper_with_signature() -> Callable:
481485 async def wrapper_body (** kwargs ):
482486 eval_metadata = None
483487
484- # Apply environment override for num_runs if present
485- effective_num_runs = parse_ep_num_runs (num_runs )
486- effective_max_concurrent_rollouts = parse_ep_max_concurrent_rollouts (max_concurrent_rollouts )
487- all_results : List [List [EvaluationRow ]] = [[] for _ in range (effective_num_runs )]
488+ all_results : List [List [EvaluationRow ]] = [[] for _ in range (num_runs )]
488489
489490 experiment_id = generate_id ()
490491
@@ -508,10 +509,9 @@ def _log_eval_error(
508509 data_jsonl .extend (load_jsonl (p ))
509510 else :
510511 data_jsonl = load_jsonl (ds_arg )
511- # Apply env override for max rows if present
512- effective_max_rows = parse_ep_max_rows (max_dataset_rows )
513- if effective_max_rows is not None :
514- data_jsonl = data_jsonl [:effective_max_rows ]
512+ # Apply override for max rows if present
513+ if max_dataset_rows is not None :
514+ data_jsonl = data_jsonl [:max_dataset_rows ]
515515 data = dataset_adapter (data_jsonl )
516516 elif "input_messages" in kwargs and kwargs ["input_messages" ] is not None :
517517 # Support either a single row (List[Message]) or many rows (List[List[Message]])
@@ -563,7 +563,7 @@ def _log_eval_error(
563563 name = test_func .__name__ ,
564564 description = test_func .__doc__ ,
565565 status = "running" ,
566- num_runs = effective_num_runs ,
566+ num_runs = num_runs ,
567567 aggregation_method = aggregation_method ,
568568 passed_threshold = threshold ,
569569 passed = None ,
@@ -589,15 +589,15 @@ def _log_eval_error(
589589 config = RolloutProcessorConfig (
590590 completion_params = completion_params ,
591591 mcp_config_path = mcp_config_path or "" ,
592- max_concurrent_rollouts = effective_max_concurrent_rollouts ,
592+ max_concurrent_rollouts = max_concurrent_rollouts ,
593593 server_script_path = server_script_path ,
594594 steps = steps ,
595595 logger = active_logger ,
596596 kwargs = rollout_processor_kwargs or {},
597597 exception_handler_config = exception_handler_config ,
598598 )
599599
600- for i in range (effective_num_runs ):
600+ for i in range (num_runs ):
601601 # Regenerate outputs each run by deep-copying the pristine dataset
602602 # so model responses are not reused across runs.
603603 run_id = generate_id ()
@@ -617,7 +617,7 @@ def _log_eval_error(
617617 processed_rows_in_run .append (row )
618618
619619 # prepare parallel eval helper function
620- semaphore = asyncio .Semaphore (effective_max_concurrent_rollouts )
620+ semaphore = asyncio .Semaphore (max_concurrent_rollouts )
621621
622622 async def _execute_eval_with_semaphore (** inner_kwargs ):
623623 async with semaphore :
@@ -665,7 +665,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
665665 config = RolloutProcessorConfig (
666666 completion_params = cp ,
667667 mcp_config_path = mcp_config_path or "" ,
668- max_concurrent_rollouts = effective_max_concurrent_rollouts ,
668+ max_concurrent_rollouts = max_concurrent_rollouts ,
669669 server_script_path = server_script_path ,
670670 steps = steps ,
671671 logger = active_logger ,
@@ -739,8 +739,7 @@ async def _collect_result(config, lst):
739739 # rollout_id is used to differentiate the result from different completion_params
740740 if mode == "groupwise" :
741741 results_by_group = [
742- [[] for _ in range (effective_num_runs )]
743- for _ in range (len (original_completion_params_list ))
742+ [[] for _ in range (num_runs )] for _ in range (len (original_completion_params_list ))
744743 ]
745744 for i_run , result in enumerate (all_results ):
746745 for r in result :
@@ -755,7 +754,7 @@ async def _collect_result(config, lst):
755754 mode ,
756755 original_completion_params_list [rollout_id ],
757756 test_func .__name__ ,
758- effective_num_runs ,
757+ num_runs ,
759758 )
760759 else :
761760 postprocess (
@@ -766,7 +765,7 @@ async def _collect_result(config, lst):
766765 mode ,
767766 completion_params ,
768767 test_func .__name__ ,
769- effective_num_runs ,
768+ num_runs ,
770769 )
771770
772771 except AssertionError :
@@ -845,7 +844,7 @@ async def dual_mode_wrapper(*args, **kwargs):
845844 dual_mode_wrapper ._origin_func = test_func
846845 dual_mode_wrapper ._metainfo = {
847846 "mode" : mode ,
848- "max_rollout_concurrency" : max_concurrent_rollouts , # TODO: fix this
847+ "max_rollout_concurrency" : max_concurrent_rollouts ,
849848 "max_evaluation_concurrency" : max_concurrent_evaluations ,
850849 }
851850
0 commit comments