Skip to content

Commit 65abb77

Browse files
committed
Fix Completion Params
1 parent 52c0f20 commit 65abb77

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
parse_ep_max_rows,
5555
parse_ep_max_concurrent_rollouts,
5656
parse_ep_num_runs,
57+
parse_ep_completion_params,
5758
rollout_processor_with_retry,
5859
sanitize_filename,
5960
)
@@ -338,6 +339,7 @@ def evaluation_test( # noqa: C901
338339
num_runs = parse_ep_num_runs(num_runs)
339340
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
340341
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
342+
completion_params = parse_ep_completion_params(completion_params)
341343

342344
def decorator(
343345
test_func: TestFunction,
@@ -420,9 +422,6 @@ async def execute_with_params(
420422
else:
421423
return test_func(**kwargs)
422424

423-
# preserve the original completion_params list for groupwise mode
424-
original_completion_params_list = completion_params
425-
426425
# Calculate all possible combinations of parameters
427426
if mode == "groupwise":
428427
combinations = generate_parameter_combinations(
@@ -544,20 +543,6 @@ def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None,
544543
"No model provided. Please provide a model in the completion parameters object."
545544
)
546545

547-
# Optional global overrides via environment for ad-hoc experimentation
548-
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
549-
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
550-
try:
551-
import json as _json
552-
553-
_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
554-
if _env_override:
555-
override_obj = _json.loads(_env_override)
556-
if isinstance(override_obj, dict):
557-
completion_params = deep_update_dict(dict(completion_params), override_obj)
558-
except Exception:
559-
pass
560-
561546
# Create eval metadata with test function info and current commit hash
562547
eval_metadata = EvalMetadata(
563548
name=test_func.__name__,
@@ -661,7 +646,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
661646
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
662647
tasks: List[asyncio.Task[List[EvaluationRow]]] = []
663648
# completion_groups = []
664-
for idx, cp in enumerate(original_completion_params_list):
649+
for idx, cp in enumerate(completion_params):
665650
config = RolloutProcessorConfig(
666651
completion_params=cp,
667652
mcp_config_path=mcp_config_path or "",
@@ -743,9 +728,7 @@ async def _collect_result(config, lst):
743728
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
744729
# rollout_id is used to differentiate the result from different completion_params
745730
if mode == "groupwise":
746-
results_by_group = [
747-
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
748-
]
731+
results_by_group = [[[] for _ in range(num_runs)] for _ in range(len(completion_params))]
749732
for i_run, result in enumerate(all_results):
750733
for r in result:
751734
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1])
@@ -757,7 +740,7 @@ async def _collect_result(config, lst):
757740
threshold,
758741
active_logger,
759742
mode,
760-
original_completion_params_list[rollout_id],
743+
completion_params[rollout_id],
761744
test_func.__name__,
762745
num_runs,
763746
)

eval_protocol/pytest/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
170170
return int(raw) if raw is not None else default_value
171171

172172

173+
def parse_ep_completion_params(completion_params: List[CompletionParams]) -> List[CompletionParams]:
174+
"""Apply EP_INPUT_PARAMS_JSON overrides to completion_params.
175+
176+
Reads the environment variable set by plugin.py and applies deep merge to each completion param.
177+
"""
178+
try:
179+
import json as _json
180+
181+
_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
182+
if _env_override:
183+
override_obj = _json.loads(_env_override)
184+
if isinstance(override_obj, dict):
185+
# Apply override to each completion_params item
186+
return [deep_update_dict(dict(cp), override_obj) for cp in completion_params]
187+
except Exception:
188+
pass
189+
return completion_params
190+
191+
173192
def deep_update_dict(base: dict, override: dict) -> dict:
174193
"""Recursively update nested dictionaries in-place and return base."""
175194
for key, value in override.items():

0 commit comments

Comments
 (0)