Skip to content

Commit e2198ac

Browse files
authored
add live bench (#68)
* add live bench * fix live bench and rollout processor
1 parent a9e7009 commit e2198ac

File tree

6 files changed

+708
-9
lines changed

6 files changed

+708
-9
lines changed

eval_protocol/benchmarks/registry.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
126126
server_script_path = ep_config.get("server_script_path")
127127
steps = ep_config.get("steps")
128128
mode = ep_config.get("mode")
129-
combine_datasets = ep_config.get("combine_datasets")
129+
# combine_datasets captured but not used here
130130

131131
# Choose the first rollout param set by default
132132
rollout_params = None
@@ -169,3 +169,162 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
169169
return test_wrapper
170170

171171
return _decorator
172+
173+
174+
def register_composite_benchmark(name: str, children: List[str]) -> None:
175+
"""
176+
Register a composite benchmark that runs multiple exported benchmarks and aggregates results.
177+
178+
The composite runner forwards common overrides to each child benchmark and aggregates
179+
a combined score as a rows-weighted mean of each child's aggregated score.
180+
181+
Args:
182+
name: Name of the composite benchmark to register.
183+
children: List of child benchmark names previously registered via export_benchmark.
184+
"""
185+
186+
def _composite_runner(
187+
*,
188+
model: Optional[str] = None,
189+
print_summary: bool = False,
190+
out: Optional[str] = None,
191+
reasoning_effort: Optional[str] = None,
192+
max_rows: Optional[int | str] = None,
193+
num_runs: Optional[int] = None,
194+
input_params_override: Optional[Dict[str, Any]] = None,
195+
max_concurrency: Optional[int] = None,
196+
) -> Dict[str, Any]:
197+
# Resolve child runners at call-time to ensure all suites are imported
198+
# Local import avoided to prevent circular import at module import time
199+
_get_benchmark_runner = get_benchmark_runner
200+
import pathlib as _pathlib
201+
import time as _time
202+
_json = json
203+
204+
child_summaries: List[Dict[str, Any]] = []
205+
total_rows = 0
206+
weighted_sum = 0.0
207+
# For per-metric aggregation across children
208+
metric_weighted_sums: Dict[str, float] = {}
209+
metric_total_rows: Dict[str, int] = {}
210+
combined_rows: List[Any] = []
211+
212+
# If 'out' is a file path, also compute a directory for child artifacts
213+
child_out_dir: Optional[str] = None
214+
if out:
215+
p = _pathlib.Path(out)
216+
if p.suffix.lower() == ".json" and not str(out).endswith("/"):
217+
# Use parent directory for child artifacts
218+
child_out_dir = str(p.parent)
219+
else:
220+
child_out_dir = out
221+
222+
for child_name in children:
223+
runner = _get_benchmark_runner(child_name)
224+
result = runner(
225+
model=model,
226+
print_summary=print_summary,
227+
out=child_out_dir,
228+
reasoning_effort=reasoning_effort,
229+
max_rows=max_rows,
230+
num_runs=num_runs,
231+
input_params_override=input_params_override,
232+
max_concurrency=max_concurrency,
233+
)
234+
summary = (result or {}).get("summary") if isinstance(result, dict) else None
235+
if not summary:
236+
continue
237+
# Gather underlying rows to recompute CI across children
238+
try:
239+
rows_obj = result.get("results") if isinstance(result, dict) else None
240+
if isinstance(rows_obj, list):
241+
combined_rows.extend(rows_obj)
242+
except Exception:
243+
pass
244+
child_summaries.append(summary)
245+
rows = int(summary.get("rows", 0) or 0)
246+
agg = summary.get("agg_score")
247+
if isinstance(agg, (int, float)) and rows > 0:
248+
total_rows += rows
249+
weighted_sum += float(agg) * rows
250+
# Combine per-metric means if available
251+
metrics_agg = summary.get("metrics_agg") or {}
252+
if isinstance(metrics_agg, dict):
253+
for m_name, m_vals in metrics_agg.items():
254+
m_mean = m_vals.get("mean")
255+
if isinstance(m_mean, (int, float)) and rows > 0:
256+
metric_weighted_sums[m_name] = metric_weighted_sums.get(m_name, 0.0) + float(m_mean) * rows
257+
metric_total_rows[m_name] = metric_total_rows.get(m_name, 0) + rows
258+
259+
combined_agg = (weighted_sum / total_rows) if total_rows > 0 else None
260+
# Compute 95% CI for combined rows if available
261+
ci_low: Optional[float] = None
262+
ci_high: Optional[float] = None
263+
if combined_rows:
264+
try:
265+
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci as _compute_ci
266+
267+
r = _compute_ci(combined_rows)
268+
if r and len(r) >= 3 and r[1] is not None and r[2] is not None:
269+
ci_low = float(r[1])
270+
ci_high = float(r[2])
271+
except Exception:
272+
ci_low = None
273+
ci_high = None
274+
combined_metrics: Dict[str, Dict[str, float]] = {}
275+
for m_name, wsum in metric_weighted_sums.items():
276+
denom = metric_total_rows.get(m_name, 0)
277+
if denom > 0:
278+
combined_metrics[m_name] = {"mean": float(wsum / denom)}
279+
combined = {
280+
"suite": name,
281+
"model": model,
282+
"agg_score": float(combined_agg) if combined_agg is not None else None,
283+
"rows": total_rows,
284+
"children": child_summaries,
285+
"num_runs": num_runs,
286+
**({"metrics_agg": combined_metrics} if combined_metrics else {}),
287+
**({"agg_ci_low": ci_low, "agg_ci_high": ci_high} if (ci_low is not None and ci_high is not None) else {}),
288+
}
289+
290+
# Optional print and persist
291+
# Respect either function arg or EP_PRINT_SUMMARY env
292+
_should_print = print_summary or (os.getenv("EP_PRINT_SUMMARY") == "1")
293+
if _should_print:
294+
try:
295+
if combined_agg is not None:
296+
if ci_low is not None and ci_high is not None:
297+
print(
298+
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] rows={total_rows}"
299+
)
300+
else:
301+
print(
302+
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} rows={total_rows}"
303+
)
304+
else:
305+
print(
306+
f"EP Summary | suite={name} model={model} agg=None rows={total_rows}"
307+
)
308+
except Exception:
309+
pass
310+
311+
if out:
312+
out_path = _pathlib.Path(out)
313+
if out_path.suffix.lower() == ".json" and not str(out).endswith("/"):
314+
# Write to the specified file
315+
out_path.parent.mkdir(parents=True, exist_ok=True)
316+
with open(out_path, "w", encoding="utf-8") as f:
317+
_json.dump({**combined, "timestamp": int(_time.time())}, f)
318+
else:
319+
# Treat as directory
320+
dir_path = out_path
321+
dir_path.mkdir(parents=True, exist_ok=True)
322+
safe_name = name.replace("/", "__")
323+
file_path = dir_path / f"{safe_name}__composite.json"
324+
with open(file_path, "w", encoding="utf-8") as f:
325+
_json.dump({**combined, "timestamp": int(_time.time())}, f)
326+
327+
return {"summary": combined}
328+
329+
# Register (overwrite if exists)
330+
_BENCHMARK_REGISTRY[name] = _composite_runner

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
6969
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
7070
rollout_processor=default_single_turn_rollout_processor,
7171
aggregation_method="mean",
72+
passed_threshold=None,
7273
num_runs=8,
7374
max_dataset_rows=2,
7475
max_concurrent_rollouts=4,

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
3939
[
4040
Message(role="system", content=SYSTEM_PROMPT),
4141
Message(role="user", content=user_content),
42-
# Correct answer is always option A by construction
43-
Message(role="system", content="__GT__:A"),
4442
]
4543
)
4644
if not messages_list:
@@ -57,14 +55,31 @@ def _extract_abcd_letter(text: str) -> str | None:
5755

