Skip to content

Commit b4f8a5b

Browse files
authored
add chunk size (#303)
1 parent ad98650 commit b4f8a5b

File tree

4 files changed

+9
-0
lines changed

4 files changed

+9
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,4 @@ package-lock.json
242242
package.json
243243
tau2-bench
244244
*.err
245+
eval-protocol

eval_protocol/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def parse_args(args=None):
402402
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
403403
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
404404
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
405+
# Rollout chunking
406+
rft_parser.add_argument("--chunk-size", type=int, help="Data chunk size for rollout batching")
405407
# Inference params
406408
rft_parser.add_argument("--temperature", type=float)
407409
rft_parser.add_argument("--top-p", type=float)

eval_protocol/cli_commands/create_rft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def create_rft_command(args) -> int:
379379
"trainingConfig": training_config,
380380
"inferenceParameters": inference_params or None,
381381
"wandbConfig": wandb_config,
382+
"chunkSize": getattr(args, "chunk_size", None),
382383
"outputStats": None,
383384
"outputMetrics": None,
384385
"mcpServer": None,

tests/pytest/gsm8k/test_pytest_math_example.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import os
55
from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader
66
from typing import List, Dict, Any, Optional
7+
import logging
8+
9+
logger = logging.getLogger(__name__)
710

811

912
def extract_answer_digits(ground_truth: str) -> Optional[str]:
@@ -54,6 +57,7 @@ def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
5457
EvaluationRow with the evaluation result
5558
"""
5659
#### Get predicted answer value
60+
logger.info(f"I am beginning to execute GSM8k rollout: {row.execution_metadata.rollout_id}")
5761
prediction = extract_answer_digits(str(row.messages[2].content))
5862
gt = extract_answer_digits(str(row.ground_truth))
5963

@@ -77,5 +81,6 @@ def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
7781
is_score_valid=True, # Optional: Whether the score is valid, true by default
7882
reason=reason, # Optional: The reason for the score
7983
)
84+
logger.info(f"I am done executing GSM8k rollout: {row.execution_metadata.rollout_id}")
8085
row.evaluation_result = evaluation_result
8186
return row

0 commit comments

Comments
 (0)