Skip to content

Commit ffd073d

Browse files
committed
Add all options to plugin
1 parent 8c87240 commit ffd073d

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
generate_parameter_combinations,
5151
log_eval_status_and_rows,
5252
parse_ep_max_rows,
53+
parse_ep_num_runs,
5354
rollout_processor_with_retry,
5455
sanitize_filename,
5556
)
@@ -456,7 +457,10 @@ def create_wrapper_with_signature() -> Callable:
456457

457458
async def wrapper_body(**kwargs):
458459
eval_metadata = None
459-
all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]
460+
461+
# Apply environment override for num_runs if present
462+
effective_num_runs = parse_ep_num_runs(num_runs)
463+
all_results: List[List[EvaluationRow]] = [[] for _ in range(effective_num_runs)]
460464

461465
experiment_id = generate_id()
462466

@@ -530,7 +534,7 @@ def _log_eval_error(
530534
name=test_func.__name__,
531535
description=test_func.__doc__,
532536
status="running",
533-
num_runs=num_runs,
537+
num_runs=effective_num_runs,
534538
aggregation_method=aggregation_method,
535539
passed_threshold=threshold,
536540
passed=None,
@@ -564,7 +568,7 @@ def _log_eval_error(
564568
exception_handler_config=exception_handler_config,
565569
)
566570

567-
for i in range(num_runs):
571+
for i in range(effective_num_runs):
568572
# Regenerate outputs each run by deep-copying the pristine dataset
569573
# so model responses are not reused across runs.
570574
run_id = generate_id()
@@ -693,7 +697,8 @@ async def _collect_result(config, lst):
693697
# rollout_id is used to differentiate the result from different completion_params
694698
if mode == "groupwise":
695699
results_by_group = [
696-
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
700+
[[] for _ in range(effective_num_runs)]
701+
for _ in range(len(original_completion_params_list))
697702
]
698703
for i_run, result in enumerate(all_results):
699704
for r in result:
@@ -708,7 +713,7 @@ async def _collect_result(config, lst):
708713
mode,
709714
original_completion_params_list[rollout_id],
710715
test_func.__name__,
711-
num_runs,
716+
effective_num_runs,
712717
)
713718
else:
714719
postprocess(
@@ -719,7 +724,7 @@ async def _collect_result(config, lst):
719724
mode,
720725
completion_params,
721726
test_func.__name__,
722-
num_runs,
727+
effective_num_runs,
723728
)
724729

725730
except AssertionError:

eval_protocol/pytest/plugin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def pytest_addoption(parser) -> None:
2828
"Pass an integer (e.g., 2, 50) or 'all' for no limit."
2929
),
3030
)
31+
group.addoption(
32+
"--ep-num-runs",
33+
action="store",
34+
default=None,
35+
help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."),
36+
)
3137
group.addoption(
3238
"--ep-print-summary",
3339
action="store_true",
@@ -92,6 +98,20 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
9298
return None
9399

94100

101+
def _normalize_num_runs(val: Optional[str]) -> Optional[str]:
102+
if val is None:
103+
return None
104+
s = val.strip()
105+
# Validate int; if invalid, ignore and return None (no override)
106+
try:
107+
num = int(s)
108+
if num <= 0:
109+
return None # num_runs must be positive
110+
return str(num)
111+
except ValueError:
112+
return None
113+
114+
95115
def pytest_configure(config) -> None:
96116
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
97117
try:
@@ -110,6 +130,11 @@ def pytest_configure(config) -> None:
110130
if norm is not None:
111131
os.environ["EP_MAX_DATASET_ROWS"] = norm
112132

133+
num_runs_val = config.getoption("--ep-num-runs")
134+
norm_runs = _normalize_num_runs(num_runs_val)
135+
if norm_runs is not None:
136+
os.environ["EP_NUM_RUNS"] = norm_runs
137+
113138
if config.getoption("--ep-print-summary"):
114139
os.environ["EP_PRINT_SUMMARY"] = "1"
115140

eval_protocol/pytest/utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,24 @@ def log_eval_status_and_rows(
139139

140140

141141
def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]:
142-
"""Read EP_MAX_DATASET_ROWS env override as int or None."""
142+
"""Read EP_MAX_DATASET_ROWS env override as int or None.
143+
144+
Assumes the environment variable was already validated by plugin.py.
145+
"""
143146
raw = os.getenv("EP_MAX_DATASET_ROWS")
144147
if raw is None:
145148
return default_value
146-
s = raw.strip().lower()
147-
if s == "none":
148-
return None
149-
try:
150-
return int(s)
151-
except ValueError:
152-
return default_value
149+
# plugin.py stores "None" as string for the "all" case
150+
return None if raw.lower() == "none" else int(raw)
151+
152+
153+
def parse_ep_num_runs(default_value: int) -> int:
154+
"""Read EP_NUM_RUNS env override as int.
155+
156+
Assumes the environment variable was already validated by plugin.py.
157+
"""
158+
raw = os.getenv("EP_NUM_RUNS")
159+
return int(raw) if raw is not None else default_value
153160

154161

155162
def deep_update_dict(base: dict, override: dict) -> dict:

0 commit comments

Comments
 (0)