@@ -53,7 +53,7 @@ def evaluation_test( # noqa: C901
5353 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
5454 evaluation_test_kwargs : Optional [List [EvaluationInputParam ]] = None ,
5555 aggregation_method : AggregationMethod = "mean" ,
56- threshold : Optional [EvaluationThreshold ] = None ,
56+ passed_threshold : Optional [Union [ EvaluationThreshold , float ] ] = None ,
5757 num_runs : int = 1 ,
5858 max_dataset_rows : Optional [int ] = None ,
5959 mcp_config_path : Optional [str ] = None ,
@@ -113,7 +113,7 @@ def evaluation_test( # noqa: C901
113113 rollout_processor: Function used to perform the rollout.
114114 evaluation_test_kwargs: Kwargs for the evaluation function.
115115 aggregation_method: How to aggregate scores across rows.
116- threshold : Threshold configuration for test success.
116+ passed_threshold : Threshold configuration for test success.
117117 Success rate must be above success, and if set, standard deviation must be below standard_deviation.
118118 num_runs: Number of times to repeat the rollout and evaluations.
119119 max_dataset_rows: Limit dataset to the first N rows.
@@ -129,11 +129,11 @@ def evaluation_test( # noqa: C901
129129 def decorator (
130130 test_func : TestFunction ,
131131 ):
132- if threshold is not None :
133- if isinstance (threshold , dict ):
134- evaluation_threshold = EvaluationThreshold (** threshold )
135- elif isinstance ( threshold , float ) :
136- evaluation_threshold = EvaluationThreshold (success = threshold )
132+ if passed_threshold is not None :
133+ if isinstance (passed_threshold , float ):
134+ threshold = EvaluationThreshold (success = passed_threshold )
135+ else :
136+ threshold = EvaluationThreshold (** passed_threshold )
137137
138138 sig = inspect .signature (test_func )
139139
@@ -361,7 +361,7 @@ def _log_eval_error(
361361 status = "running" ,
362362 num_runs = num_runs ,
363363 aggregation_method = aggregation_method ,
364- threshold = evaluation_threshold ,
364+ passed_threshold = threshold ,
365365 passed = None ,
366366 )
367367
@@ -459,6 +459,7 @@ def _log_eval_error(
459459 sum ([r .evaluation_result .score for r in result if r .evaluation_result ]) / len (result )
460460 for result in all_results
461461 ]
462+ print (f"SCORES: { scores } " )
462463 agg_score = aggregate (scores , aggregation_method )
463464 score_std = statistics .stdev (scores ) if len (scores ) > 1 else 0.0
464465
@@ -495,13 +496,13 @@ def _log_eval_error(
495496 # Determine if the evaluation passed based on threshold
496497 passed = None
497498
498- if evaluation_threshold is not None :
499+ if threshold is not None :
499500 success_passed , std_passed = True , True
500501
501- success_passed = agg_score >= evaluation_threshold .success
502+ success_passed = agg_score >= threshold .success
502503
503- if evaluation_threshold .standard_deviation is not None :
504- std_passed = score_std <= evaluation_threshold .standard_deviation
504+ if threshold .standard_deviation is not None :
505+ std_passed = score_std <= threshold .standard_deviation
505506
506507 passed = success_passed and std_passed
507508
@@ -636,14 +637,14 @@ def _extract_effort_tag(params: dict) -> str | None:
636637 pass
637638
638639 # Check threshold after logging
639- if evaluation_threshold is not None and not passed :
640+ if threshold is not None and not passed :
640641 assert (
641- agg_score >= evaluation_threshold .success
642- ), f"Aggregated score { agg_score :.3f} below threshold { evaluation_threshold .success } "
643- if evaluation_threshold .standard_deviation is not None :
642+ agg_score >= threshold .success
643+ ), f"Aggregated score { agg_score :.3f} below threshold { threshold .success } "
644+ if threshold .standard_deviation is not None :
644645 assert (
645- score_std <= evaluation_threshold .standard_deviation
646- ), f"Standard deviation { score_std :.3f} above threshold { evaluation_threshold .standard_deviation } "
646+ score_std <= threshold .standard_deviation
647+ ), f"Standard deviation { score_std :.3f} above threshold { threshold .standard_deviation } "
647648
648649 except AssertionError :
649650 _log_eval_error ("finished" , data if "data" in locals () else None , passed = False )
0 commit comments