Skip to content

Commit a0edb8f

Browse files
committed
default
1 parent bdd630c commit a0edb8f

File tree

1 file changed

+16
-34
lines changed

1 file changed

+16
-34
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -418,43 +418,25 @@ def _log_eval_error(
418418
max_retry = int(os.getenv("EP_MAX_RETRY", "0"))
419419

420420
for i in range(num_runs):
421-
# Regenerate outputs each run by deep-copying the pristine dataset
422-
# so model responses are not reused across runs.
423421
run_id = generate_id()
424422
retry_attempt = 0
425423
current_data = data
426424

427-
while retry_attempt <= max_retry:
428-
if retry_attempt > 0:
429-
logged_rows = active_logger.read()
430-
failed_rows = [
431-
row
432-
for row in logged_rows
433-
if row.rollout_status
434-
and row.rollout_status.status == "error"
435-
and row.run_id == run_id
436-
]
437-
if not failed_rows:
438-
break
439-
current_data = failed_rows
440-
441-
# Regenerate outputs each run by deep-copying the pristine dataset
442-
# so model responses are not reused across runs.
443-
fresh_dataset = [r.model_copy(deep=True) for r in current_data]
444-
445-
# apply new run_id to fresh_dataset
446-
for row in fresh_dataset:
447-
row.run_id = run_id
448-
449-
# generate new rollout_id for each row
450-
for row in fresh_dataset:
451-
row.rollout_id = generate_id()
452-
453-
# log the fresh_dataset
454-
for row in fresh_dataset:
455-
active_logger.log(row)
456-
457-
rollout_result = rollout_processor(fresh_dataset, config)
425+
# Regenerate outputs each run by deep-copying the pristine dataset
426+
# so model responses are not reused across runs.
427+
fresh_dataset = [r.model_copy(deep=True) for r in current_data]
428+
429+
# apply new run_id to fresh_dataset
430+
for row in fresh_dataset:
431+
row.run_id = run_id
432+
433+
# generate new rollout_id for each row
434+
for row in fresh_dataset:
435+
row.rollout_id = generate_id()
436+
437+
# log the fresh_dataset
438+
for row in fresh_dataset:
439+
active_logger.log(row)
458440

459441
if mode == "pointwise":
460442
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
@@ -482,7 +464,7 @@ async def _execute_with_semaphore(row):
482464
else:
483465
# Batch mode: collect all results first, then evaluate (no pipelining)
484466
input_dataset = []
485-
async for row in rollout_result:
467+
async for row in rollout_processor(fresh_dataset, config):
486468
input_dataset.append(row)
487469

488470
results = await execute_with_params(

0 commit comments

Comments
 (0)