|
| 1 | +""" |
| 2 | +Benchmark registry and export decorator. |
| 3 | +
|
| 4 | +This module provides a lightweight registry for benchmarks and a decorator |
| 5 | +`@export_benchmark(name)` that can be stacked with `@evaluation_test`. |
| 6 | +
|
| 7 | +It registers a runnable handle that executes the exact same evaluation pipeline |
| 8 | +as the pytest flow by calling `run_evaluation_test_direct` with the parameters |
| 9 | +captured from the decorated function. |
| 10 | +
|
| 11 | +Usage in a suite module (stack under @evaluation_test): |
| 12 | +
|
| 13 | + from eval_protocol.benchmarks.registry import export_benchmark |
| 14 | +
|
| 15 | + @export_benchmark("aime25_low") |
| 16 | + @evaluation_test(...) |
| 17 | + def test_aime_pointwise(row: EvaluationRow) -> EvaluationRow: |
| 18 | + ... |
| 19 | +
|
| 20 | +Programmatic run: |
| 21 | +
|
| 22 | + from eval_protocol.benchmarks.registry import get_benchmark_runner |
| 23 | + get_benchmark_runner("aime25_low")(model="fireworks_ai/...", print_summary=True, out="artifacts/aime.json") |
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import annotations |
| 27 | + |
| 28 | +import json |
| 29 | +import os |
| 30 | +from typing import Any, Callable, Dict, List, Optional |
| 31 | + |
| 32 | + |
| 33 | +# Global registry: name -> callable runner |
| 34 | +_BENCHMARK_REGISTRY: Dict[str, Callable[..., Any]] = {} |
| 35 | + |
| 36 | + |
| 37 | +def list_benchmarks() -> List[str]: |
| 38 | + return sorted(_BENCHMARK_REGISTRY.keys()) |
| 39 | + |
| 40 | + |
| 41 | +def get_benchmark_runner(name: str) -> Callable[..., Any]: |
| 42 | + try: |
| 43 | + return _BENCHMARK_REGISTRY[name] |
| 44 | + except KeyError as exc: |
| 45 | + raise KeyError(f"Benchmark '{name}' not found. Available: {list_benchmarks()}") from exc |
| 46 | + |
| 47 | + |
| 48 | +def export_benchmark(name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: |
| 49 | + """ |
| 50 | + Decorator to export a benchmark test into the global registry. |
| 51 | +
|
| 52 | + This expects to be stacked with `@evaluation_test`, so the decorated function |
| 53 | + should carry `__ep_config` and `__ep_original_test_func` attributes that the |
| 54 | + decorator can read to construct a direct runner. |
| 55 | +
|
| 56 | + The registered runner supports a subset of convenient overrides and maps them |
| 57 | + to the same EP_* environment variables used by the pytest plugin to ensure |
| 58 | + identical summaries and JSON artifact behavior. |
| 59 | + """ |
| 60 | + |
| 61 | + def _decorator(test_wrapper: Callable[..., Any]) -> Callable[..., Any]: |
| 62 | + # Pull through metadata attached by evaluation_test |
| 63 | + ep_config: Dict[str, Any] = getattr(test_wrapper, "__ep_config", {}) |
| 64 | + original_test_func: Optional[Callable[..., Any]] = getattr( |
| 65 | + test_wrapper, "__ep_original_test_func", None |
| 66 | + ) |
| 67 | + |
| 68 | + def _runner( |
| 69 | + *, |
| 70 | + model: Optional[str] = None, |
| 71 | + print_summary: bool = False, |
| 72 | + out: Optional[str] = None, |
| 73 | + reasoning_effort: Optional[str] = None, |
| 74 | + max_rows: Optional[int | str] = None, |
| 75 | + num_runs: Optional[int] = None, |
| 76 | + input_params_override: Optional[Dict[str, Any]] = None, |
| 77 | + max_concurrency: Optional[int] = None, |
| 78 | + ) -> Any: |
| 79 | + # Map convenience flags to EP_* env used by the pytest flow |
| 80 | + if print_summary: |
| 81 | + os.environ["EP_PRINT_SUMMARY"] = "1" |
| 82 | + if out: |
| 83 | + os.environ["EP_SUMMARY_JSON"] = out |
| 84 | + # Merge reasoning effort and arbitrary overrides into EP_INPUT_PARAMS_JSON |
| 85 | + merged: Dict[str, Any] = {} |
| 86 | + if reasoning_effort: |
| 87 | + # Fireworks OpenAI-compatible endpoint expects extra_body.reasoning_effort, not nested reasoning dict |
| 88 | + merged.setdefault("extra_body", {})["reasoning_effort"] = str(reasoning_effort) |
| 89 | + if input_params_override: |
| 90 | + def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: |
| 91 | + for k, v in over.items(): |
| 92 | + if isinstance(v, dict) and isinstance(base.get(k), dict): |
| 93 | + _deep_update(base[k], v) |
| 94 | + else: |
| 95 | + base[k] = v |
| 96 | + return base |
| 97 | + merged = _deep_update(merged, dict(input_params_override)) |
| 98 | + if merged: |
| 99 | + os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged) |
| 100 | + |
| 101 | + if max_rows is not None: |
| 102 | + if isinstance(max_rows, str) and max_rows.strip().lower() == "all": |
| 103 | + os.environ["EP_MAX_DATASET_ROWS"] = "None" |
| 104 | + else: |
| 105 | + os.environ["EP_MAX_DATASET_ROWS"] = str(max_rows) |
| 106 | + |
| 107 | + # Build effective parameters, preferring overrides |
| 108 | + models: List[str] = ep_config.get("model") or [] |
| 109 | + model_to_use = model or (models[0] if models else None) |
| 110 | + if not model_to_use: |
| 111 | + raise ValueError( |
| 112 | + f"No model provided and none captured from evaluation_test for benchmark '{name}'" |
| 113 | + ) |
| 114 | + |
| 115 | + input_messages = ep_config.get("input_messages") |
| 116 | + input_dataset = ep_config.get("input_dataset") |
| 117 | + dataset_adapter = ep_config.get("dataset_adapter") |
| 118 | + rollout_input_params_list = ep_config.get("rollout_input_params") |
| 119 | + rollout_processor = ep_config.get("rollout_processor") |
| 120 | + aggregation_method = ep_config.get("aggregation_method") |
| 121 | + threshold = ep_config.get("threshold_of_success") |
| 122 | + default_num_runs = ep_config.get("num_runs") |
| 123 | + max_dataset_rows = ep_config.get("max_dataset_rows") |
| 124 | + mcp_config_path = ep_config.get("mcp_config_path") |
| 125 | + max_concurrent_rollouts = ep_config.get("max_concurrent_rollouts") |
| 126 | + if max_concurrency is not None: |
| 127 | + max_concurrent_rollouts = int(max_concurrency) |
| 128 | + server_script_path = ep_config.get("server_script_path") |
| 129 | + steps = ep_config.get("steps") |
| 130 | + mode = ep_config.get("mode") |
| 131 | + combine_datasets = ep_config.get("combine_datasets") |
| 132 | + |
| 133 | + # Choose the first rollout param set by default |
| 134 | + rollout_params = None |
| 135 | + if isinstance(rollout_input_params_list, list) and rollout_input_params_list: |
| 136 | + rollout_params = rollout_input_params_list[0] |
| 137 | + |
| 138 | + # Import runner lazily to avoid hard import dependencies and circulars |
| 139 | + import importlib |
| 140 | + |
| 141 | + _mod = importlib.import_module("eval_protocol.pytest.evaluation_test") |
| 142 | + run_evaluation_test_direct = getattr(_mod, "run_evaluation_test_direct") |
| 143 | + |
| 144 | + return run_evaluation_test_direct( |
| 145 | + test_func=original_test_func or test_wrapper, |
| 146 | + model=model_to_use, |
| 147 | + input_messages=input_messages, |
| 148 | + input_dataset=input_dataset, |
| 149 | + dataset_adapter=dataset_adapter, |
| 150 | + rollout_input_params=rollout_params, |
| 151 | + rollout_processor=rollout_processor, |
| 152 | + aggregation_method=aggregation_method, |
| 153 | + threshold_of_success=threshold, |
| 154 | + num_runs=(num_runs if num_runs is not None else default_num_runs), |
| 155 | + max_dataset_rows=max_dataset_rows, |
| 156 | + mcp_config_path=mcp_config_path, |
| 157 | + max_concurrent_rollouts=max_concurrent_rollouts, |
| 158 | + server_script_path=server_script_path, |
| 159 | + steps=steps, |
| 160 | + mode=mode, |
| 161 | + ) |
| 162 | + |
| 163 | + # Register runner |
| 164 | + if name in _BENCHMARK_REGISTRY: |
| 165 | + # Overwrite with latest definition |
| 166 | + _BENCHMARK_REGISTRY[name] = _runner |
| 167 | + else: |
| 168 | + _BENCHMARK_REGISTRY[name] = _runner |
| 169 | + |
| 170 | + return test_wrapper |
| 171 | + |
| 172 | + return _decorator |
| 173 | + |
| 174 | + |
0 commit comments