Skip to content

Commit 5c00d97

Browse files
committed
Adding Passed Threshold to Flags
1 parent 9d00e74 commit 5c00d97

File tree

3 files changed

+97
-10
lines changed

3 files changed

+97
-10
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
parse_ep_max_concurrent_rollouts,
5757
parse_ep_num_runs,
5858
parse_ep_completion_params,
59+
parse_ep_passed_threshold,
5960
rollout_processor_with_retry,
6061
sanitize_filename,
6162
)
@@ -344,6 +345,7 @@ def evaluation_test( # noqa: C901
344345
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
345346
completion_params = parse_ep_completion_params(completion_params)
346347
original_completion_params = completion_params
348+
passed_threshold = parse_ep_passed_threshold(passed_threshold)
347349

348350
def decorator(
349351
test_func: TestFunction,

eval_protocol/pytest/plugin.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import logging
1616
import os
1717
from typing import Optional
18+
import json
19+
import pathlib
1820

1921

2022
def pytest_addoption(parser) -> None:
@@ -87,6 +89,21 @@ def pytest_addoption(parser) -> None:
8789
"Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts."
8890
),
8991
)
92+
group.addoption(
93+
"--ep-success-threshold",
94+
action="store",
95+
default=None,
96+
help=("Override the success threshold for evaluation_test. Pass a float between 0.0 and 1.0 (e.g., 0.8)."),
97+
)
98+
group.addoption(
99+
"--ep-se-threshold",
100+
action="store",
101+
default=None,
102+
help=(
103+
"Override the standard error threshold for evaluation_test. "
104+
"Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0."
105+
),
106+
)
90107

91108

92109
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -117,6 +134,49 @@ def _normalize_number(val: Optional[str]) -> Optional[str]:
117134
return None
118135

119136

137+
def _normalize_success_threshold(val: Optional[str]) -> Optional[float]:
138+
"""Normalize success threshold value as float between 0.0 and 1.0."""
139+
if val is None:
140+
return None
141+
142+
try:
143+
threshold_float = float(val.strip())
144+
if 0.0 <= threshold_float <= 1.0:
145+
return threshold_float
146+
else:
147+
return None # threshold must be between 0 and 1
148+
except ValueError:
149+
return None
150+
151+
152+
def _normalize_se_threshold(val: Optional[str]) -> Optional[float]:
153+
"""Normalize standard error threshold value as float >= 0.0."""
154+
if val is None:
155+
return None
156+
157+
try:
158+
threshold_float = float(val.strip())
159+
if threshold_float >= 0.0:
160+
return threshold_float
161+
else:
162+
return None # standard error must be >= 0
163+
except ValueError:
164+
return None
165+
166+
167+
def _build_passed_threshold_env(success: Optional[float], se: Optional[float]) -> Optional[str]:
168+
"""Build the EP_PASSED_THRESHOLD environment variable value from the two separate thresholds."""
169+
if success is None and se is None:
170+
return None
171+
172+
if se is None:
173+
return str(success)
174+
else:
175+
success_val = success if success is not None else 0.0
176+
threshold_dict = {"success": success_val, "standard_error": se}
177+
return json.dumps(threshold_dict)
178+
179+
120180
def pytest_configure(config) -> None:
121181
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
122182
try:
@@ -161,11 +221,16 @@ def pytest_configure(config) -> None:
161221
if fail_on_max_retry is not None:
162222
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
163223

224+
success_threshold_val = config.getoption("--ep-success-threshold")
225+
se_threshold_val = config.getoption("--ep-se-threshold")
226+
norm_success = _normalize_success_threshold(success_threshold_val)
227+
norm_se = _normalize_se_threshold(se_threshold_val)
228+
threshold_env = _build_passed_threshold_env(norm_success, norm_se)
229+
if threshold_env is not None:
230+
os.environ["EP_PASSED_THRESHOLD"] = threshold_env
231+
164232
# Allow ad-hoc overrides of input params via CLI flags
165233
try:
166-
import json as _json
167-
import pathlib as _pathlib
168-
169234
merged: dict = {}
170235
input_params_opts = config.getoption("--ep-input-param")
171236
if input_params_opts:
@@ -174,17 +239,17 @@ def pytest_configure(config) -> None:
174239
continue
175240
opt = str(opt)
176241
if opt.startswith("@"): # load JSON file
177-
p = _pathlib.Path(opt[1:])
242+
p = pathlib.Path(opt[1:])
178243
if p.is_file():
179244
with open(p, "r", encoding="utf-8") as f:
180-
obj = _json.load(f)
245+
obj = json.load(f)
181246
if isinstance(obj, dict):
182247
merged.update(obj)
183248
elif "=" in opt:
184249
k, v = opt.split("=", 1)
185250
# Try parse JSON values, fallback to string
186251
try:
187-
merged[k] = _json.loads(v)
252+
merged[k] = json.loads(v)
188253
except Exception:
189254
merged[k] = v
190255
reasoning_effort = config.getoption("--ep-reasoning-effort")
@@ -194,7 +259,7 @@ def pytest_configure(config) -> None:
194259
# Convert "none" string to None value for API compatibility
195260
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
196261
if merged:
197-
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
262+
os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged)
198263
except Exception:
199264
# best effort, do not crash pytest session
200265
pass

eval_protocol/pytest/utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
1919

2020
import logging
21+
import json
2122

2223

2324
def execute_function(func: Callable, **kwargs) -> Any:
@@ -176,11 +177,9 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis
176177
Reads the environment variable set by plugin.py and applies deep merge to each completion param.
177178
"""
178179
try:
179-
import json as _json
180-
181180
_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
182181
if _env_override:
183-
override_obj = _json.loads(_env_override)
182+
override_obj = json.loads(_env_override)
184183
if isinstance(override_obj, dict):
185184
# Apply override to each completion_params item
186185
return [deep_update_dict(dict(cp), override_obj) for cp in completion_params]
@@ -189,6 +188,27 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis
189188
return completion_params
190189

191190

191+
def parse_ep_passed_threshold(default_value: Optional[Union[float, dict]]) -> Optional[Union[float, dict]]:
192+
"""Read EP_PASSED_THRESHOLD env override as float or dict.
193+
194+
Assumes the environment variable was already validated by plugin.py.
195+
Supports both float values (e.g., "0.8") and JSON dict format (e.g., '{"success":0.8}').
196+
"""
197+
raw = os.getenv("EP_PASSED_THRESHOLD")
198+
if raw is None:
199+
return default_value
200+
201+
try:
202+
return float(raw)
203+
except ValueError:
204+
pass
205+
206+
try:
207+
return json.loads(raw)
208+
except (json.JSONDecodeError, TypeError, ValueError) as e:
209+
raise ValueError(f"EP_PASSED_THRESHOLD env var exists but can't be parsed: {raw}") from e
210+
211+
192212
def deep_update_dict(base: dict, override: dict) -> dict:
193213
"""Recursively update nested dictionaries in-place and return base."""
194214
for key, value in override.items():

0 commit comments

Comments
 (0)