Skip to content

Commit e1fab60

Browse files
author
Dylan Huang
committed
convert rollout_input_params to completion_params
1 parent ed4409e commit e1fab60

30 files changed

+149
-130
lines changed

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,18 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
6060

6161
@export_benchmark("aime25")
6262
@evaluation_test(
63-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
6463
input_dataset=[
6564
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl",
6665
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl",
6766
],
6867
dataset_adapter=aime2025_dataset_adapter,
69-
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
68+
completion_params=[
69+
{
70+
"max_tokens": 131000,
71+
"extra_body": {"reasoning_effort": "low"},
72+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
73+
}
74+
],
7075
rollout_processor=default_single_turn_rollout_processor,
7176
aggregation_method="mean",
7277
num_runs=8,

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ def _extract_abcd_letter(text: str) -> str | None:
6060

6161
@export_benchmark("gpqa")
6262
@evaluation_test(
63-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
6463
input_messages=_GPQA_INPUT_MESSAGES,
65-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
64+
completion_params=[
65+
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
66+
],
6667
rollout_processor=default_single_turn_rollout_processor,
6768
aggregation_method="mean",
6869
num_runs=8,

eval_protocol/benchmarks/suites/tau_bench_retail.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,13 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
6666
@evaluation_test(
6767
input_dataset=["tests/pytest/data/retail_dataset.jsonl"],
6868
dataset_adapter=tau_bench_retail_to_evaluation_row,
69-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
70-
rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}],
69+
completion_params=[
70+
{
71+
"temperature": 0.8,
72+
"extra_body": {"reasoning_effort": "medium"},
73+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
74+
}
75+
],
7176
rollout_processor=default_mcp_gym_rollout_processor,
7277
num_runs=8,
7378
mode="pointwise",

eval_protocol/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class CompletionParams(BaseModel):
186186
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
187187
max_tool_calls: Optional[int] = Field(None, description="Maximum tool calls per turn")
188188

189+
# there might be model or provider specific parameters that you want to pass that should be preserved
190+
model_config = ConfigDict(extra="allow")
191+
189192

190193
class InputMetadata(BaseModel):
191194
"""Comprehensive metadata for input to evaluation and logging systems."""

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ async def default_agent_rollout_processor(
117117
) -> List[EvaluationRow]:
118118
dataset: Dataset = []
119119
for row in rows:
120-
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
120+
agent = Agent(
121+
model=config.completion_params.model, row=row, config_path=config.mcp_config_path, logger=config.logger
122+
)
121123
await agent.setup()
122124
await agent.call_agent()
123125
dataset.append(agent.evaluation_row)

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ async def default_mcp_gym_rollout_processor(
219219
server.start()
220220

221221
policy = ep.LiteLLMPolicy(
222-
model_id=config.model,
223-
temperature=config.input_params.get("temperature", 0.0),
224-
max_tokens=config.input_params.get("max_tokens", 4096),
225-
reasoning_effort=config.input_params.get("reasoning_effort", None),
222+
model_id=config.completion_params.model,
223+
temperature=config.completion_params.get("temperature", 0.0),
224+
max_tokens=config.completion_params.get("max_tokens", 4096),
225+
reasoning_effort=config.completion_params.get("reasoning_effort", None),
226226
)
227227

228228
# Create MCP environments directly from evaluation_rows

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,24 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3333

3434
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
3535

36-
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
36+
request_params = {
37+
"model": config.completion_params.model,
38+
"messages": messages_payload,
39+
**config.completion_params,
40+
}
3741
# Ensure caching is disabled only for this request (review feedback)
3842
request_params["cache"] = {"no-cache": True}
3943
# Single-level reasoning effort: expect `reasoning_effort` only
4044
effort_val = None
41-
if isinstance(config.input_params, dict):
42-
if "reasoning_effort" in config.input_params:
43-
effort_val = str(config.input_params["reasoning_effort"]) # flat shape
44-
elif isinstance(config.input_params.get("extra_body"), dict) and "reasoning_effort" in config.input_params["extra_body"]:
45+
if isinstance(config.completion_params, dict):
46+
if "reasoning_effort" in config.completion_params:
47+
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
48+
elif (
49+
isinstance(config.completion_params.get("extra_body"), dict)
50+
and "reasoning_effort" in config.completion_params["extra_body"]
51+
):
4552
# Accept if user passed it directly inside extra_body
46-
effort_val = str(config.input_params["extra_body"]["reasoning_effort"]) # already in extra_body
53+
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
4754

4855
if effort_val:
4956
# Always under extra_body so LiteLLM forwards to provider-specific param set

0 commit comments

Comments
 (0)