Skip to content

Commit 971d6e4

Browse files
committed
add live bench
1 parent a5e1479 commit 971d6e4

File tree

7 files changed

+695
-8
lines changed

7 files changed

+695
-8
lines changed

eval_protocol/benchmarks/registry.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
128128
server_script_path = ep_config.get("server_script_path")
129129
steps = ep_config.get("steps")
130130
mode = ep_config.get("mode")
131-
combine_datasets = ep_config.get("combine_datasets")
131+
# combine_datasets captured but not used here
132132

133133
# Choose the first rollout param set by default
134134
rollout_params = None
@@ -172,3 +172,161 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
172172
return _decorator
173173

174174

175+
def register_composite_benchmark(name: str, children: List[str]) -> None:
176+
"""
177+
Register a composite benchmark that runs multiple exported benchmarks and aggregates results.
178+
179+
The composite runner forwards common overrides to each child benchmark and aggregates
180+
a combined score as a rows-weighted mean of each child's aggregated score.
181+
182+
Args:
183+
name: Name of the composite benchmark to register.
184+
children: List of child benchmark names previously registered via export_benchmark.
185+
"""
186+
187+
def _composite_runner(
188+
*,
189+
model: Optional[str] = None,
190+
print_summary: bool = False,
191+
out: Optional[str] = None,
192+
reasoning_effort: Optional[str] = None,
193+
max_rows: Optional[int | str] = None,
194+
num_runs: Optional[int] = None,
195+
input_params_override: Optional[Dict[str, Any]] = None,
196+
max_concurrency: Optional[int] = None,
197+
) -> Dict[str, Any]:
198+
# Resolve child runners at call-time to ensure all suites are imported
199+
# Local import avoided to prevent circular import at module import time
200+
_get_benchmark_runner = get_benchmark_runner
201+
import pathlib as _pathlib
202+
import time as _time
203+
_json = json
204+
205+
child_summaries: List[Dict[str, Any]] = []
206+
total_rows = 0
207+
weighted_sum = 0.0
208+
# For per-metric aggregation across children
209+
metric_weighted_sums: Dict[str, float] = {}
210+
metric_total_rows: Dict[str, int] = {}
211+
combined_rows: List[Any] = []
212+
213+
# If 'out' is a file path, also compute a directory for child artifacts
214+
child_out_dir: Optional[str] = None
215+
if out:
216+
p = _pathlib.Path(out)
217+
if p.suffix.lower() == ".json" and not str(out).endswith("/"):
218+
# Use parent directory for child artifacts
219+
child_out_dir = str(p.parent)
220+
else:
221+
child_out_dir = out
222+
223+
for child_name in children:
224+
runner = _get_benchmark_runner(child_name)
225+
result = runner(
226+
model=model,
227+
print_summary=print_summary,
228+
out=child_out_dir,
229+
reasoning_effort=reasoning_effort,
230+
max_rows=max_rows,
231+
num_runs=num_runs,
232+
input_params_override=input_params_override,
233+
max_concurrency=max_concurrency,
234+
)
235+
summary = (result or {}).get("summary") if isinstance(result, dict) else None
236+
if not summary:
237+
continue
238+
# Gather underlying rows to recompute CI across children
239+
try:
240+
rows_obj = result.get("results") if isinstance(result, dict) else None
241+
if isinstance(rows_obj, list):
242+
combined_rows.extend(rows_obj)
243+
except Exception:
244+
pass
245+
child_summaries.append(summary)
246+
rows = int(summary.get("rows", 0) or 0)
247+
agg = summary.get("agg_score")
248+
if isinstance(agg, (int, float)) and rows > 0:
249+
total_rows += rows
250+
weighted_sum += float(agg) * rows
251+
# Combine per-metric means if available
252+
metrics_agg = summary.get("metrics_agg") or {}
253+
if isinstance(metrics_agg, dict):
254+
for m_name, m_vals in metrics_agg.items():
255+
m_mean = m_vals.get("mean")
256+
if isinstance(m_mean, (int, float)) and rows > 0:
257+
metric_weighted_sums[m_name] = metric_weighted_sums.get(m_name, 0.0) + float(m_mean) * rows
258+
metric_total_rows[m_name] = metric_total_rows.get(m_name, 0) + rows
259+
260+
combined_agg = (weighted_sum / total_rows) if total_rows > 0 else None
261+
# Compute 95% CI for combined rows if available
262+
ci_low: Optional[float] = None
263+
ci_high: Optional[float] = None
264+
if combined_rows:
265+
try:
266+
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci as _compute_ci
267+
268+
r = _compute_ci(combined_rows)
269+
if r and len(r) >= 3 and r[1] is not None and r[2] is not None:
270+
ci_low = float(r[1])
271+
ci_high = float(r[2])
272+
except Exception:
273+
ci_low = None
274+
ci_high = None
275+
combined_metrics: Dict[str, Dict[str, float]] = {}
276+
for m_name, wsum in metric_weighted_sums.items():
277+
denom = metric_total_rows.get(m_name, 0)
278+
if denom > 0:
279+
combined_metrics[m_name] = {"mean": float(wsum / denom)}
280+
combined = {
281+
"suite": name,
282+
"model": model,
283+
"agg_score": float(combined_agg) if combined_agg is not None else None,
284+
"rows": total_rows,
285+
"children": child_summaries,
286+
"num_runs": num_runs,
287+
**({"metrics_agg": combined_metrics} if combined_metrics else {}),
288+
**({"agg_ci_low": ci_low, "agg_ci_high": ci_high} if (ci_low is not None and ci_high is not None) else {}),
289+
}
290+
291+
# Optional print and persist
292+
# Respect either function arg or EP_PRINT_SUMMARY env
293+
_should_print = print_summary or (os.getenv("EP_PRINT_SUMMARY") == "1")
294+
if _should_print:
295+
try:
296+
if combined_agg is not None:
297+
if ci_low is not None and ci_high is not None:
298+
print(
299+
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] rows={total_rows}"
300+
)
301+
else:
302+
print(
303+
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} rows={total_rows}"
304+
)
305+
else:
306+
print(
307+
f"EP Summary | suite={name} model={model} agg=None rows={total_rows}"
308+
)
309+
except Exception:
310+
pass
311+
312+
if out:
313+
out_path = _pathlib.Path(out)
314+
if out_path.suffix.lower() == ".json" and not str(out).endswith("/"):
315+
# Write to the specified file
316+
out_path.parent.mkdir(parents=True, exist_ok=True)
317+
with open(out_path, "w", encoding="utf-8") as f:
318+
_json.dump({**combined, "timestamp": int(_time.time())}, f)
319+
else:
320+
# Treat as directory
321+
dir_path = out_path
322+
dir_path.mkdir(parents=True, exist_ok=True)
323+
safe_name = name.replace("/", "__")
324+
file_path = dir_path / f"{safe_name}__composite.json"
325+
with open(file_path, "w", encoding="utf-8") as f:
326+
_json.dump({**combined, "timestamp": int(_time.time())}, f)
327+
328+
return {"summary": combined}
329+
330+
# Register (overwrite if exists)
331+
_BENCHMARK_REGISTRY[name] = _composite_runner
332+

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
7373
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
7474
rollout_processor=default_single_turn_rollout_processor,
7575
aggregation_method="mean",
76-
threshold_of_success=None,
76+
passed_threshold=None,
7777
num_runs=8,
7878
max_dataset_rows=2,
7979
max_concurrent_rollouts=4,

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
4040
[
4141
Message(role="system", content=SYSTEM_PROMPT),
4242
Message(role="user", content=user_content),
43-
# Correct answer is always option A by construction
44-
Message(role="system", content="__GT__:A"),
4543
]
4644
)
4745
if not messages_list:
@@ -66,7 +64,7 @@ def _extract_abcd_letter(text: str) -> str | None:
6664
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
6765
rollout_processor=default_single_turn_rollout_processor,
6866
aggregation_method="mean",
69-
threshold_of_success=None,
67+
passed_threshold=None,
7068
num_runs=8,
7169
mode="pointwise",
7270
)

0 commit comments

Comments
 (0)