Skip to content

Commit a0cb830

Browse files
author
Dylan Huang
committed
Refactor evaluation_test to improve error handling and metadata initialization. Ensure eval_metadata is set for each row before rollouts, and enhance exception management to log errors appropriately while maintaining pytest behavior.
1 parent ffe942e commit a0cb830

File tree

1 file changed

+133
-109
lines changed

1 file changed

+133
-109
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 133 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -194,122 +194,146 @@ def create_wrapper_with_signature() -> Callable:
194194
# Create the function body that will be used
195195
def wrapper_body(**kwargs):
196196
model_name = kwargs["model"]
197-
198-
# Handle dataset loading
199-
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
200-
data = load_jsonl(kwargs["dataset_path"])
201-
if max_dataset_rows is not None:
202-
data = data[:max_dataset_rows]
203-
data = dataset_adapter(data)
204-
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
205-
data: List[EvaluationRow] = [EvaluationRow(messages=kwargs["input_messages"])]
206-
else:
207-
raise ValueError("No input dataset or input messages provided")
208-
209-
input_params = kwargs.get("input_params") or {}
210-
211-
# Create eval metadata with test function info and current commit hash
212-
eval_metadata = EvalMetadata(
213-
name=test_func.__name__,
214-
description=test_func.__doc__,
215-
version=versioneer.get_version(),
216-
status="running",
217-
num_runs=num_runs,
218-
aggregation_method=aggregation_method,
219-
threshold_of_success=threshold_of_success,
220-
passed=None,
221-
)
222-
223-
# Populate completion_params in input_metadata for all rows and initialize eval_metadata BEFORE rollouts
224-
completion_params = CompletionParams(
225-
model=model_name,
226-
temperature=input_params.get("temperature"),
227-
max_tokens=input_params.get("max_tokens"),
228-
max_tool_calls=input_params.get("max_tool_calls"),
229-
)
230-
231-
for row in data:
232-
if row.input_metadata is None:
233-
row.input_metadata = InputMetadata()
234-
row.input_metadata.completion_params = completion_params
235-
# Add mode to session_data
236-
if row.input_metadata.session_data is None:
237-
row.input_metadata.session_data = {}
238-
row.input_metadata.session_data["mode"] = mode
239-
# Initialize eval_metadata for each row
240-
row.eval_metadata = eval_metadata
241-
242-
# Now run the rollout processor with metadata-initialized data
243-
config = RolloutProcessorConfig(
244-
model=model_name,
245-
input_params=input_params,
246-
mcp_config_path=mcp_config_path or "",
247-
max_concurrent_rollouts=max_concurrent_rollouts,
248-
server_script_path=server_script_path,
249-
steps=steps,
250-
)
251-
input_dataset = execute_function(rollout_processor, rows=data, config=config)
252-
197+
eval_metadata = None
253198
all_results: List[EvaluationRow] = []
254-
for _ in range(num_runs):
255-
if mode == "pointwise":
256-
# Pointwise mode: apply the evaluator function to each row
257-
for row in input_dataset:
258-
result = execute_with_params(
199+
200+
try:
201+
# Handle dataset loading
202+
data: List[EvaluationRow] = []
203+
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
204+
data_jsonl = load_jsonl(kwargs["dataset_path"])
205+
if max_dataset_rows is not None:
206+
data_jsonl = data_jsonl[:max_dataset_rows]
207+
data = dataset_adapter(data_jsonl)
208+
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
209+
data: List[EvaluationRow] = [EvaluationRow(messages=kwargs["input_messages"])]
210+
else:
211+
raise ValueError("No input dataset or input messages provided")
212+
213+
input_params = kwargs.get("input_params") or {}
214+
215+
# Create eval metadata with test function info and current commit hash
216+
eval_metadata = EvalMetadata(
217+
name=test_func.__name__,
218+
description=test_func.__doc__,
219+
version=versioneer.get_version(),
220+
status="running",
221+
num_runs=num_runs,
222+
aggregation_method=aggregation_method,
223+
threshold_of_success=threshold_of_success,
224+
passed=None,
225+
)
226+
227+
# Populate completion_params in input_metadata for all rows and initialize eval_metadata BEFORE rollouts
228+
completion_params = CompletionParams(
229+
model=model_name,
230+
temperature=input_params.get("temperature"),
231+
max_tokens=input_params.get("max_tokens"),
232+
max_tool_calls=input_params.get("max_tool_calls"),
233+
)
234+
235+
for row in data:
236+
if row.input_metadata is None:
237+
row.input_metadata = InputMetadata()
238+
row.input_metadata.completion_params = completion_params
239+
# Add mode to session_data
240+
if row.input_metadata.session_data is None:
241+
row.input_metadata.session_data = {}
242+
row.input_metadata.session_data["mode"] = mode
243+
# Initialize eval_metadata for each row
244+
row.eval_metadata = eval_metadata
245+
246+
# Now run the rollout processor with metadata-initialized data
247+
config = RolloutProcessorConfig(
248+
model=model_name,
249+
input_params=input_params,
250+
mcp_config_path=mcp_config_path or "",
251+
max_concurrent_rollouts=max_concurrent_rollouts,
252+
server_script_path=server_script_path,
253+
steps=steps,
254+
)
255+
input_dataset = execute_function(rollout_processor, rows=data, config=config)
256+
257+
for _ in range(num_runs):
258+
if mode == "pointwise":
259+
# Pointwise mode: apply the evaluator function to each row
260+
for row in input_dataset:
261+
result = execute_with_params(
262+
test_func,
263+
row=row,
264+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
265+
)
266+
if result is None or not isinstance(result, EvaluationRow):
267+
raise ValueError(
268+
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
269+
)
270+
all_results.append(result)
271+
else:
272+
# Batch mode: call the test function with the full dataset
273+
results = execute_with_params(
259274
test_func,
260-
row=row,
275+
input_dataset=input_dataset,
261276
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
262277
)
263-
if result is None or not isinstance(result, EvaluationRow):
278+
if results is None:
264279
raise ValueError(
265280
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
266281
)
267-
all_results.append(result)
268-
else:
269-
# Batch mode: call the test function with the full dataset
270-
results = execute_with_params(
271-
test_func,
272-
input_dataset=input_dataset,
273-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
274-
)
275-
if results is None:
276-
raise ValueError(
277-
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
278-
)
279-
if not isinstance(results, list):
280-
raise ValueError(
281-
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
282-
)
283-
if not results:
284-
raise ValueError(
285-
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
286-
)
287-
if not all(isinstance(r, EvaluationRow) for r in results):
288-
raise ValueError(
289-
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
290-
)
291-
all_results.extend(results)
292-
293-
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
294-
agg_score = aggregate(scores, aggregation_method)
295-
296-
# Determine if the evaluation passed based on threshold
297-
passed = None
298-
if threshold_of_success is not None:
299-
passed = agg_score >= threshold_of_success
300-
301-
# Update eval metadata status and passed field for all results
302-
for r in all_results:
303-
if r.eval_metadata is not None:
304-
r.eval_metadata.status = "finished"
305-
r.eval_metadata.passed = passed
306-
default_logger.log(r)
307-
308-
# Check threshold after logging
309-
if threshold_of_success is not None and not passed:
310-
assert (
311-
agg_score >= threshold_of_success
312-
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
282+
if not isinstance(results, list):
283+
raise ValueError(
284+
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
285+
)
286+
if not results:
287+
raise ValueError(
288+
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
289+
)
290+
if not all(isinstance(r, EvaluationRow) for r in results):
291+
raise ValueError(
292+
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
293+
)
294+
all_results.extend(results)
295+
296+
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
297+
agg_score = aggregate(scores, aggregation_method)
298+
299+
# Determine if the evaluation passed based on threshold
300+
passed = None
301+
if threshold_of_success is not None:
302+
passed = agg_score >= threshold_of_success
303+
304+
# Update eval metadata status and passed field for all results
305+
for r in all_results:
306+
if r.eval_metadata is not None:
307+
r.eval_metadata.status = "finished"
308+
r.eval_metadata.passed = passed
309+
default_logger.log(r)
310+
311+
# Check threshold after logging
312+
if threshold_of_success is not None and not passed:
313+
assert (
314+
agg_score >= threshold_of_success
315+
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
316+
317+
except Exception as e:
318+
# Update eval metadata status to error and log it
319+
if eval_metadata is not None:
320+
eval_metadata.status = "error"
321+
eval_metadata.passed = False
322+
323+
# Create a minimal result row to log the error if we don't have any results yet
324+
if not data:
325+
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
326+
default_logger.log(error_row)
327+
else:
328+
# Update existing results with error status
329+
for r in data:
330+
if r.eval_metadata is not None:
331+
r.eval_metadata.status = "error"
332+
r.eval_metadata.passed = False
333+
default_logger.log(r)
334+
335+
# Re-raise the exception to maintain pytest behavior
336+
raise
313337

314338
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
315339

0 commit comments

Comments
 (0)