Skip to content

Commit 39b17dc

Browse files
author
Dylan Huang
committed
fix test_status_migration_integration
1 parent 7933dd1 commit 39b17dc

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

eval_protocol/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]]
148148
"""Create a status indicating the rollout failed with an error."""
149149
details = []
150150
if extra_info:
151-
details.append(ErrorInfo.rollout_error(extra_info).to_aip193_format())
151+
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
152152
return cls.error(error_message, details)
153153

154154
@classmethod
@@ -210,7 +210,7 @@ def get_extra_info(self) -> Optional[Dict[str, Any]]:
210210
metadata = detail.get("metadata", {})
211211
reason = detail.get("reason")
212212
# Skip termination_reason and stopped details, return other error info
213-
if reason not in [ErrorInfo.REASON_TERMINATION_REASON, ErrorInfo.REASON_STOPPED]:
213+
if reason in [ErrorInfo.REASON_EXTRA_INFO]:
214214
return metadata
215215
return None
216216

tests/test_status_migration_integration.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,33 +106,33 @@ def test_termination_reason_in_status_details(self):
106106
row = EvaluationRow(messages=[])
107107

108108
# Set status with termination reason
109-
termination_status = Status.with_termination_reason("goal_reached")
109+
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
110110
row.rollout_status = termination_status
111111

112112
# Should be finished
113113
assert row.rollout_status.is_finished()
114114

115115
# Should have termination reason in details
116-
assert row.rollout_status.get_termination_reason() == "goal_reached"
116+
assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
117117

118118
# Check details structure
119119
assert len(row.rollout_status.details) == 1
120120
detail = row.rollout_status.details[0]
121121
assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo"
122122
assert detail["reason"] == "TERMINATION_REASON"
123123
assert detail["domain"] == "evalprotocol.io"
124-
assert detail["metadata"]["termination_reason"] == "goal_reached"
124+
assert detail["metadata"]["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL
125125

126126
def test_termination_reason_with_extra_info(self):
127127
"""Test termination reason with additional extra info."""
128128
row = EvaluationRow(messages=[])
129129

130130
extra_info = {"steps": 10, "reward": 0.8}
131-
termination_status = Status.with_termination_reason("timeout", extra_info)
131+
termination_status = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info)
132132
row.rollout_status = termination_status
133133

134134
# Should have both termination reason and extra info
135-
assert row.rollout_status.get_termination_reason() == "timeout"
135+
assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP
136136
assert row.rollout_status.get_extra_info() == extra_info
137137

138138
# Check details structure
@@ -157,21 +157,21 @@ def test_multiple_termination_reasons(self):
157157
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
158158
"reason": "TERMINATION_REASON",
159159
"domain": "evalprotocol.io",
160-
"metadata": {"termination_reason": "first"},
160+
"metadata": {"termination_reason": TerminationReason.USER_STOP},
161161
},
162162
{
163163
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
164164
"reason": "TERMINATION_REASON",
165165
"domain": "evalprotocol.io",
166-
"metadata": {"termination_reason": "second"},
166+
"metadata": {"termination_reason": TerminationReason.SKIPPABLE_ERROR},
167167
},
168168
]
169169

170170
status = Status(code=Status.Code.FINISHED, message="Test", details=details)
171171
row.rollout_status = status
172172

173173
# Should return the first termination reason found
174-
assert row.rollout_status.get_termination_reason() == "first"
174+
assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP
175175

176176

177177
class TestErrorHandlingIntegration:
@@ -204,7 +204,7 @@ def test_error_status_with_metadata(self):
204204
assert len(row.rollout_status.details) == 1
205205
detail = row.rollout_status.details[0]
206206
assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo"
207-
assert detail["reason"] == "ROLLOUT_ERROR"
207+
assert detail["reason"] == "EXTRA_INFO"
208208
assert detail["domain"] == "evalprotocol.io"
209209
assert detail["metadata"] == error_info
210210

@@ -331,7 +331,7 @@ def test_status_model_validate(self):
331331

332332
# Set a complex status
333333
extra_info = {"steps": 10, "reward": 0.8}
334-
original_status = Status.with_termination_reason("goal_reached", extra_info)
334+
original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
335335
row.rollout_status = original_status
336336

337337
# Dump and reconstruct
@@ -344,7 +344,7 @@ def test_status_model_validate(self):
344344
assert len(reconstructed_status.details) == len(original_status.details)
345345

346346
# Should preserve functionality
347-
assert reconstructed_status.get_termination_reason() == "goal_reached"
347+
assert reconstructed_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
348348
assert reconstructed_status.get_extra_info() == extra_info
349349

350350

@@ -377,7 +377,7 @@ def test_malformed_status_details(self):
377377
row.rollout_status = malformed_status
378378

379379
# Should handle gracefully
380-
assert row.rollout_status.get_termination_reason() == "test"
380+
assert row.rollout_status.get_termination_reason() is None
381381
assert row.rollout_status.get_extra_info() is None
382382

383383
def test_large_metadata_handling(self):

0 commit comments

Comments
 (0)