Skip to content

Commit c76dc68

Browse files
committed
first pass
1 parent 6bf482b commit c76dc68

File tree

5 files changed

+44
-27
lines changed

5 files changed

+44
-27
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,6 @@ class SingleTurnRolloutProcessor(RolloutProcessor):
2020

2121
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
2222
"""Generate single turn rollout tasks and return them for external handling."""
23-
24-
# Quiet LiteLLM logs in test runs unless user overrode
25-
try:
26-
if os.environ.get("LITELLM_LOG") is None:
27-
os.environ["LITELLM_LOG"] = "ERROR"
28-
_llog = logging.getLogger("LiteLLM")
29-
_llog.setLevel(logging.CRITICAL)
30-
_llog.propagate = False
31-
for _h in list(_llog.handlers):
32-
_llog.removeHandler(_h)
33-
except Exception:
34-
pass
35-
3623
# Do not modify global LiteLLM cache. Disable caching per-request instead.
3724

3825
async def process_row(row: EvaluationRow) -> EvaluationRow:

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
generate_parameter_combinations,
5353
log_eval_status_and_rows,
5454
parse_ep_max_rows,
55+
parse_ep_max_concurrent_rollouts,
5556
parse_ep_num_runs,
5657
rollout_processor_with_retry,
5758
sanitize_filename,
@@ -482,6 +483,7 @@ async def wrapper_body(**kwargs):
482483

483484
# Apply environment override for num_runs if present
484485
effective_num_runs = parse_ep_num_runs(num_runs)
486+
effective_max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
485487
all_results: List[List[EvaluationRow]] = [[] for _ in range(effective_num_runs)]
486488

