Skip to content

Commit 0a26115

Browse files
committed
gepa integration part 1
1 parent b103d2f commit 0a26115

File tree

6 files changed

+342
-8
lines changed

6 files changed

+342
-8
lines changed

eval_protocol/models.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import importlib
44
from datetime import datetime, timezone
55
from enum import Enum
6-
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union
6+
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union, Callable, Sequence
77

88
JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
99

@@ -1190,3 +1190,32 @@ class MCPMultiClientConfiguration(BaseModel):
11901190
"""Represents a MCP configuration."""
11911191

11921192
mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]
1193+
1194+
1195+
class EPParameters(BaseModel):
1196+
"""The parameters of an `@evaluation_test`. Used for trainable integrations."""
1197+
1198+
completion_params: Any = None
1199+
input_messages: Any = None
1200+
input_dataset: Any = None
1201+
input_rows: Any = None
1202+
data_loaders: Any = None
1203+
dataset_adapter: Optional[Callable[..., Any]] = None
1204+
rollout_processor: Any = None
1205+
rollout_processor_kwargs: Dict[str, Any] | None = None
1206+
aggregation_method: Any = Field(default="mean")
1207+
passed_threshold: Any = None
1208+
disable_browser_open: bool = False
1209+
num_runs: int = 1
1210+
filtered_row_ids: Optional[Sequence[str]] = None
1211+
max_dataset_rows: Optional[int] = None
1212+
mcp_config_path: Optional[str] = None
1213+
max_concurrent_rollouts: int = 8
1214+
max_concurrent_evaluations: int = 64
1215+
server_script_path: Optional[str] = None
1216+
steps: int = 30
1217+
mode: Any = Field(default="pointwise")
1218+
combine_datasets: bool = True
1219+
preprocess_fn: Optional[Callable[[list[EvaluationRow]], list[EvaluationRow]]] = None
1220+
logger: Any = None
1221+
exception_handler_config: Any = None

eval_protocol/pytest/evaluation_test.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
EvaluationThresholdDict,
2222
EvaluateResult,
2323
Status,
24+
EPParameters,
2425
)
2526
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2627
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
@@ -695,13 +696,33 @@ async def _collect_result(config, lst):
695696
)
696697
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
697698

698-
ep_params: dict[str, Any] = {
699-
"rollout_processor": rollout_processor,
700-
"server_script_path": server_script_path,
701-
"mcp_config_path": mcp_config_path,
702-
"rollout_processor_kwargs": rollout_processor_kwargs,
703-
"mode": mode,
704-
}
699+
# Attach full evaluation parameter metadata for training integrations
700+
ep_params: EPParameters = EPParameters(
701+
completion_params=completion_params,
702+
input_messages=input_messages,
703+
input_dataset=input_dataset,
704+
input_rows=input_rows,
705+
data_loaders=data_loaders,
706+
dataset_adapter=dataset_adapter,
707+
rollout_processor=rollout_processor,
708+
rollout_processor_kwargs=rollout_processor_kwargs,
709+
aggregation_method=aggregation_method,
710+
passed_threshold=passed_threshold,
711+
disable_browser_open=disable_browser_open,
712+
num_runs=num_runs,
713+
filtered_row_ids=filtered_row_ids,
714+
max_dataset_rows=max_dataset_rows,
715+
mcp_config_path=mcp_config_path,
716+
max_concurrent_rollouts=max_concurrent_rollouts,
717+
max_concurrent_evaluations=max_concurrent_evaluations,
718+
server_script_path=server_script_path,
719+
steps=steps,
720+
mode=mode,
721+
combine_datasets=combine_datasets,
722+
preprocess_fn=preprocess_fn,
723+
logger=logger,
724+
exception_handler_config=exception_handler_config,
725+
)
705726

