Skip to content

Commit 29b3aea

Browse files
author
Dylan Huang
committed
migrate eval_metadata.status to AIP-193
1 parent 087695c commit 29b3aea

File tree

9 files changed

+140
-49
lines changed

9 files changed

+140
-49
lines changed

eval_protocol/models.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,29 @@ class Code(int, Enum):
107107
DATA_LOSS = 15
108108
UNAUTHENTICATED = 16
109109

110-
# Custom codes for rollout states (using higher numbers to avoid conflicts)
110+
# Custom codes for EP (using higher numbers to avoid conflicts)
111111
FINISHED = 100
112+
RUNNING = 101
112113

113114
@classmethod
114115
def rollout_running(cls) -> "Status":
115116
"""Create a status indicating the rollout is running."""
116-
return cls(code=cls.Code.OK, message="Rollout is running", details=[])
117+
return cls(code=cls.Code.RUNNING, message="Rollout is running", details=[])
118+
119+
@classmethod
120+
def eval_running(cls) -> "Status":
121+
"""Create a status indicating the evaluation is running."""
122+
return cls(code=cls.Code.RUNNING, message="Evaluation is running", details=[])
123+
124+
@classmethod
125+
def eval_finished(cls) -> "Status":
126+
"""Create a status indicating the evaluation finished."""
127+
return cls(code=cls.Code.FINISHED, message="Evaluation finished", details=[])
128+
129+
@classmethod
130+
def aborted(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
131+
"""Create a status indicating the evaluation was aborted."""
132+
return cls(code=cls.Code.ABORTED, message=message, details=details or [])
117133

118134
@classmethod
119135
def rollout_finished(
@@ -144,7 +160,7 @@ def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = Non
144160

145161
def is_running(self) -> bool:
146162
"""Check if the status indicates the rollout is running."""
147-
return self.code == self.Code.OK and self.message == "Rollout is running"
163+
return self.code == self.Code.RUNNING
148164

149165
def is_finished(self) -> bool:
150166
"""Check if the status indicates the rollout finished successfully."""
@@ -436,9 +452,7 @@ class EvalMetadata(BaseModel):
436452
default_factory=get_pep440_version,
437453
description="Version of the evaluation. Should be populated with a PEP 440 version string.",
438454
)
439-
status: Optional[Literal["running", "finished", "error", "stopped"]] = Field(
440-
None, description="Status of the evaluation"
441-
)
455+
status: Optional[Status] = Field(None, description="Status of the evaluation")
442456
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
443457
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
444458
passed_threshold: Optional[EvaluationThreshold] = Field(

eval_protocol/pytest/evaluation_test.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from eval_protocol.human_id import generate_id, num_combinations
2222
from eval_protocol.models import (
2323
CompletionParams,
24+
ErrorInfo,
2425
EvalMetadata,
2526
EvaluationRow,
2627
EvaluationThreshold,
2728
InputMetadata,
2829
Message,
30+
Status,
2931
)
3032
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
3133
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
@@ -57,6 +59,7 @@
5759
)
5860
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
5961
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
62+
from eval_protocol.types.types import TerminationReason
6063

6164
from ..common_utils import load_jsonl
6265

@@ -419,7 +422,7 @@ async def execute_with_params(
419422
if mode == "groupwise":
420423
combinations = generate_parameter_combinations(
421424
input_dataset,
422-
None,
425+
completion_params,
423426
input_messages,
424427
input_rows,
425428
evaluation_test_kwargs,
@@ -482,9 +485,7 @@ async def wrapper_body(**kwargs):
482485

483486
experiment_id = generate_id()
484487

485-
def _log_eval_error(
486-
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
487-
) -> None:
488+
def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None, passed: bool) -> None:
488489
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
489490

490491
try:
@@ -556,7 +557,7 @@ def _log_eval_error(
556557
eval_metadata = EvalMetadata(
557558
name=test_func.__name__,
558559
description=test_func.__doc__,
559-
status="running",
560+
status=Status.eval_running(),
560561
num_runs=num_runs,
561562
aggregation_method=aggregation_method,
562563
passed_threshold=threshold,
@@ -727,9 +728,11 @@ async def _collect_result(config, lst):
727728
for r in results:
728729
if r.eval_metadata is not None:
729730
if r.rollout_status.is_error():
730-
r.eval_metadata.status = "error"
731+
r.eval_metadata.status = Status.error(
732+
r.rollout_status.message, r.rollout_status.details
733+
)
731734
else:
732-
r.eval_metadata.status = "finished"
735+
r.eval_metadata.status = Status.eval_finished()
733736
active_logger.log(r)
734737

735738
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
@@ -767,14 +770,16 @@ async def _collect_result(config, lst):
767770

768771
except AssertionError:
769772
_log_eval_error(
770-
"finished",
773+
Status.eval_finished(),
771774
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
772775
passed=False,
773776
)
774777
raise
775-
except Exception:
778+
except Exception as e:
776779
_log_eval_error(
777-
"error", processed_rows_in_run if "processed_rows_in_run" in locals() else None, passed=False
780+
Status.error(str(e)),
781+
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
782+
passed=False,
778783
)
779784
raise
780785

eval_protocol/pytest/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ async def wrapper(**kwargs):
113113
def log_eval_status_and_rows(
114114
eval_metadata: Optional[EvalMetadata],
115115
rows: Optional[List[EvaluationRow]] | None,
116-
status: Literal["finished", "error"],
116+
status: Status,
117117
passed: bool,
118118
logger: DatasetLogger,
119119
) -> None:

eval_protocol/utils/logs_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from eval_protocol.dataset_logger import default_logger
1616
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
1717
from eval_protocol.event_bus import event_bus
18+
from eval_protocol.models import Status
1819
from eval_protocol.utils.vite_server import ViteServer
1920

2021
if TYPE_CHECKING:
@@ -179,7 +180,9 @@ def _check_running_evaluations(self):
179180
if self._should_update_status(row):
180181
logger.info(f"Updating status to 'stopped' for row {row.input_metadata.row_id} (PID {row.pid})")
181182
if row.eval_metadata is not None:
182-
row.eval_metadata.status = "stopped"
183+
row.eval_metadata.status = Status.aborted(
184+
f"Evaluation aborted since process {row.pid} stopped"
185+
)
183186
updated_rows.append(row)
184187

185188
# Log all updated rows
@@ -194,7 +197,12 @@ def _check_running_evaluations(self):
194197
def _should_update_status(self, row: "EvaluationRow") -> bool:
195198
"""Check if a row's status should be updated to 'stopped'."""
196199
# Check if the row has running status and a PID
197-
if row.eval_metadata and row.eval_metadata.status == "running" and row.pid is not None:
200+
if (
201+
row.eval_metadata
202+
and row.eval_metadata.status
203+
and row.eval_metadata.status.is_running()
204+
and row.pid is not None
205+
):
198206
# Check if the process is still running
199207
try:
200208
process = psutil.Process(row.pid)

tests/pytest/test_pytest_propagate_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
6060

6161
# assert that the status of eval_metadata.status is "error"
6262
assert len(rollouts) == 5
63-
assert all(row.eval_metadata.status == "error" for row in rollouts.values())
63+
assert all(row.eval_metadata.status.is_error() for row in rollouts.values())

vite-app/src/App.tsx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,11 @@ const App = observer(() => {
150150
</div>
151151
<div className="flex items-center gap-2">
152152
<StatusIndicator
153-
status={state.isConnected ? "connected" : "disconnected"}
153+
status={
154+
state.isConnected
155+
? { code: 0, message: "Connected", details: [] }
156+
: { code: 1, message: "Disconnected", details: [] }
157+
}
154158
/>
155159
<Button onClick={handleManualRefresh} className="ml-2">
156160
Refresh

vite-app/src/components/EvaluationRow.tsx

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import { observer } from "mobx-react";
2-
import type { EvaluationRow as EvaluationRowType } from "../types/eval-protocol";
2+
import type {
3+
EvaluationRow as EvaluationRowType,
4+
Status,
5+
} from "../types/eval-protocol";
36
import { ChatInterface } from "./ChatInterface";
47
import { MetadataSection } from "./MetadataSection";
58
import StatusIndicator from "./StatusIndicator";
@@ -146,11 +149,14 @@ const RowStatus = observer(
146149
status,
147150
showSpinner,
148151
}: {
149-
status: string | undefined;
152+
status: Status | undefined;
150153
showSpinner: boolean;
151154
}) => (
152155
<div className="whitespace-nowrap">
153-
<StatusIndicator showSpinner={showSpinner} status={status || "N/A"} />
156+
<StatusIndicator
157+
showSpinner={showSpinner}
158+
status={status || { code: 2, message: "N/A", details: [] }}
159+
/>
154160
</div>
155161
)
156162
);
@@ -340,7 +346,7 @@ export const EvaluationRow = observer(
340346
<TableCell className="py-3 text-xs">
341347
<RowStatus
342348
status={row.eval_metadata?.status}
343-
showSpinner={row.eval_metadata?.status === "running"}
349+
showSpinner={row.eval_metadata?.status?.code === 101}
344350
/>
345351
</TableCell>
346352

vite-app/src/components/StatusIndicator.tsx

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import React from "react";
2+
import { getStatusCodeName, type Status } from "../types/eval-protocol";
23

34
interface StatusIndicatorProps {
4-
status: string;
5+
status: Status;
56
className?: string;
67
showSpinner?: boolean;
78
}
@@ -17,39 +18,41 @@ const StatusIndicator: React.FC<StatusIndicatorProps> = ({
1718
className = "",
1819
showSpinner = false,
1920
}) => {
20-
const getStatusConfig = (status: string) => {
21-
switch (status.toLowerCase()) {
22-
case "connected":
21+
const getStatusConfig = (status: Status) => {
22+
const statusCodeName = getStatusCodeName(status.code);
23+
24+
switch (statusCodeName) {
25+
case "OK":
2326
return {
2427
dotColor: "bg-green-500",
2528
textColor: "text-green-700",
2629
text: "Connected",
2730
};
28-
case "disconnected":
31+
case "CANCELLED":
2932
return {
3033
dotColor: "bg-red-500",
3134
textColor: "text-red-700",
3235
text: "Disconnected",
3336
};
34-
case "finished":
37+
case "FINISHED":
3538
return {
3639
dotColor: "bg-green-500",
3740
textColor: "text-green-700",
3841
text: "finished",
3942
};
40-
case "running":
43+
case "RUNNING":
4144
return {
4245
dotColor: "bg-blue-500",
4346
textColor: "text-blue-700",
4447
text: "running",
4548
};
46-
case "error":
49+
case "INTERNAL":
4750
return {
4851
dotColor: "bg-red-500",
4952
textColor: "text-red-700",
5053
text: "error",
5154
};
52-
case "stopped":
55+
case "ABORTED":
5356
return {
5457
dotColor: "bg-yellow-500",
5558
textColor: "text-yellow-700",
@@ -59,13 +62,13 @@ const StatusIndicator: React.FC<StatusIndicatorProps> = ({
5962
return {
6063
dotColor: "bg-gray-500",
6164
textColor: "text-gray-700",
62-
text: status,
65+
text: status.message,
6366
};
6467
}
6568
};
6669

6770
const config = getStatusConfig(status);
68-
const shouldShowSpinner = showSpinner && status.toLowerCase() === "running";
71+
const shouldShowSpinner = showSpinner && status.code === 101;
6972

7073
return (
7174
<div

0 commit comments

Comments
 (0)