11import asyncio
22import inspect
33import os
4+ import sys
45from collections import defaultdict
56from typing import Any , Callable
67from typing_extensions import Unpack
78from collections .abc import Sequence
89
910import pytest
11+ from tqdm import tqdm
1012
1113from eval_protocol .dataset_logger import default_logger
1214from eval_protocol .dataset_logger .dataset_logger import DatasetLogger
@@ -297,7 +299,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
297299 exception_handler_config = exception_handler_config ,
298300 )
299301
300- async def execute_run (i : int , config : RolloutProcessorConfig ):
302+ async def execute_run (run_idx : int , config : RolloutProcessorConfig ):
301303 nonlocal all_results
302304
303305 # Regenerate outputs each run by deep-copying the pristine dataset
@@ -357,13 +359,15 @@ async def _execute_groupwise_eval_with_semaphore(
357359 # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
358360 pointwise_tasks : list [asyncio .Task [EvaluationRow ]] = []
359361 # Use wrapper that handles retry logic internally
360- async for row in rollout_processor_with_retry (rollout_processor , fresh_dataset , config ):
362+ async for row in rollout_processor_with_retry (
363+ rollout_processor , fresh_dataset , config , run_idx
364+ ):
361365 pointwise_tasks .append (
362366 asyncio .create_task (_execute_pointwise_eval_with_semaphore (row = row ))
363367 )
364368 results = await asyncio .gather (* pointwise_tasks )
365369
366- all_results [i ] = results
370+ all_results [run_idx ] = results
367371 elif mode == "groupwise" :
368372 # rollout all the completion_params for the same row at once, and then send the output to the test_func
369373 row_groups = defaultdict ( # pyright: ignore[reportUnknownVariableType]
@@ -385,7 +389,9 @@ async def _execute_groupwise_eval_with_semaphore(
385389
386390 async def _collect_result (config , lst ): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
387391 result = []
388- async for row in rollout_processor_with_retry (rollout_processor , lst , config ): # pyright: ignore[reportUnknownArgumentType]
392+ async for row in rollout_processor_with_retry (
393+ rollout_processor , lst , config , run_idx
394+ ): # pyright: ignore[reportUnknownArgumentType]
389395 result .append (row ) # pyright: ignore[reportUnknownMemberType]
390396 return result # pyright: ignore[reportUnknownVariableType]
391397
@@ -409,11 +415,13 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
409415 for task in tasks :
410416 res = await task
411417 results .extend (res ) # pyright: ignore[reportUnknownMemberType]
412- all_results [i ] = results
418+ all_results [run_idx ] = results
413419 else :
414420 # Batch mode: collect all results first, then evaluate (no pipelining)
415421 input_dataset = []
416- async for row in rollout_processor_with_retry (rollout_processor , fresh_dataset , config ):
422+ async for row in rollout_processor_with_retry (
423+ rollout_processor , fresh_dataset , config , run_idx
424+ ):
417425 input_dataset .append (row ) # pyright: ignore[reportUnknownMemberType]
418426 # NOTE: we will still evaluate errored rows (give users control over this)
419427 # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
@@ -438,7 +446,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
438446 raise ValueError (
439447 f"Test function { test_func .__name__ } returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
440448 )
441- all_results [i ] = results
449+ all_results [run_idx ] = results
442450
443451 for r in results :
444452 if r .eval_metadata is not None :
@@ -472,16 +480,34 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
472480 # else, we execute runs in parallel
473481 if isinstance (rollout_processor , MCPGymRolloutProcessor ):
474482 # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
475- for i in range (num_runs ):
476- task = asyncio .create_task (execute_run (i , config ))
483+ # For now, no tqdm progress bar because logs override it, we can revisit this later
484+ for run_idx in range (num_runs ):
485+ task = asyncio .create_task (execute_run (run_idx , config ))
477486 await task
478487 else :
479488 # For other processors, create all tasks at once and run in parallel
480489 # Concurrency is now controlled by the shared semaphore in each rollout processor
481- tasks = []
482- for i in range (num_runs ):
483- tasks .append (asyncio .create_task (execute_run (i , config ))) # pyright: ignore[reportUnknownMemberType]
484- await asyncio .gather (* tasks ) # pyright: ignore[reportUnknownArgumentType]
490+ with tqdm (
491+ total = num_runs ,
492+ desc = "Runs (Parallel)" ,
493+ unit = "run" ,
494+ file = sys .__stderr__ ,
495+ position = 0 ,
496+ leave = True ,
497+ dynamic_ncols = True ,
498+ miniters = 1 ,
499+ bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" ,
500+ ) as run_pbar :
501+
502+ async def execute_run_with_progress (run_idx : int , config ):
503+ result = await execute_run (run_idx , config )
504+ run_pbar .update (1 )
505+ return result
506+
507+ tasks = []
508+ for run_idx in range (num_runs ):
509+ tasks .append (asyncio .create_task (execute_run_with_progress (run_idx , config )))
510+ await asyncio .gather (* tasks ) # pyright: ignore[reportUnknownArgumentType]
485511
486512 # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
487513 # rollout_id is used to differentiate the result from different completion_params
0 commit comments