Skip to content

Commit 7c10e74

Browse files
authored
tqdm progress bars (#154)
1 parent 2d4a350 commit 7c10e74

2 files changed

Lines changed: 61 additions & 17 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import asyncio
22
import inspect
33
import os
4+
import sys
45
from collections import defaultdict
56
from typing import Any, Callable
67
from typing_extensions import Unpack
78
from collections.abc import Sequence
89

910
import pytest
11+
from tqdm import tqdm
1012

1113
from eval_protocol.dataset_logger import default_logger
1214
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
@@ -297,7 +299,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
297299
exception_handler_config=exception_handler_config,
298300
)
299301

300-
async def execute_run(i: int, config: RolloutProcessorConfig):
302+
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
301303
nonlocal all_results
302304

303305
# Regenerate outputs each run by deep-copying the pristine dataset
@@ -357,13 +359,15 @@ async def _execute_groupwise_eval_with_semaphore(
357359
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
358360
pointwise_tasks: list[asyncio.Task[EvaluationRow]] = []
359361
# Use wrapper that handles retry logic internally
360-
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
362+
async for row in rollout_processor_with_retry(
363+
rollout_processor, fresh_dataset, config, run_idx
364+
):
361365
pointwise_tasks.append(
362366
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
363367
)
364368
results = await asyncio.gather(*pointwise_tasks)
365369

366-
all_results[i] = results
370+
all_results[run_idx] = results
367371
elif mode == "groupwise":
368372
# rollout all the completion_params for the same row at once, and then send the output to the test_func
369373
row_groups = defaultdict( # pyright: ignore[reportUnknownVariableType]
@@ -385,7 +389,9 @@ async def _execute_groupwise_eval_with_semaphore(
385389

386390
async def _collect_result(config, lst): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
387391
result = []
388-
async for row in rollout_processor_with_retry(rollout_processor, lst, config): # pyright: ignore[reportUnknownArgumentType]
392+
async for row in rollout_processor_with_retry(
393+
rollout_processor, lst, config, run_idx
394+
): # pyright: ignore[reportUnknownArgumentType]
389395
result.append(row) # pyright: ignore[reportUnknownMemberType]
390396
return result # pyright: ignore[reportUnknownVariableType]
391397

@@ -409,11 +415,13 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
409415
for task in tasks:
410416
res = await task
411417
results.extend(res) # pyright: ignore[reportUnknownMemberType]
412-
all_results[i] = results
418+
all_results[run_idx] = results
413419
else:
414420
# Batch mode: collect all results first, then evaluate (no pipelining)
415421
input_dataset = []
416-
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
422+
async for row in rollout_processor_with_retry(
423+
rollout_processor, fresh_dataset, config, run_idx
424+
):
417425
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
418426
# NOTE: we will still evaluate errored rows (give users control over this)
419427
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
@@ -438,7 +446,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
438446
raise ValueError(
439447
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
440448
)
441-
all_results[i] = results
449+
all_results[run_idx] = results
442450

443451
for r in results:
444452
if r.eval_metadata is not None:
@@ -472,16 +480,34 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
472480
# else, we execute runs in parallel
473481
if isinstance(rollout_processor, MCPGymRolloutProcessor):
474482
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
475-
for i in range(num_runs):
476-
task = asyncio.create_task(execute_run(i, config))
483+
# For now, no tqdm progress bar because logs override it, we can revisit this later
484+
for run_idx in range(num_runs):
485+
task = asyncio.create_task(execute_run(run_idx, config))
477486
await task
478487
else:
479488
# For other processors, create all tasks at once and run in parallel
480489
# Concurrency is now controlled by the shared semaphore in each rollout processor
481-
tasks = []
482-
for i in range(num_runs):
483-
tasks.append(asyncio.create_task(execute_run(i, config))) # pyright: ignore[reportUnknownMemberType]
484-
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]
490+
with tqdm(
491+
total=num_runs,
492+
desc="Runs (Parallel)",
493+
unit="run",
494+
file=sys.__stderr__,
495+
position=0,
496+
leave=True,
497+
dynamic_ncols=True,
498+
miniters=1,
499+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
500+
) as run_pbar:
501+
502+
async def execute_run_with_progress(run_idx: int, config):
503+
result = await execute_run(run_idx, config)
504+
run_pbar.update(1)
505+
return result
506+
507+
tasks = []
508+
for run_idx in range(num_runs):
509+
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
510+
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]
485511

486512
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
487513
# rollout_id is used to differentiate the result from different completion_params

eval_protocol/pytest/utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from collections.abc import Sequence
33
import os
44
import re
5+
import sys
56
from dataclasses import replace
67
from typing import Any, Literal
78

9+
from tqdm import tqdm
10+
811
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
912
from eval_protocol.models import (
1013
EvalMetadata,
@@ -157,6 +160,7 @@ async def rollout_processor_with_retry(
157160
rollout_processor: RolloutProcessor,
158161
fresh_dataset: list[EvaluationRow],
159162
config: RolloutProcessorConfig,
163+
run_idx: int = 0,
160164
):
161165
"""
162166
Wrapper around rollout_processor that handles retry logic using the Python backoff library.
@@ -240,10 +244,24 @@ async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRo
240244
for i, task in enumerate(base_tasks)
241245
]
242246

243-
# Yield results as they complete
244-
for task in asyncio.as_completed(retry_tasks):
245-
result = await task
246-
yield result
247+
position = run_idx + 1 # Position 0 is reserved for main run bar, so shift up by 1
248+
with tqdm(
249+
total=len(retry_tasks),
250+
desc=f" Run {position}",
251+
unit="rollout",
252+
file=sys.__stderr__,
253+
leave=False,
254+
position=position,
255+
dynamic_ncols=True,
256+
miniters=1,
257+
mininterval=0.1,
258+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
259+
) as rollout_pbar:
260+
# Yield results as they complete
261+
for task in asyncio.as_completed(retry_tasks):
262+
result = await task
263+
rollout_pbar.update(1)
264+
yield result
247265

248266
finally:
249267
rollout_processor.cleanup()

0 commit comments

Comments
 (0)