Skip to content

Commit 4fa4162

Browse files
committed
skeleton of gepa trainer
1 parent 42e0b08 commit 4fa4162

File tree

6 files changed

+227
-2
lines changed

6 files changed

+227
-2
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
SingleTurnRolloutProcessor,
1313
)
1414
from eval_protocol.pytest.evaluation_test import evaluation_test
15+
from eval_protocol.training import GEPATrainer
16+
from eval_protocol.training.gepa_utils import build_reflection_lm
1517

1618
SYSTEM_PROMPT = (
1719
"You are a helpful math assistant. Please reason step by step, and put your final answer within \\boxed{...}."
@@ -131,3 +133,17 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
131133
metrics=metrics,
132134
)
133135
return row
136+
137+
138+
if __name__ == "__main__":
139+
trainer = GEPATrainer(test_aime25_pointwise)
140+
reflection_lm = build_reflection_lm("gpt-5")
141+
142+
optimized_program = trainer.train(
143+
num_threads=32,
144+
track_stats=True,
145+
reflection_minibatch_size=3,
146+
reflection_lm=reflection_lm,
147+
)
148+
149+
print(trainer.evaluate(optimized_program))

eval_protocol/trainable_gepa_design.md

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
- Specifies what is tunable (e.g., the system prompt) and how to adapt rows using a candidate.
2020
- Invokes a train routine (GEPA-based or otherwise).
2121

22-
- **training core**
22+
- **Training core**
2323
- Provides a single central abstraction:
2424
- **`EPParameters`**: Encapsulates everything `evaluation_test` knows about the eval in a structured form:
2525
- One field for every parameter that `evaluation_test` accepts (dataset sources, adapters, completion params, rollout processor, aggregation, thresholds, etc.), after parsing/env overrides.
@@ -68,7 +68,7 @@ setattr(dual_mode_wrapper, "__ep_params__", ep_params)
6868
- Rollout and mode information (processor, kwargs, concurrency limits, mode).
6969
- The training core can then **directly convert `__ep_params__` into an `EPParameters` instance** without maintaining a separate training-only config.
7070

71-
- training core will expose:
71+
- Training core will expose:
7272
- A factory like:
7373

