1- import inspect
2- import os
1+ import asyncio
32import copy
3+ import inspect
44import math
5+ import os
56import statistics
67from typing import Any , Callable , Dict , List , Optional
78
3334from ..common_utils import load_jsonl
3435
3536
36- def evaluation_test (
37+ def evaluation_test ( # noqa: C901
3738 * ,
3839 model : List [ModelParam ],
3940 input_messages : Optional [List [InputMessagesParam ]] = None ,
@@ -221,7 +222,7 @@ def generate_combinations():
221222 # Create wrapper function with exact signature that pytest expects
222223 def create_wrapper_with_signature () -> Callable :
223224 # Create the function body that will be used
224- def wrapper_body (** kwargs ):
225+ async def wrapper_body (** kwargs ):
225226 model_name = kwargs ["model" ]
226227 eval_metadata = None
227228 all_results : List [EvaluationRow ] = []
@@ -300,10 +301,14 @@ def wrapper_body(**kwargs):
300301 # Regenerate outputs each run by deep-copying the pristine dataset
301302 # so model responses are not reused across runs.
302303 fresh_rows = [copy .deepcopy (r ) for r in data ]
303- input_dataset = execute_function (rollout_processor , rows = fresh_rows , config = config )
304+
305+ # All rollout processors now return AsyncIterator for pipelining
306+ rollout_result = rollout_processor (fresh_rows , config )
307+
304308 if mode == "pointwise" :
305- # Pointwise mode: apply the evaluator function to each row
306- for row in input_dataset :
309+ # Pointwise mode: true pipelining with concurrent evaluations
310+ async def process_evaluation (row ):
311+ """Process a single evaluation and return the result."""
307312 result = execute_with_params (
308313 test_func ,
309314 row = row ,
@@ -313,8 +318,25 @@ def wrapper_body(**kwargs):
313318 raise ValueError (
314319 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
315320 )
316- all_results .append (result )
321+ return result
322+
323+ # Start evaluations as rollouts complete - true pipelining
324+ eval_tasks = []
325+ async for row in rollout_result :
326+ # Start evaluation immediately when rollout completes
327+ eval_task = asyncio .create_task (process_evaluation (row ))
328+ eval_tasks .append (eval_task )
329+
330+ # Collect all evaluation results
331+ if eval_tasks :
332+ eval_results = await asyncio .gather (* eval_tasks )
333+ all_results .extend (eval_results )
317334 else :
335+ # Batch mode: collect all results first, then evaluate
336+ input_dataset = []
337+ async for row in rollout_result :
338+ input_dataset .append (row )
339+
318340 # Batch mode: call the test function with the full dataset
319341 results = execute_with_params (
320342 test_func ,
@@ -353,8 +375,12 @@ def wrapper_body(**kwargs):
353375 sample_std = statistics .stdev (scores )
354376 se = sample_std / math .sqrt (n )
355377 margin = 1.96 * se
356- ci_low = float (max (0.0 , (agg_score or 0.0 ) - margin )) if agg_score is not None else None
357- ci_high = float (min (1.0 , (agg_score or 0.0 ) + margin )) if agg_score is not None else None
378+ ci_low = (
379+ float (max (0.0 , (agg_score or 0.0 ) - margin )) if agg_score is not None else None
380+ )
381+ ci_high = (
382+ float (min (1.0 , (agg_score or 0.0 ) + margin )) if agg_score is not None else None
383+ )
358384 except Exception :
359385 ci_low = None
360386 ci_high = None
@@ -392,6 +418,7 @@ def wrapper_body(**kwargs):
392418 # Aggregate per-metric mean and 95% CI when available
393419 metrics_summary : Dict [str , Dict [str , float ]] = {}
394420 from collections import defaultdict
421+
395422 metric_scores : Dict [str , list ] = defaultdict (list )
396423 for r in all_results :
397424 if r .evaluation_result and r .evaluation_result .metrics :
@@ -435,12 +462,16 @@ def wrapper_body(**kwargs):
435462 parts = []
436463 for m_name , entry in metrics_summary .items ():
437464 if "ci_low" in entry and "ci_high" in entry :
438- parts .append (f"{ m_name } ={ entry ['mean' ]:.3f} ci95=[{ entry ['ci_low' ]:.3f} ,{ entry ['ci_high' ]:.3f} ]" )
465+ parts .append (
466+ f"{ m_name } ={ entry ['mean' ]:.3f} ci95=[{ entry ['ci_low' ]:.3f} ,{ entry ['ci_high' ]:.3f} ]"
467+ )
439468 else :
440469 parts .append (f"{ m_name } ={ entry ['mean' ]:.3f} " )
441470 print (f"EP Metrics | " + ", " .join (parts ))
442471 if summary_path :
443- import json , pathlib , time
472+ import json
473+ import pathlib
474+ import time
444475
445476 p = pathlib .Path (summary_path )
446477 p .parent .mkdir (parents = True , exist_ok = True )
@@ -483,6 +514,7 @@ def wrapper_body(**kwargs):
483514 # Create the pytest wrapper
484515 pytest_wrapper = create_wrapper_with_signature ()
485516 pytest_wrapper = pytest .mark .parametrize (test_param_names , param_tuples )(pytest_wrapper )
517+ pytest_wrapper = pytest .mark .asyncio (pytest_wrapper )
486518
487519 def create_dual_mode_wrapper () -> Callable :
488520 """
@@ -500,17 +532,21 @@ def create_dual_mode_wrapper() -> Callable:
500532 """
501533 import asyncio
502534
503- # Check if the test function is async
504- is_async = asyncio .iscoroutinefunction (test_func )
535+ # Check if the pytest wrapper is async (it should be now)
536+ is_pytest_wrapper_async = asyncio .iscoroutinefunction (pytest_wrapper )
537+ is_test_func_async = asyncio .iscoroutinefunction (test_func )
505538
506- if is_async :
539+ if is_pytest_wrapper_async :
507540
508541 async def dual_mode_wrapper (* args , ** kwargs ):
509542 # Check if this is a direct call with the expected signature
510543 if mode == "pointwise" :
511544 # For pointwise mode, check if called with a single row argument
512545 if len (args ) == 1 and isinstance (args [0 ], EvaluationRow ) and not kwargs :
513- return await test_func (row = args [0 ])
546+ if is_test_func_async :
547+ return await test_func (row = args [0 ])
548+ else :
549+ return test_func (row = args [0 ])
514550 else :
515551 # For batch mode, check if called with rows argument
516552 if (
@@ -519,18 +555,24 @@ async def dual_mode_wrapper(*args, **kwargs):
519555 and all (isinstance (r , EvaluationRow ) for r in args [0 ])
520556 and not kwargs
521557 ):
522- return await test_func (rows = args [0 ])
558+ if is_test_func_async :
559+ return await test_func (rows = args [0 ])
560+ else :
561+ return test_func (rows = args [0 ])
523562 # Also check if called with keyword argument 'rows'
524563 if (
525564 len (args ) == 0
526565 and "rows" in kwargs
527566 and isinstance (kwargs ["rows" ], list )
528567 and all (isinstance (r , EvaluationRow ) for r in kwargs ["rows" ])
529568 ):
530- return await test_func (** kwargs )
569+ if is_test_func_async :
570+ return await test_func (** kwargs )
571+ else :
572+ return test_func (** kwargs )
531573
532574 # If not a direct call, use the pytest wrapper
533- return pytest_wrapper (* args , ** kwargs )
575+ return await pytest_wrapper (* args , ** kwargs )
534576
535577 else :
536578
0 commit comments