Skip to content

Commit c5edee1

Browse files
committed
almost finished
1 parent c76dc68 commit c5edee1

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3535
# Single-level reasoning effort: expect `reasoning_effort` only
3636
effort_val = None
3737

38-
if "reasoning_effort" in config.completion_params:
38+
if (
39+
"reasoning_effort" in config.completion_params
40+
and config.completion_params["reasoning_effort"] is not None
41+
):
3942
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
4043
elif (
4144
isinstance(config.completion_params.get("extra_body"), dict)
4245
and "reasoning_effort" in config.completion_params["extra_body"]
46+
and config.completion_params["extra_body"]["reasoning_effort"] is not None
4347
):
4448
# Accept if user passed it directly inside extra_body
4549
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
aggregate,
4848
create_dynamically_parameterized_wrapper,
4949
deep_update_dict,
50-
execute_function,
5150
extract_effort_tag,
5251
generate_parameter_combinations,
5352
log_eval_status_and_rows,
@@ -333,6 +332,11 @@ def evaluation_test( # noqa: C901
333332

334333
active_logger: DatasetLogger = logger if logger else default_logger
335334

335+
# Apply override from pytest flags if present
336+
num_runs = parse_ep_num_runs(num_runs)
337+
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
338+
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
339+
336340
def decorator(
337341
test_func: TestFunction,
338342
):
@@ -481,10 +485,7 @@ def create_wrapper_with_signature() -> Callable:
481485
async def wrapper_body(**kwargs):
482486
eval_metadata = None
483487

484-
# Apply environment override for num_runs if present
485-
effective_num_runs = parse_ep_num_runs(num_runs)
486-
effective_max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
487-
all_results: List[List[EvaluationRow]] = [[] for _ in range(effective_num_runs)]
488+
all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]
488489

489490
experiment_id = generate_id()
490491

@@ -508,10 +509,9 @@ def _log_eval_error(
508509
data_jsonl.extend(load_jsonl(p))
509510
else:
510511
data_jsonl = load_jsonl(ds_arg)
511-
# Apply env override for max rows if present
512-
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
513-
if effective_max_rows is not None:
514-
data_jsonl = data_jsonl[:effective_max_rows]
512+
# Apply override for max rows if present
513+
if max_dataset_rows is not None:
514+
data_jsonl = data_jsonl[:max_dataset_rows]
515515
data = dataset_adapter(data_jsonl)
516516
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
517517
# Support either a single row (List[Message]) or many rows (List[List[Message]])
@@ -563,7 +563,7 @@ def _log_eval_error(
563563
name=test_func.__name__,
564564
description=test_func.__doc__,
565565
status="running",
566-
num_runs=effective_num_runs,
566+
num_runs=num_runs,
567567
aggregation_method=aggregation_method,
568568
passed_threshold=threshold,
569569
passed=None,
@@ -589,15 +589,15 @@ def _log_eval_error(
589589
config = RolloutProcessorConfig(
590590
completion_params=completion_params,
591591
mcp_config_path=mcp_config_path or "",
592-
max_concurrent_rollouts=effective_max_concurrent_rollouts,
592+
max_concurrent_rollouts=max_concurrent_rollouts,
593593
server_script_path=server_script_path,
594594
steps=steps,
595595
logger=active_logger,
596596
kwargs=rollout_processor_kwargs or {},
597597
exception_handler_config=exception_handler_config,
598598
)
599599

600-
for i in range(effective_num_runs):
600+
for i in range(num_runs):
601601
# Regenerate outputs each run by deep-copying the pristine dataset
602602
# so model responses are not reused across runs.
603603
run_id = generate_id()
@@ -617,7 +617,7 @@ def _log_eval_error(
617617
processed_rows_in_run.append(row)
618618

619619
# prepare parallel eval helper function
620-
semaphore = asyncio.Semaphore(effective_max_concurrent_rollouts)
620+
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
621621

622622
async def _execute_eval_with_semaphore(**inner_kwargs):
623623
async with semaphore:
@@ -665,7 +665,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
665665
config = RolloutProcessorConfig(
666666
completion_params=cp,
667667
mcp_config_path=mcp_config_path or "",
668-
max_concurrent_rollouts=effective_max_concurrent_rollouts,
668+
max_concurrent_rollouts=max_concurrent_rollouts,
669669
server_script_path=server_script_path,
670670
steps=steps,
671671
logger=active_logger,
@@ -739,8 +739,7 @@ async def _collect_result(config, lst):
739739
# rollout_id is used to differentiate the result from different completion_params
740740
if mode == "groupwise":
741741
results_by_group = [
742-
[[] for _ in range(effective_num_runs)]
743-
for _ in range(len(original_completion_params_list))
742+
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
744743
]
745744
for i_run, result in enumerate(all_results):
746745
for r in result:
@@ -755,7 +754,7 @@ async def _collect_result(config, lst):
755754
mode,
756755
original_completion_params_list[rollout_id],
757756
test_func.__name__,
758-
effective_num_runs,
757+
num_runs,
759758
)
760759
else:
761760
postprocess(
@@ -766,7 +765,7 @@ async def _collect_result(config, lst):
766765
mode,
767766
completion_params,
768767
test_func.__name__,
769-
effective_num_runs,
768+
num_runs,
770769
)
771770

772771
except AssertionError:
@@ -845,7 +844,7 @@ async def dual_mode_wrapper(*args, **kwargs):
845844
dual_mode_wrapper._origin_func = test_func
846845
dual_mode_wrapper._metainfo = {
847846
"mode": mode,
848-
"max_rollout_concurrency": max_concurrent_rollouts, # TODO: fix this
847+
"max_rollout_concurrency": max_concurrent_rollouts,
849848
"max_evaluation_concurrency": max_concurrent_evaluations,
850849
}
851850

eval_protocol/pytest/plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def pytest_addoption(parser) -> None:
8383
default="true",
8484
choices=["true", "false"],
8585
help=(
86+
# TODO: this is not working as expected
8687
"Whether to fail the entire rollout when permanent failures occur after max retries. "
8788
"Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts."
8889
),

0 commit comments

Comments
 (0)