706727
# Create the dual mode wrapper
707728
dual_mode_wrapper = create_dual_mode_wrapper(
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
## GEPA-Trainable Interface Design for Eval Protocol
2+
3+
### Goals
4+
5+
- **Tunable prompts for existing benchmarks**: Allow benchmarks like `test_aime25.py` and `test_gpqa.py` to expose parts of their configuration (e.g., system prompts) as trainable parameters, without changing their core evaluation logic.
6+
- **Tight coupling with `@evaluation_test`**: Reuse the same rollout configuration, datasets, and metrics that are already defined via `evaluation_test`, instead of duplicating that configuration in a separate training API.
7+
- **GEPA as one optimizer backend**: Provide a clean integration point for GEPA (and potentially other optimizers later) without requiring benchmarks to depend on DSPy or GEPA directly.
8+
9+
### High-Level Architecture
10+
11+
- **Benchmark file (e.g., `test_aime25.py`)**
12+
- Continues to define:
13+
- Dataset adapter (`aime2025_dataset_adapter`).
14+
- `@evaluation_test(...)`-decorated function (e.g., `test_aime25_pointwise`) that:
15+
- Uses `SingleTurnRolloutProcessor` (or another processor).
16+
- Computes per-row metrics and sets `row.evaluation_result`.
17+
- Adds *optional* trainable wiring at the bottom, under `if __name__ == "__main__":`, that:
18+
- Imports a trainable/core API from `eval_protocol.trainable`.
19+
- Specifies what is tunable (e.g., the system prompt) and how to adapt rows using a candidate.
20+
- Invokes a train routine (GEPA-based or otherwise).
21+
22+
- **Trainable core**
23+
- Provides a single central abstraction:
24+
- **`EPParameters`**: Encapsulates everything `evaluation_test` knows about the eval in a structured form:
25+
- One field for every parameter that `evaluation_test` accepts (dataset sources, adapters, completion params, rollout processor, aggregation, thresholds, etc.), after parsing/env overrides.
26+
- **Candidate representation**: Start with `dict[str, str]` (e.g., `{"system_prompt": "..."}`), anticipating future extensions (few-shot examples, tool docs, etc.).
27+
- Includes helper utilities to:
28+
- Build an `EPParameters` instance by introspecting an `@evaluation_test`-decorated function.
29+
- Run a single candidate or a batch of candidates through the full rollout + evaluation pipeline, returning aggregate scores (and optionally per-row scores).
30+
31+
- **GEPA adapter (e.g., `eval_protocol/trainable/gepa_adapter.py`)**
32+
- Wraps the trainable core and GEPA’s API:
33+
- Accepts:
34+
- An `EPConfig`.
35+
- A candidate space definition (for now, implicit via `dict[str, str]` keys).
36+
- GEPA configuration (budget, reflection model, seed, component selection strategy, etc.).
37+
- Provides:
38+
- A GEPA-compatible metric interface that:
39+
- Given a candidate, uses `EPConfig` (and benchmark-specific logic such as a custom `dataset_adapter`) to:
40+
- Construct or adapt rows for that candidate.
41+
- Run rollouts (reusing the same processors and params as the test).
42+
- Compute scalar scores (e.g., mean exact-match over a batch).
43+
- A training routine that returns:
44+
- A `best_candidate: dict[str, str]`.
45+
- Optional rich result object (e.g., mapping to `GEPAResult`, additional stats).
46+
47+
### Relationship to `evaluation_test` and `__ep_params__`
48+
49+
- Existing `evaluation_test` code will attach:
50+
51+
```python
52+
ep_params: dict[str, Any] = {
53+
"rollout_processor": rollout_processor,
54+
"server_script_path": server_script_path,
55+
"mcp_config_path": mcp_config_path,
56+
"rollout_processor_kwargs": rollout_processor_kwargs,
57+
"mode": mode,
58+
}
59+
setattr(dual_mode_wrapper, "__ep_params__", ep_params)
60+
```
61+
62+
- Design direction:
63+
- **Use `__ep_params__` as the single source of truth**.
64+
- **`__ep_params__` should contain all effective `evaluation_test` parameters**, including:
65+
- Parsed `completion_params` (after env overrides).
66+
- Dataset sources (`input_dataset`, `input_rows`, dataloaders, and `dataset_adapter`), after `parse_ep_*` transforms.
67+
- `aggregation_method`, `num_runs`, `max_dataset_rows`, etc.
68+
- Rollout and mode information (processor, kwargs, concurrency limits, mode).
69+
- The trainable core can then **directly convert `__ep_params__` into an `EPParameters` instance** without maintaining a separate trainable-only config.
70+
71+
- Trainable core will expose:
72+
- A factory like:
73+
74+
```python
75+
def build_ep_parameters_from_test(
76+
test_fn: TestFunction,
77+
) -> EPParameters:
78+
...
79+
```
80+
81+
- This function:
82+
- Reads `test_fn.__ep_params__`.
83+
- Reconstructs how to:
84+
- Load and preprocess the dataset.
85+
- Configure the rollout processor (`RolloutProcessorConfig`).
86+
- Run rollouts and then apply the row-level metric (by calling the decorated test function in a library mode).
87+
88+
- Training code (e.g., `python test_aime25.py`) then becomes:
89+
- Import the test function (e.g., `test_aime25_pointwise`).
90+
- Build an `EPParameters` from it.
91+
- Call into a GEPA-based trainer that uses the `EPParameters`.
92+
93+
### Open Questions
94+
95+
- **Where tuned prompts live (storage format and location)**:
96+
- GEPA already supports a `run_dir` for logging and checkpoints.
97+
- We need to decide:
98+
- Whether EP should:
99+
- Treat `run_dir` as the canonical store and optionally add a small `best_candidate.json` there; or
100+
- Provide an additional EP-level artifact format.
101+
- For now, storage is left as an **explicit design TODO** and can be finalized once we have the core/adapter in place.
102+
103+
### Work Split: Person A vs Person B
104+
105+
#### Person A – Trainable Core & `evaluation_test` Integration
106+
107+
- **1. Extend `evaluation_test` metadata (no behavior change)**
108+
- Populate a single `__ep_config__` dict on the decorated test function that includes:
109+
- Dataset specification (paths / input_rows / dataloaders, `dataset_adapter`, `max_dataset_rows`, etc.) after `parse_ep_*`.
110+
- Parsed `completion_params` (after env overrides like `parse_ep_completion_params_overwrite`).
111+
- Rollout settings (`rollout_processor`, `rollout_processor_kwargs`, `mode`, `max_concurrent_rollouts`, `max_concurrent_evaluations`).
112+
- Aggregation and threshold metadata.
113+
- Ensure:
114+
- Backwards compatibility for existing tests.
115+
- Clear typing and docstrings to guide future use.
116+
117+
- **2. Define core trainable abstractions in `eval_protocol/trainable/core.py`**
118+
- Define:
119+
- `EPConfig`:
120+
- A field for every parameter `evaluation_test` accepts (dataset, adapters, completion params, rollout processor, aggregation, thresholds, etc.).
121+
- Can be serialized/inspected for external tooling.
122+
- Candidate type alias (initially `Candidate = dict[str, str]`).
123+
- Implement:
124+
- `build_ep_config_from_test(test_fn: TestFunction) -> EPConfig`.
125+
- Reads `__ep_config__`.
126+
- Reuses the same dataset and rollout logic as pytest, but in a library-friendly way (no pytest invocation).
127+
- Helper(s) to:
128+
- Run a single candidate over the dataset, possibly with:
129+
- A subset of rows (train vs val split initially determined by the benchmark or EPConfig).
130+
- A configurable aggregation method (mean score to start).
131+
132+
- **3. Minimal tests and documentation for the core**
133+
- Add unit/integration tests that:
134+
- Use a tiny fake `@evaluation_test` function.
135+
- Confirm `build_ep_config_from_test` produces a config that can:
136+
- Load mock rows.
137+
- Run a dummy rollout processor.
138+
- Apply a simple metric to produce scores.
139+
- Document (in this design file or a short README) how benchmarks should think about exposing tunable pieces (e.g., via custom dataset adapters or other wiring).
140+
141+
#### Person B – GEPA Adapter & Benchmark Wiring
142+
143+
- **4. Implement GEPA integration in `eval_protocol/trainable/gepa_adapter.py`**
144+
- Define a small adapter API, e.g.:
145+
146+
```python
147+
class GEPATrainer:
148+
def __init__(self, spec: TrainableBenchmarkSpec, inject_fn: InjectFn, ...gepa_config...):
149+
...
150+
151+
def train(self) -> tuple[Candidate, Any]:
152+
"""Run GEPA and return best candidate plus optional rich result."""
153+
```
154+
155+
- Inside, implement:
156+
- Conversion from `(spec, inject_fn)` into a GEPA metric:
157+
- For each candidate:
158+
- Clone or map the base dataset rows, applying `inject_fn(candidate, row)`.
159+
- Use the spec’s rollout runner + metric runner to compute per-example and aggregate scores.
160+
- Return the aggregate score (and optional textual feedback) to GEPA.
161+
- The call to `gepa.optimize(...)` with:
162+
- `seed_candidate` constructed from the baseline configuration (e.g., default system prompt).
163+
- Budget configuration (max metric calls / auto presets).
164+
- Reflection config (reflection LM or other knobs) passed in via constructor.
165+
- Mapping from `GEPAResult` (or equivalent) back into:
166+
- `best_candidate: Candidate`.
167+
- Optional rich result object (e.g., exposing Pareto-front stats).
168+
169+
- **5. Wire a first benchmark: AIME 2025**
170+
- In `eval_protocol/benchmarks/test_aime25.py`:
171+
- Factor the row-scoring logic inside `test_aime25_pointwise` into a **reusable metric function** (pure function that sets `row.evaluation_result` given a rolled-out row).
172+
- Decide how candidates should influence the evaluation:
173+
- For example, by making the dataset adapter or message-construction logic candidate-aware (e.g., changing the system prompt).
174+
- Add a `if __name__ == "__main__":` block that:
175+
- Imports `test_aime25_pointwise` and builds an `EPConfig` via `build_ep_config_from_test`.
176+
- Instantiates `GEPATrainer` with:
177+
- The `EPConfig`.
178+
- Initial GEPA config (budget, reflection model placeholder, seed).
179+
- Calls `trainer.train()` and prints/logs the resulting `best_candidate` for now.
180+
- Keep storage of tuned prompts as a TODO/extension point to be resolved later.
181+
182+
- **6. Optional second benchmark: GPQA**
183+
- Repeat step 5 for `test_gpqa.py`:
184+
- Identify what’s tunable (system prompt, possibly chain-of-thought instructions).
185+
- Extract metric logic into a reusable function.
186+
- Add candidate-aware wiring (e.g., via dataset adapters) and an optional `__main__` entrypoint calling the same GEPA trainer.
187+
- This will validate that:
188+
- The abstractions generalize across tasks.
189+
- No DSPy/GEPA-specific imports leak into benchmark files (other than a small, well-defined trainable API).
190+
191+
### Coordination Notes
192+
193+
- **Order of work**
194+
- Person A should go first (or in parallel up to the point where `EPConfig` and `build_ep_config_from_test` are usable).
195+
- Person B can stub against interfaces and adjust once Person A’s core is available.
196+
- **Integration checkpoints**
197+
- After Person A lands the core + tests:
198+
- Person B wires AIME with a very simple “optimizer” (even random search) to smoke-test the path before hooking up real GEPA.
199+
- After GEPA integration works for AIME:
200+
- Decide on the canonical way to treat GEPA’s `run_dir` and/or additional artifacts for tuned prompts.
201+
- Optionally add a small helper that knows how to “run evaluation once with best GEPA candidate” for CI workflows.

eval_protocol/training/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any
2+
3+
from eval_protocol.models import EPParameters
4+
5+
6+
def build_ep_parameters_from_test(test_fn: Any) -> EPParameters:
7+
"""
8+
Build an `EPParameters` instance from an `@evaluation_test`-decorated function.
9+
10+
The decorator is responsible for attaching a `__ep_params__` attribute that
11+
contains all effective evaluation parameters after parsing/env overrides.
12+
"""
13+
if not hasattr(test_fn, "__ep_params__"):
14+
raise ValueError(
15+
"The provided test function does not have `__ep_params__` attached. "
16+
"Ensure it is decorated with `@evaluation_test` from eval_protocol.pytest."
17+
)
18+
19+
return getattr(test_fn, "__ep_params__")

tests/test_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Message,
1212
MetricResult,
1313
StepOutput,
14+
EPParameters,
1415
)
1516

