|
21 | 21 | from eval_protocol.human_id import generate_id, num_combinations |
22 | 22 | from eval_protocol.models import ( |
23 | 23 | CompletionParams, |
| 24 | + ErrorInfo, |
24 | 25 | EvalMetadata, |
25 | 26 | EvaluationRow, |
26 | 27 | EvaluationThreshold, |
27 | 28 | InputMetadata, |
28 | 29 | Message, |
| 30 | + Status, |
29 | 31 | ) |
30 | 32 | from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter |
31 | 33 | from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor |
|
57 | 59 | ) |
58 | 60 | from eval_protocol.pytest.exception_config import ExceptionHandlerConfig |
59 | 61 | from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci |
| 62 | +from eval_protocol.types.types import TerminationReason |
60 | 63 |
|
61 | 64 | from ..common_utils import load_jsonl |
62 | 65 |
|
@@ -419,7 +422,7 @@ async def execute_with_params( |
419 | 422 | if mode == "groupwise": |
420 | 423 | combinations = generate_parameter_combinations( |
421 | 424 | input_dataset, |
422 | | - None, |
| 425 | + completion_params, |
423 | 426 | input_messages, |
424 | 427 | input_rows, |
425 | 428 | evaluation_test_kwargs, |
@@ -482,9 +485,7 @@ async def wrapper_body(**kwargs): |
482 | 485 |
|
483 | 486 | experiment_id = generate_id() |
484 | 487 |
|
485 | | - def _log_eval_error( |
486 | | - status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool |
487 | | - ) -> None: |
| 488 | + def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None, passed: bool) -> None: |
488 | 489 | log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger) |
489 | 490 |
|
490 | 491 | try: |
@@ -556,7 +557,7 @@ def _log_eval_error( |
556 | 557 | eval_metadata = EvalMetadata( |
557 | 558 | name=test_func.__name__, |
558 | 559 | description=test_func.__doc__, |
559 | | - status="running", |
| 560 | + status=Status.eval_running(), |
560 | 561 | num_runs=num_runs, |
561 | 562 | aggregation_method=aggregation_method, |
562 | 563 | passed_threshold=threshold, |
@@ -727,9 +728,11 @@ async def _collect_result(config, lst): |
727 | 728 | for r in results: |
728 | 729 | if r.eval_metadata is not None: |
729 | 730 | if r.rollout_status.is_error(): |
730 | | - r.eval_metadata.status = "error" |
| 731 | + r.eval_metadata.status = Status.error( |
| 732 | + r.rollout_status.message, r.rollout_status.details |
| 733 | + ) |
731 | 734 | else: |
732 | | - r.eval_metadata.status = "finished" |
| 735 | + r.eval_metadata.status = Status.eval_finished() |
733 | 736 | active_logger.log(r) |
734 | 737 |
|
735 | 738 | # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them |
@@ -767,14 +770,16 @@ async def _collect_result(config, lst): |
767 | 770 |
|
768 | 771 | except AssertionError: |
769 | 772 | _log_eval_error( |
770 | | - "finished", |
| 773 | + Status.eval_finished(), |
771 | 774 | processed_rows_in_run if "processed_rows_in_run" in locals() else None, |
772 | 775 | passed=False, |
773 | 776 | ) |
774 | 777 | raise |
775 | | - except Exception: |
| 778 | + except Exception as e: |
776 | 779 | _log_eval_error( |
777 | | - "error", processed_rows_in_run if "processed_rows_in_run" in locals() else None, passed=False |
| 780 | + Status.error(str(e)), |
| 781 | + processed_rows_in_run if "processed_rows_in_run" in locals() else None, |
| 782 | + passed=False, |
778 | 783 | ) |
779 | 784 | raise |
780 | 785 |
|
|
0 commit comments