5856
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
5957

58+
def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
59+
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
60+
61+
62+
async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
63+
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
64+
processed: List[EvaluationRow] = []
65+
for r in rows:
66+
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
67+
if gt_tokens:
68+
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
69+
r.ground_truth = gt_val
70+
r.messages = [m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
71+
processed.append(r)
72+
return await default_single_turn_rollout_processor(processed, config)
73+
6074

6175
@export_benchmark("gpqa")
6276
@evaluation_test(
6377
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
6478
input_messages=_GPQA_INPUT_MESSAGES,
6579
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
66-
rollout_processor=default_single_turn_rollout_processor,
80+
rollout_processor=gpqa_strip_gt_rollout_processor,
6781
aggregation_method="mean",
82+
passed_threshold=None,
6883
num_runs=8,
6984
mode="pointwise",
7085
)
@@ -73,9 +88,8 @@ def gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
7388
content = assistant_msgs[-1].content if assistant_msgs else ""
7489

7590
pred = _extract_abcd_letter(content or "")
76-
# Retrieve GT from the trailing system message we appended
77-
gt_tokens = [m.content for m in row.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
78-
gt = gt_tokens[-1].split(":", 1)[1].strip() if gt_tokens else None
91+
# GPQA diamond CSV constructs options so that the correct answer is always A
92+
gt = "A"
7993

8094
is_valid = pred is not None and gt in {"A", "B", "C", "D"}
8195
score = 1.0 if (is_valid and pred == gt) else 0.0

0 commit comments

Comments
 (0)