Skip to content

Commit 67918cb

Browse files
author
Dylan Huang
authored
Rollout input params to completion params (#73)
* convert rollout_input_params to completion_params * fix * DISABLE_EP_SQLITE_LOG * fix kwargs access to "model" * DRY completion params and make it a dict * fix tests * revert * fix * ensure logging * fix smoke test params
1 parent 8ad4c06 commit 67918cb

40 files changed

+431
-341
lines changed

eval_protocol/adapters/huggingface.py

Lines changed: 133 additions & 135 deletions
Large diffs are not rendered by default.

eval_protocol/adapters/langfuse.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
to EvaluationRow format for use in evaluation pipelines.
55
"""
66

7-
from typing import Any, Dict, Iterator, List, Optional
8-
from datetime import datetime
97
import logging
8+
from datetime import datetime
9+
from typing import Any, Dict, Iterator, List, Optional
1010

11-
from eval_protocol.models import EvaluationRow, Message, InputMetadata, CompletionParams
11+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -277,20 +277,20 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
277277
InputMetadata object
278278
"""
279279
# Extract completion parameters from observations
280-
completion_params = CompletionParams()
280+
completion_params = {}
281281

282282
# Look for model parameters in observations
283283
for obs in observations:
284284
if hasattr(obs, "model") and obs.model:
285-
completion_params.model = obs.model
285+
completion_params["model"] = obs.model
286286
if hasattr(obs, "model_parameters") and obs.model_parameters:
287287
params = obs.model_parameters
288288
if "temperature" in params:
289-
completion_params.temperature = params["temperature"]
289+
completion_params["temperature"] = params["temperature"]
290290
if "max_tokens" in params:
291-
completion_params.max_tokens = params["max_tokens"]
291+
completion_params["max_tokens"] = params["max_tokens"]
292292
if "top_p" in params:
293-
completion_params.top_p = params["top_p"]
293+
completion_params["top_p"] = params["top_p"]
294294
break
295295

296296
# Create dataset info from trace metadata

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,18 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
6060

6161
@export_benchmark("aime25")
6262
@evaluation_test(
63-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
6463
input_dataset=[
6564
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl",
6665
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl",
6766
],
6867
dataset_adapter=aime2025_dataset_adapter,
69-
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
68+
completion_params=[
69+
{
70+
"max_tokens": 131000,
71+
"extra_body": {"reasoning_effort": "low"},
72+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
73+
}
74+
],
7075
rollout_processor=default_single_turn_rollout_processor,
7176
aggregation_method="mean",
7277
passed_threshold=None,

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _extract_abcd_letter(text: str) -> str | None:
5555

5656
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
5757

58+
5859
def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
5960
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6061

@@ -67,16 +68,19 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) ->
6768
if gt_tokens:
6869
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
6970
r.ground_truth = gt_val
70-
r.messages = [m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
71+
r.messages = [
72+
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
73+
]
7174
processed.append(r)
7275
return await default_single_turn_rollout_processor(processed, config)
7376

7477

7578
@export_benchmark("gpqa")
7679
@evaluation_test(
77-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
7880
input_messages=_GPQA_INPUT_MESSAGES,
79-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
81+
completion_params=[
82+
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
83+
],
8084
rollout_processor=gpqa_strip_gt_rollout_processor,
8185
aggregation_method="mean",
8286
passed_threshold=None,

eval_protocol/benchmarks/suites/livebench_data_analysis.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
from typing import Any, Dict, List, Optional
2-
31
import json
42
import re
3+
from typing import Any, Dict, List, Optional
54

5+
from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark
66
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
77
from eval_protocol.pytest.default_single_turn_rollout_process import (
88
default_single_turn_rollout_processor,
99
)
1010
from eval_protocol.pytest.evaluation_test import evaluation_test
11-
from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark
12-
1311

1412
# -------------------------
1513
# Lightweight ports of LiveBench scoring utilities for data_analysis tasks
1614
# -------------------------
1715

16+
1817
def _lb_clean_text(text: str) -> str:
1918
text = text.lower().strip()
2019
text = re.sub(r"[^\w]", "", text)
@@ -36,9 +35,7 @@ def _cta_process_results(ground_truth: str, llm_answer: str) -> int:
3635
boxed = _extract_last_boxed_segment(parsed_answer)
3736
if boxed is not None:
3837
parsed_answer = boxed
39-
parsed_answer = (
40-
parsed_answer.replace("\\text{", "").replace("}", "").replace("\\", "")
41-
)
38+
parsed_answer = parsed_answer.replace("\\text{", "").replace("}", "").replace("\\", "")
4239

4340
gt_clean = _lb_clean_text(ground_truth)
4441
ans_clean = _lb_clean_text(parsed_answer)
@@ -132,17 +129,15 @@ def _tablejoin_process_results(ground_truth: Any, llm_answer: str) -> float:
132129
return round((2 * tp) / denom, 2)
133130

134131

135-
def _tablereformat_process_results(
136-
input_command: str, ground_truth: str, llm_answer: str, version: str
137-
) -> int:
132+
def _tablereformat_process_results(input_command: str, ground_truth: str, llm_answer: str, version: str) -> int:
138133
try:
139134
import pandas as pd # type: ignore
140135
except Exception:
141136
return 0
142137

143-
from io import StringIO
144138
import math as _math
145139
import traceback as _traceback
140+
from io import StringIO
146141

147142
def _read_df_v1(df_type: str, df_str: str):
148143
if df_type == "json":
@@ -252,8 +247,12 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]):
252247
)
253248
else:
254249
lines = input_command.split("\n")
255-
input_fmt = [l for l in lines if "Source Format" in l][-1].split("Source Format: ")[-1].strip().lower()
256-
output_fmt = [l for l in lines if "Target Format" in l][-1].split("Target Format: ")[-1].strip().lower()
250+
input_fmt = (
251+
[line for line in lines if "Source Format" in line][-1].split("Source Format: ")[-1].strip().lower()
252+
)
253+
output_fmt = (
254+
[line for line in lines if "Target Format" in line][-1].split("Target Format: ")[-1].strip().lower()
255+
)
257256