7474
```python
@@ -199,3 +199,38 @@ class GEPATrainer:
199199
- After GEPA integration works for AIME:
200200
- Decide on the canonical way to treat GEPA’s `run_dir` and/or additional artifacts for tuned prompts.
201201
- Optionally add a small helper that knows how to “run evaluation once with best GEPA candidate” for CI workflows.
202+
203+
204+
future:
205+
206+
this is how gepa defines eval:
207+
208+
def metric(
209+
gold: Example,
210+
pred: Prediction,
211+
trace: Optional[DSPyTrace] = None,
212+
pred_name: Optional[str] = None,
213+
pred_trace: Optional[DSPyTrace] = None,
214+
) -> float | ScoreWithFeedback:
215+
"""
216+
This function is called with the following arguments:
217+
- gold: The gold example.
218+
- pred: The predicted output.
219+
- trace: Optional. The trace of the program's execution.
220+
- pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which
221+
the feedback is being requested.
222+
- pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for.
223+
224+
Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain
225+
feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name`
226+
and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`.
227+
If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding
228+
to the predictor.
229+
If not available at the predictor level, the metric can also return a text feedback at the program level
230+
(using just the gold, pred and trace).
231+
If no feedback is returned, GEPA will use a simple text feedback consisting of just the score:
232+
f"This trajectory got a score of {score}."
233+
"""
234+
...
235+
236+
ideally generic way to turn evaluation_test into this.

eval_protocol/training/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from gepa_adapter import GEPATrainer
2+
3+
__all__ = ["GEPATrainer"]
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Any, Dict, Literal
2+
3+
import dspy
4+
from dspy.clients.lm import LM
5+
from dspy.primitives import Module
6+
from dspy.teleprompt.gepa.gepa import GEPA
7+
from gepa.core.adapter import ProposalFn
8+
from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector
9+
10+
from eval_protocol.models import EPParameters, EvaluationRow
11+
from eval_protocol.pytest.types import TestFunction
12+
from eval_protocol.training.gepa_utils import REFLECTION_LM_CONFIGS
13+
from eval_protocol.training.utils import build_ep_parameters_from_test
14+
15+
16+
class GEPATrainer:
17+
"""
18+
High-level entrypoint for running GEPA-style training against an existing
19+
`@evaluation_test`-decorated function.
20+
21+
This class is intentionally minimal for now:
22+
- It captures `EPParameters` from the provided test function via
23+
`build_ep_parameters_from_test`.
24+
- It stores any GEPA-related configuration kwargs for future use.
25+
- The actual GEPA optimization loop is left as a TODO.
26+
"""
27+
28+
def __init__(self, test_fn: TestFunction) -> None:
29+
"""
30+
Args:
31+
test_fn: The `@evaluation_test`-decorated function defining the eval.
32+
"""
33+
self.test_fn = test_fn
34+
self.ep_params: EPParameters = build_ep_parameters_from_test(test_fn)
35+
36+
self.metric = (
37+
test_fn # TODO: need to convert our ep test_fn to a GEPA metric. also need to inject the feedback text.
38+
)
39+
40+
self.program = ... # TODO: converting between a program (dspy.Module) and an @evaluation_test is a bit tricky.
41+
42+
self.train_set, self.val_set, self.test_set = (
43+
...,
44+
...,
45+
...,
46+
) # TODO: need to convert our input_dataset to a train set
47+
48+
def train(
49+
self,
50+
auto: Literal["light", "medium", "heavy"] | None = None,
51+
max_full_evals: int | None = None,
52+
max_metric_calls: int | None = None,
53+
reflection_minibatch_size: int = 3,
54+
candidate_selection_strategy: Literal["pareto", "current_best"] = "pareto",
55+
reflection_lm: LM | None = None,
56+
skip_perfect_score: bool = True,
57+
add_format_failure_as_feedback: bool = False,
58+
instruction_proposer: ProposalFn | None = None,
59+
component_selector: ReflectionComponentSelector | str = "round_robin",
60+
use_merge: bool = True,
61+
max_merge_invocations: int | None = 5,
62+
num_threads: int | None = None,
63+
failure_score: float = 0.0,
64+
perfect_score: float = 1.0,
65+
log_dir: str | None = None,
66+
track_stats: bool = False,
67+
use_wandb: bool = False,
68+
wandb_api_key: str | None = None,
69+
wandb_init_kwargs: dict[str, Any] | None = None,
70+
track_best_outputs: bool = False,
71+
warn_on_score_mismatch: bool = True,
72+
enable_tool_optimization: bool = False,
73+
use_mlflow: bool = False,
74+
seed: int | None = 0,
75+
gepa_kwargs: dict | None = None,
76+
) -> Module:
77+
"""
78+
Run GEPA to optimize over candidates.
79+
"""
80+
gepa_args: dict[str, Any] = {
81+
"auto": auto,
82+
"max_full_evals": max_full_evals,
83+
"max_metric_calls": max_metric_calls,
84+
"reflection_minibatch_size": reflection_minibatch_size,
85+
"candidate_selection_strategy": candidate_selection_strategy,
86+
"reflection_lm": reflection_lm,
87+
"skip_perfect_score": skip_perfect_score,
88+
"add_format_failure_as_feedback": add_format_failure_as_feedback,
89+
"instruction_proposer": instruction_proposer,
90+
"component_selector": component_selector,
91+
"use_merge": use_merge,
92+
"max_merge_invocations": max_merge_invocations,
93+
"num_threads": num_threads,
94+
"failure_score": failure_score,
95+
"perfect_score": perfect_score,
96+
"log_dir": log_dir,
97+
"track_stats": track_stats,
98+
"use_wandb": use_wandb,
99+
"wandb_api_key": wandb_api_key,
100+
"wandb_init_kwargs": wandb_init_kwargs,
101+
"track_best_outputs": track_best_outputs,
102+
"warn_on_score_mismatch": warn_on_score_mismatch,
103+
"enable_tool_optimization": enable_tool_optimization,
104+
"use_mlflow": use_mlflow,
105+
"seed": seed,
106+
}
107+
gepa_args.update(gepa_kwargs or {})
108+
109+
optimizer = GEPA(
110+
metric=self.metric,
111+
**gepa_args,
112+
)
113+
114+
optimized_program = optimizer.compile(
115+
self.program,
116+
trainset=self.train_set,
117+
valset=self.val_set,
118+
)
119+
120+
return optimized_program
121+
122+
def evaluate(self, optimized_program: Module) -> list[EvaluationRow]:
123+
# convert back to EP
124+
125+
# and then just run our evaluation_test function on the optimized program.
126+
127+
# OR we can evaluate using dspy.Evaluate
128+
129+
# evaluate = dspy.Evaluate(
130+
# devset=self.test_set,
131+
# metric=self.metric,
132+
# num_threads=32,
133+
# display_table=True,
134+
# display_progress=True
135+
# )
136+
137+
# return evaluate(self.optimized_program)
138+
...
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
3+
import dspy
4+
from dspy.clients.lm import LM
5+
6+
REFLECTION_LM_CONFIGS = {
7+
"gpt-5": {
8+
"model": "gpt-5",
9+
"temperature": 1.0,
10+
"max_tokens": 32000,
11+
"api_key": os.getenv("OPENAI_API_KEY"),
12+
"base_url": "https://api.openai.com/v1",
13+
},
14+
"kimi-k2-instruct-0905": {
15+
"model": "accounts/fireworks/models/kimi-k2-instruct-0905",
16+
"temperature": 0.6, # Kimi recommended temperature
17+
"max_tokens": 131000,
18+
"api_key": os.getenv("FIREWORKS_API_KEY"),
19+
"base_url": "https://api.fireworks.ai/inference/v1",
20+
},
21+
}
22+
23+
24+
def build_reflection_lm(reflection_lm_name: str) -> LM:
25+
reflection_lm_config = REFLECTION_LM_CONFIGS[reflection_lm_name]
26+
return dspy.LM(
27+
model=reflection_lm_config["model"],
28+
temperature=reflection_lm_config["temperature"],
29+
max_tokens=reflection_lm_config["max_tokens"],
30+
api_key=reflection_lm_config["api_key"],
31+
base_url=reflection_lm_config["base_url"],
32+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"deepdiff>=6.0.0",
4848
"websockets>=15.0.1",
4949
"fastapi>=0.116.1",
50+
"dspy>=3.0.0",
5051
]
5152

5253
[project.urls]

0 commit comments

Comments
 (0)