Skip to content

Commit 90933d3

Browse files
author
Dylan Huang
committed
DRY completion params and make it a dict
1 parent 8915c7e commit 90933d3

File tree

14 files changed

+203
-211
lines changed

14 files changed

+203
-211
lines changed

eval_protocol/adapters/huggingface.py

Lines changed: 133 additions & 135 deletions
Large diffs are not rendered by default.

eval_protocol/adapters/langfuse.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
to EvaluationRow format for use in evaluation pipelines.
55
"""
66

7-
from typing import Any, Dict, Iterator, List, Optional
8-
from datetime import datetime
97
import logging
8+
from datetime import datetime
9+
from typing import Any, Dict, Iterator, List, Optional
1010

11-
from eval_protocol.models import EvaluationRow, Message, InputMetadata, CompletionParams
11+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -277,20 +277,20 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
277277
InputMetadata object
278278
"""
279279
# Extract completion parameters from observations
280-
completion_params = CompletionParams()
280+
completion_params = {}
281281

282282
# Look for model parameters in observations
283283
for obs in observations:
284284
if hasattr(obs, "model") and obs.model:
285-
completion_params.model = obs.model
285+
completion_params["model"] = obs.model
286286
if hasattr(obs, "model_parameters") and obs.model_parameters:
287287
params = obs.model_parameters
288288
if "temperature" in params:
289-
completion_params.temperature = params["temperature"]
289+
completion_params["temperature"] = params["temperature"]
290290
if "max_tokens" in params:
291-
completion_params.max_tokens = params["max_tokens"]
291+
completion_params["max_tokens"] = params["max_tokens"]
292292
if "top_p" in params:
293-
completion_params.top_p = params["top_p"]
293+
completion_params["top_p"] = params["top_p"]
294294
break
295295

296296
# Create dataset info from trace metadata

eval_protocol/benchmarks/suites/tau_bench_retail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Dict, List
1212