258257
reader = _read_df_v1 if version == "v1" else _read_df_v2
259258
gt_df = reader(output_fmt, ground_truth)
@@ -373,9 +372,9 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
373372

374373
@export_benchmark("live_bench/data_analysis/cta")
375374
@evaluation_test(
376-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
375+
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
377376
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
378-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
377+
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
379378
rollout_processor=default_single_turn_rollout_processor,
380379
aggregation_method="mean",
381380
passed_threshold=None,
@@ -416,9 +415,9 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
416415

417416
@export_benchmark("live_bench/data_analysis/tablejoin")
418417
@evaluation_test(
419-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
418+
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
420419
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
421-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
420+
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
422421
rollout_processor=default_single_turn_rollout_processor,
423422
aggregation_method="mean",
424423
passed_threshold=None,
@@ -460,9 +459,9 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
460459

461460
@export_benchmark("live_bench/data_analysis/tablereformat")
462461
@evaluation_test(
463-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
462+
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
464463
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
465-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
464+
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
466465
rollout_processor=default_single_turn_rollout_processor,
467466
aggregation_method="mean",
468467
passed_threshold=None,
@@ -508,5 +507,3 @@ def livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow:
508507
"live_bench/data_analysis/tablereformat",
509508
],
510509
)
511-
512-

eval_protocol/benchmarks/suites/tau_bench_retail.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Dict, List
1212

1313
from eval_protocol.benchmarks.registry import export_benchmark
14-
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
14+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
1515
from eval_protocol.pytest import evaluation_test
1616
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
1717
from vendor.tau2.data_model.message import (
@@ -66,8 +66,13 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
6666
@evaluation_test(
6767
input_dataset=["tests/pytest/data/retail_dataset.jsonl"],
6868
dataset_adapter=tau_bench_retail_to_evaluation_row,
69-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
70-
rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}],
69+
completion_params=[
70+
{
71+
"temperature": 0.8,
72+
"extra_body": {"reasoning_effort": "medium"},
73+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
74+
}
75+
],
7176
rollout_processor=default_mcp_gym_rollout_processor,
7277
rollout_processor_kwargs={"domain": "retail"},
7378
num_runs=8,

eval_protocol/dataset_logger/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
21
import os
32

3+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
4+
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
5+
46
# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs.
5-
if os.getenv("EP_SQLITE_LOG", "0").strip() == "1":
7+
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() == "1":
68
default_logger = SqliteDatasetLoggerAdapter()
79
else:
8-
class _NoOpLogger:
10+
11+
class _NoOpLogger(DatasetLogger):
912
def log(self, row):
1013
return None
1114

eval_protocol/mcp/execution/manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2121
from vendor.tau2.user.user_simulator import UserSimulator
2222

23-
from ...models import CompletionParams, EvaluationRow, InputMetadata, Message
23+
from ...models import EvaluationRow, InputMetadata, Message
2424
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
2525

2626
if TYPE_CHECKING:
@@ -128,12 +128,12 @@ async def _execute_with_semaphore(idx):
128128
evaluation_row.messages = messages
129129
evaluation_row.tools = shared_tool_schema
130130
evaluation_row.usage = CompletionUsage(**trajectory.usage)
131-
evaluation_row.input_metadata.completion_params = CompletionParams(
132-
model=policy.model_id,
133-
temperature=getattr(policy, "temperature", None),
134-
max_tokens=getattr(policy, "max_tokens", None),
135-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
136-
)
131+
evaluation_row.input_metadata.completion_params = {
132+
"model": policy.model_id,
133+
"temperature": getattr(policy, "temperature", None),
134+
"max_tokens": getattr(policy, "max_tokens", None),
135+
"max_tool_calls": getattr(policy, "max_tools_per_turn", None),
136+
}
137137

138138
if trajectory.terminated:
139139
if trajectory.termination_reason == TerminationReason.ERROR:

eval_protocol/models.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from datetime import datetime
3-
from typing import Any, Dict, List, Literal, Optional, Union
3+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
44

55
from openai.types import CompletionUsage
66
from openai.types.chat.chat_completion_message import (
@@ -178,13 +178,18 @@ def __iter__(self):
178178
return iter(self.__fields__.keys()) # Changed to __fields__
179179

180180

181-
class CompletionParams(BaseModel):
182-
"""Configuration for the language model used in the session."""
181+
CompletionParams = Dict[str, Any]
182+
"""
183+
Common set of completion parameters that most model providers support in their
184+
API. Set total=False to allow extra fields since LiteLLM + providers have their
185+
own set of parameters. The following parameters are common fields that are
186+
populated.
183187
184-
model: str = Field(..., description="Model identifier (e.g., 'gpt-4.1', 'fireworks/llama')")
185-
temperature: Optional[float] = Field(None, description="Temperature setting for model generation")
186-
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
187-
max_tool_calls: Optional[int] = Field(None, description="Maximum tool calls per turn")
188+
model: str
189+
temperature: Optional[float]
190+
max_tokens: Optional[int]
191+
top_p: Optional[float]
192+
"""
188193

189194

190195
class InputMetadata(BaseModel):
@@ -193,7 +198,9 @@ class InputMetadata(BaseModel):
193198
model_config = ConfigDict(extra="allow")
194199

195200
row_id: Optional[str] = Field(default_factory=generate_id, description="Unique string to ID the row")
196-
completion_params: Optional[CompletionParams] = Field(None, description="Completion endpoint parameters used")
201+
completion_params: CompletionParams = Field(
202+
default_factory=dict, description="Completion endpoint parameters used"
203+
)
197204
dataset_info: Optional[Dict[str, Any]] = Field(
198205
None, description="Dataset row details: seed, system_prompt, environment_context, etc"
199206
)

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ async def default_agent_rollout_processor(
125125

126126
async def process_row(row: EvaluationRow) -> EvaluationRow:
127127
"""Process a single row with agent rollout."""
128-
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
128+
agent = Agent(
129+
model=config.completion_params.model, row=row, config_path=config.mcp_config_path, logger=config.logger
130+
)
129131
try:
130132
await agent.setup()
131133
await agent.call_agent()

0 commit comments

Comments
 (0)