Skip to content

Commit 7be345a

Browse files
committed
checkpointing
1 parent 38a4444 commit 7be345a

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,34 @@ def _log_eval_error(
404404
for row in fresh_dataset:
405405
active_logger.log(row)
406406

407-
processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)
407+
# filter out rows that already have completed rollouts via checkpointing
408+
rows_to_process = []
409+
completed_rollout_ids = set()
410+
411+
finished_logs = active_logger.read()
412+
413+
for finished_row in finished_logs:
414+
# need to add finished rows to all_results so that we can aggregate them later.
415+
all_results.append(finished_row)
416+
# TODO: need to also add the num_run to track which run the row belongs to.
417+
# TODO: ask why we made row_id optional in the first place. checkpointing won't work without some ID.
418+
if finished_row.input_metadata and finished_row.input_metadata.row_id:
419+
completed_rollout_ids.add(finished_row.input_metadata.row_id)
420+
421+
for row in fresh_dataset:
422+
row_id = row.input_metadata.row_id if row.input_metadata else None
423+
if row_id not in completed_rollout_ids:
424+
rows_to_process.append(row)
425+
426+
if len(rows_to_process) < len(fresh_dataset):
427+
print(
428+
f"Checkpointing: Found {len(fresh_dataset) - len(rows_to_process)} completed rows, processing {len(rows_to_process)} remaining rows"
429+
)
430+
431+
if rows_to_process:
432+
processed_dataset = execute_function(
433+
rollout_processor, rows=rows_to_process, config=config
434+
)
408435

409436
if mode == "pointwise":
410437
# Pointwise mode: apply the evaluator function to each row

0 commit comments

Comments
 (0)