1617

@@ -721,3 +722,34 @@ def test_message_dump_for_chat_completion_request():
721722
assert "weight" not in dictionary
722723
assert "reasoning_content" not in dictionary
723724
assert dictionary["content"] == "Hello, how are you?"
725+
726+
727+
def test_ep_parameters_defaults():
728+
"""EPParameters should have sensible defaults for core fields."""
729+
params = EPParameters()
730+
731+
assert params.completion_params is None
732+
assert params.num_runs == 1
733+
assert params.disable_browser_open is False
734+
assert params.max_concurrent_rollouts == 8
735+
assert params.max_concurrent_evaluations == 64
736+
assert params.mode == "pointwise"
737+
assert params.combine_datasets is True
738+
739+
740+
def test_ep_parameters_accepts_arbitrary_types():
741+
"""EPParameters should allow rich Python types for callable/logger fields."""
742+
743+
def dummy_preprocess(rows):
744+
return rows
745+
746+
def dummy_adapter(*args, **kwargs):
747+
return None
748+
749+
logger = logging.getLogger("ep-params-test")
750+
751+
params = EPParameters(dataset_adapter=dummy_adapter, preprocess_fn=dummy_preprocess, logger=logger)
752+
753+
assert params.dataset_adapter is dummy_adapter
754+
assert params.preprocess_fn is dummy_preprocess
755+
assert params.logger is logger

tests/test_training_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from eval_protocol.models import EPParameters
4+
from eval_protocol.training.utils import build_ep_parameters_from_test
5+
6+
7+
def test_build_ep_parameters_from_test_returns_attached_model():
8+
"""build_ep_parameters_from_test should return the EPParameters attached to the test function."""
9+
10+
def dummy_test() -> None:
11+
pass
12+
13+
params = EPParameters(num_runs=3, completion_params={"model": "gpt-4"})
14+
setattr(dummy_test, "__ep_params__", params)
15+
16+
result = build_ep_parameters_from_test(dummy_test)
17+
18+
assert result is params
19+
assert result.num_runs == 3
20+
assert result.completion_params == {"model": "gpt-4"}
21+
22+
23+
def test_build_ep_parameters_from_test_missing_attr_raises():
24+
"""build_ep_parameters_from_test should raise when __ep_params__ is missing."""
25+
26+
def dummy_test_no_attr() -> None:
27+
pass
28+
29+
with pytest.raises(ValueError) as exc_info:
30+
build_ep_parameters_from_test(dummy_test_no_attr)
31+
32+
assert "__ep_params__" in str(exc_info.value)

0 commit comments

Comments
 (0)