@@ -404,7 +404,34 @@ def _log_eval_error(
404404 for row in fresh_dataset :
405405 active_logger .log (row )
406406
407- processed_dataset = execute_function (rollout_processor , rows = fresh_dataset , config = config )
407+ # filter out rows that already have completed rollouts via checkpointing
408+ rows_to_process = []
409+ completed_rollout_ids = set ()
410+
411+ finished_logs = active_logger .read ()
412+
413+ for finished_row in finished_logs :
414+ # need to add finished rows to all_results so that we can aggregate them later.
415+ all_results .append (finished_row )
416+ # TODO: need to also add the num_run to track which run the row belongs to.
417+ # TODO: ask why we made row_id optional in the first place. checkpointing won't work without some ID.
418+ if finished_row .input_metadata and finished_row .input_metadata .row_id :
419+ completed_rollout_ids .add (finished_row .input_metadata .row_id )
420+
421+ for row in fresh_dataset :
422+ row_id = row .input_metadata .row_id if row .input_metadata else None
423+ if row_id not in completed_rollout_ids :
424+ rows_to_process .append (row )
425+
426+ if len (rows_to_process ) < len (fresh_dataset ):
427+ print (
428+ f"Checkpointing: Found { len (fresh_dataset ) - len (rows_to_process )} completed rows, processing { len (rows_to_process )} remaining rows"
429+ )
430+
431+ if rows_to_process :
432+ processed_dataset = execute_function (
433+ rollout_processor , rows = rows_to_process , config = config
434+ )
408435
409436 if mode == "pointwise" :
410437 # Pointwise mode: apply the evaluator function to each row
0 commit comments