1313
from eval_protocol.benchmarks.registry import export_benchmark
14-
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
14+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
1515
from eval_protocol.pytest import evaluation_test
1616
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
1717
from vendor.tau2.data_model.message import (

eval_protocol/mcp/execution/manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2121
from vendor.tau2.user.user_simulator import UserSimulator
2222

23-
from ...models import CompletionParams, EvaluationRow, InputMetadata, Message
23+
from ...models import EvaluationRow, InputMetadata, Message
2424
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
2525

2626
if TYPE_CHECKING:
@@ -128,12 +128,12 @@ async def _execute_with_semaphore(idx):
128128
evaluation_row.messages = messages
129129
evaluation_row.tools = shared_tool_schema
130130
evaluation_row.usage = CompletionUsage(**trajectory.usage)
131-
evaluation_row.input_metadata.completion_params = CompletionParams(
132-
model=policy.model_id,
133-
temperature=getattr(policy, "temperature", None),
134-
max_tokens=getattr(policy, "max_tokens", None),
135-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
136-
)
131+
evaluation_row.input_metadata.completion_params = {
132+
"model": policy.model_id,
133+
"temperature": getattr(policy, "temperature", None),
134+
"max_tokens": getattr(policy, "max_tokens", None),
135+
"max_tool_calls": getattr(policy, "max_tools_per_turn", None),
136+
}
137137

138138
if trajectory.terminated:
139139
if trajectory.termination_reason == TerminationReason.ERROR:

eval_protocol/models.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from datetime import datetime
3-
from typing import Any, Dict, List, Literal, Optional, Union
3+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
44

55
from openai.types import CompletionUsage
66
from openai.types.chat.chat_completion_message import (
@@ -178,16 +178,18 @@ def __iter__(self):
178178
return iter(self.__fields__.keys()) # Changed to __fields__
179179

180180

181-
class CompletionParams(BaseModel):
182-
"""Configuration for the language model used in the session."""
181+
CompletionParams = Dict[str, Any]
182+
"""
183+
Common set of completion parameters that most model providers support in their
184+
API. Set total=False to allow extra fields since LiteLLM + providers have their
185+
own set of parameters. The following parameters are common fields that are
186+
populated.
183187
184-
model: str = Field(..., description="Model identifier (e.g., 'gpt-4.1', 'fireworks/llama')")
185-
temperature: Optional[float] = Field(None, description="Temperature setting for model generation")
186-
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
187-
max_tool_calls: Optional[int] = Field(None, description="Maximum tool calls per turn")
188-
189-
# there might be model or provider specific parameters that you want to pass that should be preserved
190-
model_config = ConfigDict(extra="allow")
188+
model: str
189+
temperature: Optional[float]
190+
max_tokens: Optional[int]
191+
top_p: Optional[float]
192+
"""
191193

192194

193195
class InputMetadata(BaseModel):
@@ -196,7 +198,7 @@ class InputMetadata(BaseModel):
196198
model_config = ConfigDict(extra="allow")
197199

198200
row_id: Optional[str] = Field(default_factory=generate_id, description="Unique string to ID the row")
199-
completion_params: Optional[CompletionParams] = Field(None, description="Completion endpoint parameters used")
201+
completion_params: CompletionParams = Field(..., description="Completion endpoint parameters used")
200202
dataset_info: Optional[Dict[str, Any]] = Field(
201203
None, description="Dataset row details: seed, system_prompt, environment_context, etc"
202204
)

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,20 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4141

4242
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
4343

44-
request_params = {
45-
"model": config.completion_params.model,
46-
"messages": messages_payload,
47-
**config.completion_params,
48-
}
44+
request_params = {"messages": messages_payload, **config.completion_params}
4945
# Ensure caching is disabled only for this request (review feedback)
5046
request_params["cache"] = {"no-cache": True}
5147
# Single-level reasoning effort: expect `reasoning_effort` only
5248
effort_val = None
53-
if isinstance(config.completion_params, dict):
54-
if "reasoning_effort" in config.completion_params:
55-
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
56-
elif (
57-
isinstance(config.completion_params.get("extra_body"), dict)
58-
and "reasoning_effort" in config.completion_params["extra_body"]
59-
):
60-
# Accept if user passed it directly inside extra_body
61-
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
49+
50+
if "reasoning_effort" in config.completion_params:
51+
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
52+
elif (
53+
isinstance(config.completion_params.get("extra_body"), dict)
54+
and "reasoning_effort" in config.completion_params["extra_body"]
55+
):
56+
# Accept if user passed it directly inside extra_body
57+
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
6258

6359
if effort_val:
6460
# Always under extra_body so LiteLLM forwards to provider-specific param set

eval_protocol/pytest/evaluation_test.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
2727
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
2828
from eval_protocol.pytest.types import (
29-
CompletionsParams,
3029
Dataset,
3130
DatasetPathParam,
3231
EvaluationInputParam,
@@ -52,7 +51,7 @@
5251

5352
def evaluation_test( # noqa: C901
5453
*,
55-
completion_params: List[CompletionsParams],
54+
completion_params: List[CompletionParams],
5655
input_messages: Optional[List[InputMessagesParam]] = None,
5756
input_dataset: Optional[List[DatasetPathParam]] = None,
5857
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
@@ -240,7 +239,7 @@ def generate_combinations():
240239
datasets = [[input_dataset]] # type: ignore
241240
else:
242241
datasets = [None]
243-
cps: List[Optional[CompletionsParams]] = completion_params if completion_params is not None else [None] # type: ignore
242+
cps: List[Optional[CompletionParams]] = completion_params if completion_params is not None else [None] # type: ignore
244243
# Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
245244
# each row. Instead, pass the entire sliced list through in a single test run
246245
# so summaries aggregate all rows together (AIME-style behavior).
@@ -348,7 +347,16 @@ def _log_eval_error(
348347
else:
349348
raise ValueError("No input dataset or input messages provided")
350349

351-
completions_params = kwargs.get("completion_params") or {}
350+
if "completion_params" not in kwargs or not kwargs["completion_params"]:
351+
raise ValueError(
352+
"No completion parameters provided. Please provide a completion parameters object."
353+
)
354+
completion_params = kwargs["completion_params"]
355+
if "model" not in completion_params or not completion_params["model"]:
356+
raise ValueError(
357+
"No model provided. Please provide a model in the completion parameters object."
358+
)
359+
352360
# Optional global overrides via environment for ad-hoc experimentation
353361
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
354362
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
@@ -359,7 +367,7 @@ def _log_eval_error(
359367
if _env_override:
360368
override_obj = _json.loads(_env_override)
361369
if isinstance(override_obj, dict):
362-
completions_params = _deep_update_dict(dict(completions_params), override_obj)
370+
completion_params = _deep_update_dict(dict(completion_params), override_obj)
363371
except Exception:
364372
pass
365373

@@ -374,11 +382,6 @@ def _log_eval_error(
374382
passed=None,
375383
)
376384

377-
# Populate completion_params in input_metadata for all rows and initialize eval_metadata BEFORE rollouts
378-
completion_params = CompletionParams(
379-
**completions_params,
380-
)
381-
382385
for row in data:
383386
if row.input_metadata is None:
384387
row.input_metadata = InputMetadata()
@@ -398,13 +401,13 @@ def _log_eval_error(
398401

399402
# Prepare rollout processor config once; we will generate fresh outputs per run
400403
config = RolloutProcessorConfig(
401-
completion_params=CompletionParams(**completions_params),
404+
completion_params=completion_params,
402405
mcp_config_path=mcp_config_path or "",
403406
max_concurrent_rollouts=max_concurrent_rollouts,
404407
server_script_path=server_script_path,
405408
steps=steps,
406409
logger=active_logger,
407-
kwargs=rollout_processor_kwargs,
410+
kwargs=rollout_processor_kwargs or {},
408411
)
409412

410413
for i in range(num_runs):
@@ -611,7 +614,7 @@ def _extract_effort_tag(params: dict) -> str | None:
611614
return None
612615

613616
model_slug = _sanitize_filename(model_used)
614-
effort_tag = _extract_effort_tag(completions_params) or ""
617+
effort_tag = _extract_effort_tag(completion_params) or ""
615618
effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else ""
616619
base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json"
617620

@@ -788,7 +791,7 @@ def __ep_run_direct(
788791
input_messages=cfg.get("input_messages"),
789792
input_dataset=cfg.get("input_dataset"),
790793
dataset_adapter=cfg.get("dataset_adapter"),
791-
completions_params=rip,
794+
completion_params=rip,
792795
rollout_processor=cfg.get("rollout_processor"),
793796
aggregation_method=cfg.get("aggregation_method"),
794797
passed_threshold=cfg.get("passed_threshold"),
@@ -818,7 +821,7 @@ def run_evaluation_test_direct(
818821
input_messages: Optional[List[InputMessagesParam]] = None,
819822
input_dataset: Optional[List[DatasetPathParam]] = None,
820823
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
821-
completions_params: Optional[CompletionsParams] = None,
824+
completion_params: Optional[CompletionParams] = None,
822825
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
823826
rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None,
824827
aggregation_method: AggregationMethod = "mean",
@@ -885,7 +888,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
885888
raise ValueError("No input dataset or input messages provided")
886889

887890
# Build input params and apply env JSON override
888-
completion_params: Dict[str, Any] = completions_params or {}
891+
completion_params: Dict[str, Any] = completion_params or {}
889892
try:
890893
import json as _json
891894

@@ -911,7 +914,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
911914
for row in data:
912915
if row.input_metadata is None:
913916
row.input_metadata = InputMetadata()
914-
row.input_metadata.completion_params = CompletionParams(**completion_params)
917+
row.input_metadata.completion_params = completion_params
915918
if row.input_metadata.session_data is None:
916919
row.input_metadata.session_data = {}
917920
row.input_metadata.session_data["mode"] = mode
@@ -925,7 +928,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
925928
max_concurrent_rollouts=max_concurrent_rollouts,
926929
server_script_path=server_script_path,
927930
steps=steps,
928-
kwargs=rollout_processor_kwargs,
931+
kwargs=rollout_processor_kwargs or {},
929932
)
930933

931934
all_results: List[EvaluationRow] = []

eval_protocol/pytest/types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from eval_protocol.dataset_logger import default_logger
99
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1010

11-
from ..models import EvaluationRow, Message
11+
from ..models import CompletionParams, EvaluationRow, Message
1212

1313
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1414
DatasetPathParam = str
15-
CompletionsParams = Dict[str, Any]
1615
InputMessagesParam = List[Message]
1716
EvaluationInputParam = Dict[str, Any]
1817
RolloutProcessorInputParam = Dict[str, Any]
@@ -41,7 +40,7 @@
4140

4241
@dataclass
4342
class RolloutProcessorConfig:
44-
completion_params: CompletionsParams # input parameters for inference
43+
completion_params: CompletionParams # input parameters for inference
4544
mcp_config_path: str
4645
server_script_path: Optional[str] = (
4746
None # TODO: change from server_script_path to mcp_config_path for agent rollout processor

tests/pytest/test_frozen_lake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from typing import Any, Dict, List
99

10-
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult
10+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult
1111
from eval_protocol.pytest import evaluation_test
1212
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
1313

tests/pytest/test_lunar_lander.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from typing import Any, Dict, List
99

10-
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
10+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
1111
from eval_protocol.pytest import evaluation_test
1212
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
1313

0 commit comments

Comments
 (0)