Skip to content

Commit ad98650

Browse files
authored
Revert "Revert "add default try catch to evaluator function"" (#302)
* Revert "Revert "add try catch (#297)" (#301)" This reverts commit 56e00c2. * set the eval metadata status as well * add * lint * add * add * avoid override * fix ut * add
1 parent 44606f7 commit ad98650

4 files changed

Lines changed: 597 additions & 15 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
EvaluationRow,
2020
EvaluationThreshold,
2121
EvaluationThresholdDict,
22+
EvaluateResult,
2223
Status,
2324
)
2425
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
@@ -370,7 +371,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
370371
row.input_metadata.session_data = {}
371372
row.input_metadata.session_data["mode"] = mode
372373
# Initialize eval_metadata for each row
373-
row.eval_metadata = eval_metadata
374+
row.eval_metadata = eval_metadata.model_copy(deep=True)
374375
row.execution_metadata.experiment_id = experiment_id
375376
row.execution_metadata.invocation_id = invocation_id
376377

@@ -429,11 +430,23 @@ async def _execute_pointwise_eval_with_semaphore(
429430
experiment_id=experiment_id,
430431
run_id=run_id,
431432
):
432-
result = await execute_pytest(
433-
test_func,
434-
processed_row=row,
435-
evaluation_test_kwargs=evaluation_test_kwargs,
436-
)
433+
try:
434+
result = await execute_pytest(
435+
test_func,
436+
processed_row=row,
437+
evaluation_test_kwargs=evaluation_test_kwargs,
438+
)
439+
except Exception as e:
440+
result = row
441+
result.evaluation_result = EvaluateResult(
442+
score=0.0,
443+
is_score_valid=False,
444+
reason=f"Error during evaluation: {type(e).__name__}: {e}",
445+
)
446+
if result.eval_metadata is not None:
447+
result.eval_metadata.status = Status.error(
448+
f"Error during evaluation: {type(e).__name__}: {e}",
449+
)
437450
if not isinstance(result, EvaluationRow):
438451
raise ValueError(
439452
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
@@ -455,11 +468,24 @@ async def _execute_groupwise_eval_with_semaphore(
455468
run_id=run_id,
456469
rollout_ids=group_rollout_ids or None,
457470
):
458-
results = await execute_pytest(
459-
test_func,
460-
processed_dataset=rows,
461-
evaluation_test_kwargs=evaluation_test_kwargs,
462-
)
471+
try:
472+
results = await execute_pytest(
473+
test_func,
474+
processed_dataset=rows,
475+
evaluation_test_kwargs=evaluation_test_kwargs,
476+
)
477+
except Exception as e:
478+
results = rows
479+
for row in results:
480+
row.evaluation_result = EvaluateResult(
481+
score=0.0,
482+
is_score_valid=False,
483+
reason=f"Error during evaluation: {type(e).__name__}: {e}",
484+
)
485+
if row.eval_metadata is not None:
486+
row.eval_metadata.status = Status.error(
487+
f"Error during evaluation: {type(e).__name__}: {e}",
488+
)
463489
if not isinstance(results, list):
464490
raise ValueError(
465491
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
@@ -576,7 +602,10 @@ async def _collect_result(config, lst):
576602
r.eval_metadata.status = Status.error(
577603
r.rollout_status.message, r.rollout_status.details
578604
)
579-
else:
605+
elif not (
606+
r.eval_metadata.status and r.eval_metadata.status.code != Status.Code.RUNNING
607+
):
608+
# if the eval_metadata status code has not been set to something else, consider it as finished
580609
r.eval_metadata.status = Status.eval_finished()
581610
# Optional debug print for assistant/tool sequence
582611
if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1":

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def postprocess(
9191
result.evaluation_result.standard_error = standard_error
9292
if result.evaluation_result.is_score_valid is False:
9393
if result.eval_metadata is not None:
94-
result.eval_metadata.status = Status.score_invalid()
94+
if not result.eval_metadata.status or not result.eval_metadata.status.is_error():
95+
result.eval_metadata.status = Status.score_invalid()
9596
result.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
9697
active_logger.log(result)
9798

eval_protocol/pytest/exception_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
litellm.exceptions.InternalServerError,
3434
litellm.exceptions.Timeout,
3535
litellm.exceptions.NotFoundError,
36-
litellm.exceptions.BadRequestError, # remove this once we have a long term solution
36+
litellm.exceptions.BadRequestError,
3737
litellm.exceptions.ServiceUnavailableError,
38-
litellm.exceptions.APIError
38+
litellm.exceptions.APIError,
3939
}
4040

4141

0 commit comments

Comments
 (0)