Skip to content

Commit 7dbd341

Browse files
committed
Merge branch 'main' into derekx/pipelining
2 parents dcbbfad + a5e1479 commit 7dbd341

File tree

14 files changed

+919
-105
lines changed

14 files changed

+919
-105
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .registry import export_benchmark, get_benchmark_runner, list_benchmarks
2+
3+
__all__ = [
4+
"export_benchmark",
5+
"get_benchmark_runner",
6+
"list_benchmarks",
7+
]
8+
9+
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Benchmark registry and export decorator.
3+
4+
This module provides a lightweight registry for benchmarks and a decorator
5+
`@export_benchmark(name)` that can be stacked with `@evaluation_test`.
6+
7+
It registers a runnable handle that executes the exact same evaluation pipeline
8+
as the pytest flow by calling `run_evaluation_test_direct` with the parameters
9+
captured from the decorated function.
10+
11+
Usage in a suite module (stack under @evaluation_test):
12+
13+
from eval_protocol.benchmarks.registry import export_benchmark
14+
15+
@export_benchmark("aime25_low")
16+
@evaluation_test(...)
17+
def test_aime_pointwise(row: EvaluationRow) -> EvaluationRow:
18+
...
19+
20+
Programmatic run:
21+
22+
from eval_protocol.benchmarks.registry import get_benchmark_runner
23+
get_benchmark_runner("aime25_low")(model="fireworks_ai/...", print_summary=True, out="artifacts/aime.json")
24+
"""
25+
26+
from __future__ import annotations
27+
28+
import json
29+
import os
30+
from typing import Any, Callable, Dict, List, Optional
31+
32+
33+
# Global registry: name -> callable runner
34+
_BENCHMARK_REGISTRY: Dict[str, Callable[..., Any]] = {}
35+
36+
37+
def list_benchmarks() -> List[str]:
38+
return sorted(_BENCHMARK_REGISTRY.keys())
39+
40+
41+
def get_benchmark_runner(name: str) -> Callable[..., Any]:
42+
try:
43+
return _BENCHMARK_REGISTRY[name]
44+
except KeyError as exc:
45+
raise KeyError(f"Benchmark '{name}' not found. Available: {list_benchmarks()}") from exc
46+
47+
48+
def export_benchmark(name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
49+
"""
50+
Decorator to export a benchmark test into the global registry.
51+
52+
This expects to be stacked with `@evaluation_test`, so the decorated function
53+
should carry `__ep_config` and `__ep_original_test_func` attributes that the
54+
decorator can read to construct a direct runner.
55+
56+
The registered runner supports a subset of convenient overrides and maps them
57+
to the same EP_* environment variables used by the pytest plugin to ensure
58+
identical summaries and JSON artifact behavior.
59+
"""
60+
61+
def _decorator(test_wrapper: Callable[..., Any]) -> Callable[..., Any]:
62+
# Pull through metadata attached by evaluation_test
63+
ep_config: Dict[str, Any] = getattr(test_wrapper, "__ep_config", {})
64+
original_test_func: Optional[Callable[..., Any]] = getattr(
65+
test_wrapper, "__ep_original_test_func", None
66+
)
67+
68+
def _runner(
69+
*,
70+
model: Optional[str] = None,
71+
print_summary: bool = False,
72+
out: Optional[str] = None,
73+
reasoning_effort: Optional[str] = None,
74+
max_rows: Optional[int | str] = None,
75+
num_runs: Optional[int] = None,
76+
input_params_override: Optional[Dict[str, Any]] = None,
77+
max_concurrency: Optional[int] = None,
78+
) -> Any:
79+
# Map convenience flags to EP_* env used by the pytest flow
80+
if print_summary:
81+
os.environ["EP_PRINT_SUMMARY"] = "1"
82+
if out:
83+
os.environ["EP_SUMMARY_JSON"] = out
84+
# Merge reasoning effort and arbitrary overrides into EP_INPUT_PARAMS_JSON
85+
merged: Dict[str, Any] = {}
86+
if reasoning_effort:
87+
# Fireworks OpenAI-compatible endpoint expects extra_body.reasoning_effort, not nested reasoning dict
88+
merged.setdefault("extra_body", {})["reasoning_effort"] = str(reasoning_effort)
89+
if input_params_override:
90+
def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
91+
for k, v in over.items():
92+
if isinstance(v, dict) and isinstance(base.get(k), dict):
93+
_deep_update(base[k], v)
94+
else:
95+
base[k] = v
96+
return base
97+
merged = _deep_update(merged, dict(input_params_override))
98+
if merged:
99+
os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged)
100+
101+
if max_rows is not None:
102+
if isinstance(max_rows, str) and max_rows.strip().lower() == "all":
103+
os.environ["EP_MAX_DATASET_ROWS"] = "None"
104+
else:
105+
os.environ["EP_MAX_DATASET_ROWS"] = str(max_rows)
106+
107+
# Build effective parameters, preferring overrides
108+
models: List[str] = ep_config.get("model") or []
109+
model_to_use = model or (models[0] if models else None)
110+
if not model_to_use:
111+
raise ValueError(
112+
f"No model provided and none captured from evaluation_test for benchmark '{name}'"
113+
)
114+
115+
input_messages = ep_config.get("input_messages")
116+
input_dataset = ep_config.get("input_dataset")
117+
dataset_adapter = ep_config.get("dataset_adapter")
118+
rollout_input_params_list = ep_config.get("rollout_input_params")
119+
rollout_processor = ep_config.get("rollout_processor")
120+
aggregation_method = ep_config.get("aggregation_method")
121+
threshold = ep_config.get("threshold_of_success")
122+
default_num_runs = ep_config.get("num_runs")
123+
max_dataset_rows = ep_config.get("max_dataset_rows")
124+
mcp_config_path = ep_config.get("mcp_config_path")
125+
max_concurrent_rollouts = ep_config.get("max_concurrent_rollouts")
126+
if max_concurrency is not None:
127+
max_concurrent_rollouts = int(max_concurrency)
128+
server_script_path = ep_config.get("server_script_path")
129+
steps = ep_config.get("steps")
130+
mode = ep_config.get("mode")
131+
combine_datasets = ep_config.get("combine_datasets")
132+
133+
# Choose the first rollout param set by default
134+
rollout_params = None
135+
if isinstance(rollout_input_params_list, list) and rollout_input_params_list:
136+
rollout_params = rollout_input_params_list[0]
137+
138+
# Import runner lazily to avoid hard import dependencies and circulars
139+
import importlib
140+
141+
_mod = importlib.import_module("eval_protocol.pytest.evaluation_test")
142+
run_evaluation_test_direct = getattr(_mod, "run_evaluation_test_direct")
143+
144+
return run_evaluation_test_direct(
145+
test_func=original_test_func or test_wrapper,
146+
model=model_to_use,
147+
input_messages=input_messages,
148+
input_dataset=input_dataset,
149+
dataset_adapter=dataset_adapter,
150+
rollout_input_params=rollout_params,
151+
rollout_processor=rollout_processor,
152+
aggregation_method=aggregation_method,
153+
threshold_of_success=threshold,
154+
num_runs=(num_runs if num_runs is not None else default_num_runs),
155+
max_dataset_rows=max_dataset_rows,
156+
mcp_config_path=mcp_config_path,
157+
max_concurrent_rollouts=max_concurrent_rollouts,
158+
server_script_path=server_script_path,
159+
steps=steps,
160+
mode=mode,
161+
)
162+
163+
# Register runner
164+
if name in _BENCHMARK_REGISTRY:
165+
# Overwrite with latest definition
166+
_BENCHMARK_REGISTRY[name] = _runner
167+
else:
168+
_BENCHMARK_REGISTRY[name] = _runner
169+
170+
return test_wrapper
171+
172+
return _decorator
173+
174+

eval_protocol/benchmarks/run.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Minimal CLI runner for exported benchmarks.
3+
4+
Usage:
5+
6+
python -m eval_protocol.benchmarks.run aime25_low \
7+
--model fireworks_ai/accounts/fireworks/models/gpt-oss-120b \
8+
--print-summary \
9+
--out artifacts/aime25_low.json \
10+
--max-rows 50 \
11+
--reasoning-effort low
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import argparse
17+
from typing import Any
18+
19+
from importlib import import_module
20+
import pkgutil
21+
import eval_protocol.benchmarks.suites as suites_pkg
22+
from eval_protocol.benchmarks.registry import get_benchmark_runner, list_benchmarks
23+
24+
25+
def _parse_args() -> argparse.Namespace:
26+
parser = argparse.ArgumentParser(description="Run an exported eval-protocol benchmark")
27+
parser.add_argument("name", help=f"Benchmark name. Known: {', '.join(list_benchmarks()) or '(none)'}")
28+
parser.add_argument("--model", required=True, help="Model identifier (provider/model)")
29+
parser.add_argument("--print-summary", action="store_true", help="Print concise EP summary line")
30+
parser.add_argument("--out", help="Write JSON summary artifact to path or directory")
31+
parser.add_argument(
32+
"--reasoning-effort",
33+
choices=["low", "medium", "high"],
34+
help="Sets extra_body.reasoning.effort via EP_INPUT_PARAMS_JSON",
35+
)
36+
parser.add_argument(
37+
"--max-rows",
38+
help="Limit rows: integer or 'all' for no limit (maps to EP_MAX_DATASET_ROWS)",
39+
)
40+
parser.add_argument("--num-runs", type=int, help="Override num_runs if provided")
41+
parser.add_argument("--max-tokens", type=int, help="Override max_tokens for generation requests")
42+
parser.add_argument("--max-concurrency", type=int, help="Override max concurrent rollouts")
43+
# Allow overriding reasoning effort explicitly (low/medium/high). If omitted, suite default is used.
44+
# Already mapped by --reasoning-effort above.
45+
return parser.parse_args()
46+
47+
48+
def main() -> int:
49+
args = _parse_args()
50+
# Auto-import all suite modules so their @export_benchmark decorators register
51+
# Import all suite modules so their @export_benchmark decorators register
52+
import sys, traceback
53+
for modinfo in pkgutil.iter_modules(suites_pkg.__path__):
54+
mod_name = f"{suites_pkg.__name__}.{modinfo.name}"
55+
try:
56+
import_module(mod_name)
57+
except Exception as e:
58+
print(f"[bench] failed to import suite module: {mod_name}: {e}", file=sys.stderr)
59+
traceback.print_exc()
60+
# Fallback: if nothing registered yet and a known suite was requested, try explicit import
61+
if not list_benchmarks():
62+
known_map = {
63+
"aime25_low": "eval_protocol.benchmarks.suites.aime25",
64+
}
65+
forced = known_map.get(args.name)
66+
if forced:
67+
try:
68+
import_module(forced)
69+
except Exception as e:
70+
print(f"[bench] explicit import failed for {forced}: {e}", file=sys.stderr)
71+
runner = get_benchmark_runner(args.name)
72+
max_rows: int | str | None = None
73+
if args.max_rows is not None:
74+
try:
75+
max_rows = int(args.max_rows)
76+
except Exception:
77+
max_rows = str(args.max_rows)
78+
# Build input params override if needed
79+
ip_override = {}
80+
if args.max_tokens is not None:
81+
ip_override["max_tokens"] = int(args.max_tokens)
82+
83+
_ = runner(
84+
model=args.model,
85+
print_summary=args.print_summary,
86+
out=args.out,
87+
reasoning_effort=args.reasoning_effort,
88+
max_rows=max_rows,
89+
num_runs=args.num_runs,
90+
input_params_override=(ip_override or None),
91+
max_concurrency=args.max_concurrency,
92+
)
93+
# Non-zero exit on failure gate is handled within the runner via assertions
94+
return 0
95+
96+
97+
if __name__ == "__main__":
98+
raise SystemExit(main())
99+
100+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Suite modules are auto-imported by eval_protocol.benchmarks.run to register benchmarks.
2+
3+

0 commit comments

Comments
 (0)