487489
experiment_id = generate_id()
@@ -587,7 +589,7 @@ def _log_eval_error(
587589
config = RolloutProcessorConfig(
588590
completion_params=completion_params,
589591
mcp_config_path=mcp_config_path or "",
590-
max_concurrent_rollouts=max_concurrent_rollouts,
592+
max_concurrent_rollouts=effective_max_concurrent_rollouts,
591593
server_script_path=server_script_path,
592594
steps=steps,
593595
logger=active_logger,
@@ -615,7 +617,7 @@ def _log_eval_error(
615617
processed_rows_in_run.append(row)
616618

617619
# prepare parallel eval helper function
618-
semaphore = asyncio.Semaphore(max_concurrent_evaluations)
620+
semaphore = asyncio.Semaphore(effective_max_concurrent_rollouts)
619621

620622
async def _execute_eval_with_semaphore(**inner_kwargs):
621623
async with semaphore:
@@ -663,7 +665,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
663665
config = RolloutProcessorConfig(
664666
completion_params=cp,
665667
mcp_config_path=mcp_config_path or "",
666-
max_concurrent_rollouts=max_concurrent_rollouts,
668+
max_concurrent_rollouts=effective_max_concurrent_rollouts,
667669
server_script_path=server_script_path,
668670
steps=steps,
669671
logger=active_logger,
@@ -843,7 +845,7 @@ async def dual_mode_wrapper(*args, **kwargs):
843845
dual_mode_wrapper._origin_func = test_func
844846
dual_mode_wrapper._metainfo = {
845847
"mode": mode,
846-
"max_rollout_concurrency": max_concurrent_rollouts,
848+
"max_rollout_concurrency": max_concurrent_rollouts, # TODO: fix this
847849
"max_evaluation_concurrency": max_concurrent_evaluations,
848850
}
849851

eval_protocol/pytest/exception_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def __post_init__(self):
109109
# Override backoff settings from environment variables
110110
if "EP_MAX_RETRY" in os.environ:
111111
max_retry = int(os.environ["EP_MAX_RETRY"])
112-
if max_retry > 0:
113-
self.backoff_config.max_tries = max_retry
112+
self.backoff_config.max_tries = max_retry
114113

115114
if "EP_FAIL_ON_MAX_RETRY" in os.environ:
116115
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()

eval_protocol/pytest/plugin.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def pytest_addoption(parser) -> None:
3434
default=None,
3535
help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."),
3636
)
37+
group.addoption(
38+
"--ep-max-concurrent-rollouts",
39+
action="store",
40+
default=None,
41+
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
42+
)
3743
group.addoption(
3844
"--ep-print-summary",
3945
action="store_true",
@@ -62,14 +68,13 @@ def pytest_addoption(parser) -> None:
6268
default=None,
6369
help=(
6470
"Set reasoning.effort for providers that support it (e.g., Fireworks) via LiteLLM extra_body. "
65-
"Values: low|medium|high"
71+
"Values: low|medium|high|none"
6672
),
6773
)
6874
group.addoption(
6975
"--ep-max-retry",
7076
action="store",
71-
type=int,
72-
default=0,
77+
default=None,
7378
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
7479
)
7580
group.addoption(
@@ -98,7 +103,7 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
98103
return None
99104

100105

101-
def _normalize_num_runs(val: Optional[str]) -> Optional[str]:
106+
def _normalize_number(val: Optional[str]) -> Optional[str]:
102107
if val is None:
103108
return None
104109
s = val.strip()
@@ -131,10 +136,15 @@ def pytest_configure(config) -> None:
131136
os.environ["EP_MAX_DATASET_ROWS"] = norm
132137

133138
num_runs_val = config.getoption("--ep-num-runs")
134-
norm_runs = _normalize_num_runs(num_runs_val)
139+
norm_runs = _normalize_number(num_runs_val)
135140
if norm_runs is not None:
136141
os.environ["EP_NUM_RUNS"] = norm_runs
137142

143+
max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
144+
norm_concurrent = _normalize_number(max_concurrent_val)
145+
if norm_concurrent is not None:
146+
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent
147+
138148
if config.getoption("--ep-print-summary"):
139149
os.environ["EP_PRINT_SUMMARY"] = "1"
140150

@@ -143,10 +153,13 @@ def pytest_configure(config) -> None:
143153
os.environ["EP_SUMMARY_JSON"] = summary_json_path
144154

145155
max_retry = config.getoption("--ep-max-retry")
146-
os.environ["EP_MAX_RETRY"] = str(max_retry)
156+
norm_max_retry = _normalize_number(max_retry)
157+
if norm_max_retry is not None:
158+
os.environ["EP_MAX_RETRY"] = norm_max_retry
147159

148160
fail_on_max_retry = config.getoption("--ep-fail-on-max-retry")
149-
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
161+
if fail_on_max_retry is not None:
162+
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
150163

151164
# Allow ad-hoc overrides of input params via CLI flags
152165
try:
@@ -178,7 +191,8 @@ def pytest_configure(config) -> None:
178191
if reasoning_effort:
179192
# Always place under extra_body to avoid LiteLLM rejecting top-level params
180193
eb = merged.setdefault("extra_body", {})
181-
eb["reasoning_effort"] = str(reasoning_effort)
194+
# Convert "none" string to None value for API compatibility
195+
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
182196
if merged:
183197
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
184198
except Exception:

eval_protocol/pytest/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
)
1919
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
2020

21+
import logging
22+
2123

2224
def execute_function(func: Callable, **kwargs) -> Any:
2325
"""
@@ -160,6 +162,15 @@ def parse_ep_num_runs(default_value: int) -> int:
160162
return int(raw) if raw is not None else default_value
161163

162164

165+
def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
166+
"""Read EP_MAX_CONCURRENT_ROLLOUTS env override as int.
167+
168+
Assumes the environment variable was already validated by plugin.py.
169+
"""
170+
raw = os.getenv("EP_MAX_CONCURRENT_ROLLOUTS")
171+
return int(raw) if raw is not None else default_value
172+
173+
163174
def deep_update_dict(base: dict, override: dict) -> dict:
164175
"""Recursively update nested dictionaries in-place and return base."""
165176
for key, value in override.items():
@@ -322,10 +333,14 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
322333
return result
323334
except Exception as retry_error:
324335
# Backoff gave up
336+
logging.error(
337+
f"❌ Rollout failed, (retried {exception_config.backoff_config.max_tries} times): {repr(retry_error)}"
338+
)
325339
row.rollout_status = Status.rollout_error(str(retry_error))
326340
return row
327341
else:
328342
# Non-retryable exception - fail immediately
343+
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
329344
row.rollout_status = Status.rollout_error(str(e))
330345
return row
331346

0 commit comments

Comments
 (0)