-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathevaluation_test.py
More file actions
240 lines (211 loc) · 11.5 KB
/
evaluation_test.py
File metadata and controls
240 lines (211 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import inspect
from typing import Any, Callable, Dict, List, Optional
import pytest
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
from eval_protocol.pytest.types import (
Dataset,
DatasetPathParam,
EvaluationTestMode,
InputMessagesParam,
InputParam,
ModelParam,
RolloutProcessor,
RolloutProcessorConfig,
TestFunction,
)
from eval_protocol.pytest.utils import (
AggregationMethod,
aggregate,
create_dynamically_parameterized_wrapper,
execute_function,
)
from ..common_utils import load_jsonl
def evaluation_test(
*,
model: List[ModelParam],
input_messages: Optional[List[InputMessagesParam]] = None,
input_dataset: Optional[List[DatasetPathParam]] = None,
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
input_params: Optional[List[InputParam]] = None,
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
aggregation_method: AggregationMethod = "mean",
threshold_of_success: Optional[float] = None,
num_runs: int = 1,
max_dataset_rows: Optional[int] = None,
mcp_config_path: Optional[str] = None,
mode: EvaluationTestMode = "batch",
) -> Callable[
[TestFunction],
TestFunction,
]:
"""Decorator to create pytest-based evaluation tests.
Args:
model: Model identifiers to query.
input_messages: Messages to send to the model. This is useful if you
don't have a dataset but can hard-code the messages. Will be passed as
"input_dataset" to the test function.
input_dataset: Paths to JSONL datasets. This is useful if you have a
dataset already. Provide a dataset_adapter to convert the input dataset
to a list of EvaluationRows if you have a custom dataset format.
dataset_adapter: Function to convert the input dataset to a list of
EvaluationRows. This is useful if you have a custom dataset format.
input_params: Generation parameters for the model.
rollout_processor: Function used to perform the rollout.
aggregation_method: How to aggregate scores across rows.
threshold_of_success: If set, fail the test if the aggregated score is
below this threshold.
num_runs: Number of times to repeat the evaluation.
max_dataset_rows: Limit dataset to the first N rows.
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
mode: Evaluation mode. "batch" (default) expects test function to handle
full dataset. "pointwise" applies test function to each row. If your evaluation requires
the full rollout of all rows to compute the score, use
"""
def decorator(
test_func: TestFunction,
):
sig = inspect.signature(test_func)
# For pointwise/rowwise mode, we expect a different signature
if mode == "pointwise":
# Pointwise mode: function should accept messages and other row-level params
if "row" not in sig.parameters:
raise ValueError(f"In pointwise mode, your eval function must have a parameter named 'row'")
# validate that "Row" is of type EvaluationRow
if sig.parameters["row"].annotation is not EvaluationRow:
raise ValueError(f"In pointwise mode, the 'row' parameter must be of type EvaluationRow")
# validate that the function has a return type of EvaluationRow
if sig.return_annotation is not EvaluationRow:
raise ValueError("In pointwise mode, your eval function must return an EvaluationRow instance")
else:
# Batch mode: function should accept input_dataset and model
if "rows" not in sig.parameters:
raise ValueError("In batch mode, your eval function must have a parameter named 'rows'")
# validate that "Rows" is of type List[EvaluationRow]
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
raise ValueError(f"In batch mode, the 'rows' parameter must be of type List[EvaluationRow]")
# validate that the function has a return type of List[EvaluationRow]
if sig.return_annotation is not List[EvaluationRow]:
raise ValueError("In batch mode, your eval function must return a list of EvaluationRow instances")
def execute_with_params(
test_func: TestFunction,
row: EvaluationRow | None = None,
input_dataset: List[EvaluationRow] | None = None,
):
kwargs = {}
if input_dataset is not None:
kwargs["rows"] = input_dataset
if row is not None:
kwargs["row"] = row
return execute_function(test_func, **kwargs)
# Calculate all possible combinations of parameters
def generate_combinations():
combinations = []
# Handle optional parameters with defaults
datasets: List[Optional[DatasetPathParam]] = input_dataset if input_dataset is not None else [None] # type: ignore
params: List[Optional[InputParam]] = input_params if input_params is not None else [None] # type: ignore
messages: List[Optional[InputMessagesParam]] = input_messages if input_messages is not None else [None] # type: ignore
# Generate all combinations
for m in model:
for ds in datasets:
for ip in params:
for im in messages:
# Skip combinations that don't make sense
# If we have a dataset, we should have params for rollout
if ds is not None and ip is None:
continue
# If we have messages but no dataset, that's fine
# If we have no dataset and no messages, that's also fine
combinations.append((m, ds, ip, im))
return combinations
combinations = generate_combinations()
# Create parameter tuples for pytest.mark.parametrize
param_tuples = []
for combo in combinations:
model_name, dataset, params, messages = combo
param_tuple = [model_name]
if input_dataset is not None:
param_tuple.append(dataset)
if input_params is not None:
param_tuple.append(params)
if input_messages is not None:
param_tuple.append(messages)
param_tuples.append(tuple(param_tuple))
# For batch mode, use the original parameter names
test_param_names = ["model"]
if input_dataset is not None:
test_param_names.append("dataset_path")
if input_params is not None:
test_param_names.append("input_params")
if input_messages is not None:
test_param_names.append("input_messages")
# Create wrapper function with exact signature that pytest expects
def create_wrapper_with_signature():
# Create the function body that will be used
def wrapper_body(**kwargs):
model_name = kwargs["model"]
# Handle dataset loading
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
data = load_jsonl(kwargs["dataset_path"])
if max_dataset_rows is not None:
data = data[:max_dataset_rows]
data = dataset_adapter(data)
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
data: List[EvaluationRow] = [EvaluationRow(messages=kwargs["input_messages"])]
else:
raise ValueError("No input dataset or input messages provided")
input_dataset: List[EvaluationRow] = []
config = RolloutProcessorConfig(
model=model_name,
input_params=kwargs.get("input_params") or {},
mcp_config_path=mcp_config_path or "",
)
input_dataset = execute_function(rollout_processor, rows=data, config=config)
all_results: List[EvaluationRow] = []
for _ in range(num_runs):
if mode == "pointwise":
# Pointwise mode: apply the evaluator function to each row
for row in input_dataset:
result = execute_with_params(
test_func,
row=row,
)
if result is None or not isinstance(result, EvaluationRow):
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
)
all_results.append(result)
else:
# Batch mode: call the test function with the full dataset
results = execute_with_params(
test_func,
input_dataset=input_dataset,
)
if results is None:
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
)
if not isinstance(results, list):
raise ValueError(
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
)
if not results:
raise ValueError(
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
)
if not all(isinstance(r, EvaluationRow) for r in results):
raise ValueError(
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
)
all_results.extend(results)
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
agg_score = aggregate(scores, aggregation_method)
if threshold_of_success is not None:
assert (
agg_score >= threshold_of_success
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
wrapper = create_wrapper_with_signature()
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)
return wrapper
return decorator