5858from ..common_utils import load_jsonl
5959
6060
61- def postprocess (all_results : List [List [EvaluationRow ]],
62- aggregation_method : AggregationMethod ,
63- threshold : Optional [EvaluationThreshold ],
64- active_logger : DatasetLogger ,
65- mode : EvaluationTestMode ,
66- completion_params : CompletionParams ,
67- test_func_name : str ,
68- num_runs : int ):
61+ def postprocess (
62+ all_results : List [List [EvaluationRow ]],
63+ aggregation_method : AggregationMethod ,
64+ threshold : Optional [EvaluationThreshold ],
65+ active_logger : DatasetLogger ,
66+ mode : EvaluationTestMode ,
67+ completion_params : CompletionParams ,
68+ test_func_name : str ,
69+ num_runs : int ,
70+ ):
6971 scores = [
70- sum ([r .evaluation_result .score for r in result if r .evaluation_result ]) / len (result )
71- for result in all_results
72+ sum ([r .evaluation_result .score for r in result if r .evaluation_result ]) / len (result ) for result in all_results
7273 ]
7374 agg_score = aggregate (scores , aggregation_method )
7475
@@ -220,9 +221,7 @@ def postprocess(all_results: List[List[EvaluationRow]],
220221
221222 # Check threshold after logging
222223 if threshold is not None and not passed :
223- assert agg_score >= threshold .success , (
224- f"Aggregated score { agg_score :.3f} below threshold { threshold .success } "
225- )
224+ assert agg_score >= threshold .success , f"Aggregated score { agg_score :.3f} below threshold { threshold .success } "
226225 if threshold .standard_error is not None and standard_error is not None :
227226 assert standard_error <= threshold .standard_error , (
228227 f"Standard error { standard_error :.3f} above threshold { threshold .standard_error } "
@@ -350,21 +349,15 @@ def decorator(
350349 # additional check for groupwise evaluation
351350 elif mode == "groupwise" :
352351 if "rows" not in sig .parameters :
353- raise ValueError (
354- "In listwise mode, your eval function must have a parameter named 'rows'"
355- )
352+ raise ValueError ("In listwise mode, your eval function must have a parameter named 'rows'" )
356353
357354 # validate that "Rows" is of type List[EvaluationRow]
358355 if sig .parameters ["rows" ].annotation is not List [EvaluationRow ]:
359- raise ValueError (
360- "In listwise mode, the 'rows' parameter must be of type List[EvaluationRow"
361- )
356+ raise ValueError ("In listwise mode, the 'rows' parameter must be of type List[EvaluationRow" )
362357
363358 # validate that the function has a return type of List[EvaluationRow]
364359 if sig .return_annotation is not List [EvaluationRow ]:
365- raise ValueError (
366- "In listwise mode, your eval function must return a list of EvaluationRow instances"
367- )
360+ raise ValueError ("In listwise mode, your eval function must return a list of EvaluationRow instances" )
368361 if len (completion_params ) < 2 :
369362 raise ValueError ("In groupwise mode, you must provide at least 2 completion parameters" )
370363 else :
@@ -378,9 +371,7 @@ def decorator(
378371
379372 # validate that the function has a return type of List[EvaluationRow]
380373 if sig .return_annotation is not List [EvaluationRow ]:
381- raise ValueError (
382- "In listwise mode, your eval function must return a list of EvaluationRow instances"
383- )
374+ raise ValueError ("In listwise mode, your eval function must return a list of EvaluationRow instances" )
384375
385376 async def execute_with_params (
386377 test_func : TestFunction ,
@@ -411,7 +402,9 @@ async def execute_with_params(
411402
412403 # Calculate all possible combinations of parameters
413404 if mode == "groupwise" :
414- combinations = generate_parameter_combinations (input_dataset , None , input_dataset , evaluation_test_kwargs , max_dataset_rows , combine_datasets )
405+ combinations = generate_parameter_combinations (
406+ input_dataset , None , input_dataset , evaluation_test_kwargs , max_dataset_rows , combine_datasets
407+ )
415408 else :
416409 combinations = generate_parameter_combinations (
417410 input_dataset ,
@@ -619,7 +612,7 @@ async def _execute_with_semaphore(row):
619612 all_results [i ] = results
620613 elif mode == "groupwise" :
621614 # rollout all the completion_params for the same row at once, and then send the output to the test_func
622- row_groups = defaultdict (list ) # key: row_id, value: list of rollout_result
615+ row_groups = defaultdict (list ) # key: row_id, value: list of rollout_result
623616 tasks : List [asyncio .Task [List [EvaluationRow ]]] = []
624617 # completion_groups = []
625618 for idx , cp in enumerate (original_completion_params_list ):
@@ -636,7 +629,9 @@ async def _execute_with_semaphore(row):
636629
637630 async def _collect_result (config , lst , max_retry ):
638631 result = []
639- async for row in rollout_processor_with_retry (rollout_processor , lst , config , max_retry ):
632+ async for row in rollout_processor_with_retry (
633+ rollout_processor , lst , config , max_retry
634+ ):
640635 result .append (row )
641636 return result
642637
@@ -654,7 +649,9 @@ async def _collect_result(config, lst, max_retry):
654649 results = []
655650 for row_id , rows in row_groups .items ():
656651 result = await execute_with_params (
657- test_func , processed_dataset = rows , evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {}
652+ test_func ,
653+ processed_dataset = rows ,
654+ evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
658655 )
659656 results .extend (result )
660657 all_results [i ] = results
@@ -670,10 +667,7 @@ async def _collect_result(config, lst, max_retry):
670667 results = await execute_with_params (
671668 test_func ,
672669 processed_dataset = input_dataset ,
673- evaluation_test_kwargs = kwargs .get (
674- "evaluation_test_kwargs"
675- )
676- or {},
670+ evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
677671 )
678672 if results is None :
679673 raise ValueError (
@@ -698,17 +692,37 @@ async def _collect_result(config, lst, max_retry):
698692 r .eval_metadata .status = "finished"
699693 active_logger .log (r )
700694
701- # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
695+ # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
702696 # rollout_id is used to differentiate the result from different completion_params
703697 if mode == "groupwise" :
704- results_by_group = [[[] for _ in range (num_runs )] for _ in range (len (original_completion_params_list ))]
698+ results_by_group = [
699+ [[] for _ in range (num_runs )] for _ in range (len (original_completion_params_list ))
700+ ]
705701 for i , result in enumerate (all_results ):
706702 for r in result :
707703 results_by_group [int (r .execution_metadata .rollout_id )][i ].append (r )
708704 for i , result in enumerate (results_by_group ):
709- postprocess (result , aggregation_method , threshold , active_logger , mode , original_completion_params_list [i ], test_func .__name__ , num_runs )
705+ postprocess (
706+ result ,
707+ aggregation_method ,
708+ threshold ,
709+ active_logger ,
710+ mode ,
711+ original_completion_params_list [i ],
712+ test_func .__name__ ,
713+ num_runs ,
714+ )
710715 else :
711- postprocess (all_results , aggregation_method , threshold , active_logger , mode , completion_params , test_func .__name__ , num_runs )
716+ postprocess (
717+ all_results ,
718+ aggregation_method ,
719+ threshold ,
720+ active_logger ,
721+ mode ,
722+ completion_params ,
723+ test_func .__name__ ,
724+ num_runs ,
725+ )
712726
713727 except AssertionError :
714728 _log_eval_error ("finished" , data if "data" in locals () else None , passed = False )
0 commit comments