1111from dataclasses import replace
1212from typing import Any , Callable , Dict , List , Literal , Optional , Union
1313from collections import defaultdict
14-
14+ import hashlib
15+ import ast
1516from mcp .types import Completion
1617import pytest
1718
3536 EvaluationInputParam ,
3637 EvaluationTestMode ,
3738 InputMessagesParam ,
39+ InputRowsParam ,
3840 ModelParam ,
3941 RolloutProcessorConfig ,
4042 RolloutProcessorInputParam ,
@@ -81,14 +83,16 @@ def postprocess(
8183 if aggregation_method == "mean" :
8284 try :
8385 result_ci = compute_fixed_set_mu_ci ([item for sublist in all_results for item in sublist ])
84- _ , mu_ci_low , mu_ci_high , standard_error = result_ci
85- if mu_ci_low is not None and mu_ci_high is not None :
86+ _ , mu_ci_low , mu_ci_high , se = result_ci
87+ if mu_ci_low is not None and mu_ci_high is not None and se is not None :
8688 ci_low = float (mu_ci_low )
8789 ci_high = float (mu_ci_high )
90+ standard_error = float (se )
8891 # Keep agg_score as-is (mean over scores). For equal repeats per question these match.
8992 except Exception :
9093 ci_low = None
9194 ci_high = None
95+ standard_error = None
9296
9397 # Determine if the evaluation passed based on threshold
9498 passed = None
@@ -127,9 +131,10 @@ def postprocess(
127131 "num_runs" : num_runs ,
128132 "rows" : total_rows ,
129133 }
130- if ci_low is not None and ci_high is not None :
134+ if ci_low is not None and ci_high is not None and standard_error is not None :
131135 summary_obj ["agg_ci_low" ] = ci_low
132136 summary_obj ["agg_ci_high" ] = ci_high
137+ summary_obj ["standard_error" ] = standard_error
133138
134139 # Aggregate per-metric mean and 95% CI when available
135140 metrics_summary : Dict [str , Dict [str , float ]] = {}
@@ -164,9 +169,9 @@ def postprocess(
164169 if metrics_summary :
165170 summary_obj ["metrics_agg" ] = metrics_summary
166171 if should_print :
167- if ci_low is not None and ci_high is not None :
172+ if ci_low is not None and ci_high is not None and standard_error is not None :
168173 print (
169- f"EP Summary | suite={ suite_name } model={ model_used } agg={ summary_obj ['agg_score' ]:.3f} ci95=[{ ci_low :.3f} ,{ ci_high :.3f} ] runs={ num_runs } rows={ total_rows } "
174+ f"EP Summary | suite={ suite_name } model={ model_used } agg={ summary_obj ['agg_score' ]:.3f} se= { summary_obj [ 'standard_error' ]:.3f } ci95=[{ ci_low :.3f} ,{ ci_high :.3f} ] runs={ num_runs } rows={ total_rows } "
170175 )
171176 else :
172177 print (
@@ -235,6 +240,7 @@ def evaluation_test( # noqa: C901
235240 completion_params : List [CompletionParams ],
236241 input_messages : Optional [List [InputMessagesParam ]] = None ,
237242 input_dataset : Optional [List [DatasetPathParam ]] = None ,
243+ input_rows : Optional [List [InputRowsParam ]] = None ,
238244 dataset_adapter : Callable [[List [Dict [str , Any ]]], Dataset ] = default_dataset_adapter ,
239245 rollout_processor : RolloutProcessor = NoOpRolloutProcessor (),
240246 evaluation_test_kwargs : Optional [List [EvaluationInputParam ]] = None ,
@@ -245,6 +251,7 @@ def evaluation_test( # noqa: C901
245251 max_dataset_rows : Optional [int ] = None ,
246252 mcp_config_path : Optional [str ] = None ,
247253 max_concurrent_rollouts : int = 8 ,
254+ max_concurrent_evaluations : int = 64 ,
248255 server_script_path : Optional [str ] = None ,
249256 steps : int = 30 ,
250257 mode : EvaluationTestMode = "pointwise" ,
@@ -295,6 +302,9 @@ def evaluation_test( # noqa: C901
295302 input_dataset: Paths to JSONL datasets. This is useful if you have a
296303 dataset already. Provide a dataset_adapter to convert the input dataset
297304 to a list of EvaluationRows if you have a custom dataset format.
305+ input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful
306+ when you want to provide EvaluationRow objects with custom metadata, input_messages,
307+ or other fields already populated. Will be passed as "input_dataset" to the test function.
298308 dataset_adapter: Function to convert the input dataset to a list of
299309 EvaluationRows. This is useful if you have a custom dataset format.
300310 completion_params: Generation parameters for the rollout.
@@ -309,6 +319,7 @@ def evaluation_test( # noqa: C901
309319 max_dataset_rows: Limit dataset to the first N rows.
310320 mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
311321 max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
322+ max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel.
312323 server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py").
313324 steps: Number of rollout steps to execute (default: 30).
314325 mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
@@ -408,33 +419,42 @@ async def execute_with_params(
408419 # Calculate all possible combinations of parameters
409420 if mode == "groupwise" :
410421 combinations = generate_parameter_combinations (
411- input_dataset , None , input_messages , evaluation_test_kwargs , max_dataset_rows , combine_datasets
422+ input_dataset ,
423+ None ,
424+ input_messages ,
425+ input_rows ,
426+ evaluation_test_kwargs ,
427+ max_dataset_rows ,
428+ combine_datasets ,
412429 )
413430 else :
414431 combinations = generate_parameter_combinations (
415432 input_dataset ,
416433 completion_params ,
417434 input_messages ,
435+ input_rows ,
418436 evaluation_test_kwargs ,
419437 max_dataset_rows ,
420438 combine_datasets ,
421439 )
422440 if len (combinations ) == 0 :
423441 raise ValueError (
424- "No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages ."
442+ "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows ."
425443 )
426444
427445 # Create parameter tuples for pytest.mark.parametrize
428446 param_tuples = []
429447 for combo in combinations :
430- dataset , cp , messages , etk = combo
448+ dataset , cp , messages , rows , etk = combo
431449 param_tuple = []
432450 if input_dataset is not None :
433451 param_tuple .append (dataset )
434452 if completion_params is not None :
435453 param_tuple .append (cp )
436454 if input_messages is not None :
437455 param_tuple .append (messages )
456+ if input_rows is not None :
457+ param_tuple .append (rows )
438458 if evaluation_test_kwargs is not None :
439459 param_tuple .append (etk )
440460 param_tuples .append (tuple (param_tuple ))
@@ -447,6 +467,8 @@ async def execute_with_params(
447467 test_param_names .append ("completion_params" )
448468 if input_messages is not None :
449469 test_param_names .append ("input_messages" )
470+ if input_rows is not None :
471+ test_param_names .append ("input_rows" )
450472 if evaluation_test_kwargs is not None :
451473 test_param_names .append ("evaluation_test_kwargs" )
452474
@@ -472,6 +494,8 @@ def _log_eval_error(
472494 try :
473495 # Handle dataset loading
474496 data : List [EvaluationRow ] = []
497+ # Track all rows processed in the current run for error logging
498+ processed_rows_in_run : List [EvaluationRow ] = []
475499 if "dataset_path" in kwargs and kwargs ["dataset_path" ] is not None :
476500 ds_arg = kwargs ["dataset_path" ]
477501 # Support either a single path or a list of paths; if a list is provided,
@@ -496,8 +520,11 @@ def _log_eval_error(
496520 else :
497521 # Multiple rows: list of List[Message]
498522 data = [EvaluationRow (messages = m ) for m in im ]
523+ elif "input_rows" in kwargs and kwargs ["input_rows" ] is not None :
524+ # Use pre-constructed EvaluationRow objects directly
525+ data = kwargs ["input_rows" ]
499526 else :
500- raise ValueError ("No input dataset or input messages provided" )
527+ raise ValueError ("No input dataset, input messages, or input rows provided" )
501528
502529 for row in data :
503530 # generate a stable row_id for each row
@@ -585,30 +612,44 @@ def _log_eval_error(
585612 # log the fresh_dataset
586613 for row in fresh_dataset :
587614 active_logger .log (row )
615+ processed_rows_in_run .append (row )
588616
589- if mode == "pointwise" :
590- # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
591- semaphore = asyncio .Semaphore (max_concurrent_rollouts )
592- tasks = []
617+ # prepare parallel eval helper function
618+ semaphore = asyncio .Semaphore (max_concurrent_evaluations )
593619
594- async def _execute_with_semaphore (row ):
595- async with semaphore :
596- # NOTE: we will still evaluate errored rows (give users control over this)
597- # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
620+ async def _execute_eval_with_semaphore (** inner_kwargs ):
621+ async with semaphore :
622+ # NOTE: we will still evaluate errored rows (give users control over this)
623+ # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
624+ if "row" in inner_kwargs :
598625 result = await execute_with_params (
599626 test_func ,
600- processed_row = row ,
627+ processed_row = inner_kwargs [ " row" ] ,
601628 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
602629 )
603630 if result is None or not isinstance (result , EvaluationRow ):
604631 raise ValueError (
605632 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
606633 )
607634 return result
635+ if "rows" in inner_kwargs :
636+ results = await execute_with_params (
637+ test_func ,
638+ processed_dataset = inner_kwargs ["rows" ],
639+ evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
640+ )
641+ if results is None or not isinstance (results , list ):
642+ raise ValueError (
643+ f"Test function { test_func .__name__ } did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
644+ )
645+ return results
608646
647+ if mode == "pointwise" :
648+ # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
649+ tasks = []
609650 # Use wrapper that handles retry logic internally
610651 async for row in rollout_processor_with_retry (rollout_processor , fresh_dataset , config ):
611- tasks .append (asyncio .create_task (_execute_with_semaphore ( row )))
652+ tasks .append (asyncio .create_task (_execute_eval_with_semaphore ( row = row )))
612653
613654 results = await asyncio .gather (* tasks )
614655
@@ -649,14 +690,13 @@ async def _collect_result(config, lst):
649690 for result in rollout_results :
650691 for row in result :
651692 row_groups [row .input_metadata .row_id ].append (row )
652- results = []
693+ tasks = []
653694 for row_id , rows in row_groups .items ():
654- result = await execute_with_params (
655- test_func ,
656- processed_dataset = rows ,
657- evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
658- )
659- results .extend (result )
695+ tasks .append (asyncio .create_task (_execute_eval_with_semaphore (rows = rows )))
696+ results = []
697+ for task in tasks :
698+ res = await task
699+ results .extend (res )
660700 all_results [i ] = results
661701 else :
662702 # Batch mode: collect all results first, then evaluate (no pipelining)
@@ -728,10 +768,16 @@ async def _collect_result(config, lst):
728768 )
729769
730770 except AssertionError :
731- _log_eval_error ("finished" , data if "data" in locals () else None , passed = False )
771+ _log_eval_error (
772+ "finished" ,
773+ processed_rows_in_run if "processed_rows_in_run" in locals () else None ,
774+ passed = False ,
775+ )
732776 raise
733777 except Exception :
734- _log_eval_error ("error" , data if "data" in locals () else None , passed = False )
778+ _log_eval_error (
779+ "error" , processed_rows_in_run if "processed_rows_in_run" in locals () else None , passed = False
780+ )
735781 raise
736782
737783 return create_dynamically_parameterized_wrapper (test_func , wrapper_body , test_param_names )
@@ -794,6 +840,13 @@ async def dual_mode_wrapper(*args, **kwargs):
794840 # If not a direct call, use the pytest wrapper
795841 return await pytest_wrapper (* args , ** kwargs )
796842
843+ dual_mode_wrapper ._origin_func = test_func
844+ dual_mode_wrapper ._metainfo = {
845+ "mode" : mode ,
846+ "max_rollout_concurrency" : max_concurrent_rollouts ,
847+ "max_evaluation_concurrency" : max_concurrent_evaluations ,
848+ }
849+
797850 # Copy all attributes from the pytest wrapper to our dual mode wrapper
798851 import functools
799852
0 commit comments