Skip to content

Commit 775b07e

Browse files
author
Dylan Huang
committed
fix tests
1 parent 38829b4 commit 775b07e

4 files changed

Lines changed: 20 additions & 38 deletions

File tree

eval_protocol/models.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,6 @@ def rollout_stopped(cls, message: str, extra_info: Optional[Dict[str, Any]] = No
164164
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
165165
return cls(code=cls.Code.CANCELLED, message=message, details=details)
166166

167-
@classmethod
168-
def with_termination_reason(
169-
cls, termination_reason: TerminationReason, extra_info: Optional[Dict[str, Any]] = None
170-
) -> "Status":
171-
"""Create a status indicating the rollout finished with termination reason."""
172-
details = [ErrorInfo.termination_reason(termination_reason).to_aip193_format()]
173-
174-
if extra_info:
175-
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
176-
177-
return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details)
178-
179167
def is_running(self) -> bool:
180168
"""Check if the status indicates the rollout is running."""
181169
return self.code == self.Code.OK and self.message == "Rollout is running"
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,15 @@ def test_termination_reason_integration(self):
300300
row = EvaluationRow(messages=[])
301301

302302
# Test with termination reason
303-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
303+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
304304
row.rollout_status = termination_status
305305

306306
assert row.rollout_status.is_finished()
307307
assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
308308

309309
# Test with termination reason and extra info
310310
extra_info = {"steps": 10, "reward": 0.8}
311-
termination_status_with_info = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info)
311+
termination_status_with_info = Status.rollout_finished(TerminationReason.USER_STOP, extra_info)
312312
row.rollout_status = termination_status_with_info
313313

314314
assert row.rollout_status.is_finished()
@@ -392,7 +392,7 @@ def test_termination_reason_structure_compliance(self):
392392
row = EvaluationRow(messages=[])
393393

394394
# Create status with termination reason
395-
termination_status = Status.with_termination_reason("goal_reached")
395+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
396396
row.rollout_status = termination_status
397397

398398
# Check AIP-193 structure
@@ -411,7 +411,7 @@ def test_multiple_details_compliance(self):
411411

412412
# Create status with both termination reason and extra info
413413
extra_info = {"steps": 15, "reward": 0.9}
414-
status = Status.with_termination_reason("goal_reached", extra_info)
414+
status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
415415
row.rollout_status = status
416416

417417
# Should have two details

tests/test_status_migration_integration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_termination_reason_in_status_details(self):
106106
row = EvaluationRow(messages=[])
107107

108108
# Set status with termination reason
109-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
109+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
110110
row.rollout_status = termination_status
111111

112112
# Should be finished
@@ -128,7 +128,7 @@ def test_termination_reason_with_extra_info(self):
128128
row = EvaluationRow(messages=[])
129129

130130
extra_info = {"steps": 10, "reward": 0.8}
131-
termination_status = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info)
131+
termination_status = Status.rollout_finished(TerminationReason.USER_STOP, extra_info)
132132
row.rollout_status = termination_status
133133

134134
# Should have both termination reason and extra info
@@ -262,7 +262,7 @@ def test_multiple_detail_types(self):
262262

263263
# Create status with both termination reason and extra info
264264
extra_info = {"steps": 15, "reward": 0.9}
265-
status = Status.with_termination_reason("goal_reached", extra_info)
265+
status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
266266
row.rollout_status = status
267267

268268
# Should have two details
@@ -309,7 +309,7 @@ def test_status_model_dump(self):
309309

310310
# Set a complex status
311311
extra_info = {"steps": 10, "reward": 0.8}
312-
termination_status = Status.with_termination_reason("goal_reached", extra_info)
312+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
313313
row.rollout_status = termination_status
314314

315315
# Dump to dict
@@ -331,7 +331,7 @@ def test_status_model_validate(self):
331331

332332
# Set a complex status
333333
extra_info = {"steps": 10, "reward": 0.8}
334-
original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
334+
original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
335335
row.rollout_status = original_status
336336

337337
# Dump and reconstruct

tests/test_status_model.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_status_creation_methods(self):
117117
assert error_status_with_info.message == "Something went wrong"
118118
assert len(error_status_with_info.details) == 1
119119
assert error_status_with_info.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo"
120-
assert error_status_with_info.details[0]["reason"] == "ROLLOUT_ERROR"
120+
assert error_status_with_info.details[0]["reason"] == "EXTRA_INFO"
121121
assert error_status_with_info.details[0]["domain"] == "evalprotocol.io"
122122
assert error_status_with_info.details[0]["metadata"] == extra_info
123123

@@ -128,7 +128,7 @@ def test_status_creation_methods(self):
128128
assert stopped_status.details == []
129129

130130
# Test with termination reason
131-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
131+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
132132
assert termination_status.code == Status.Code.FINISHED
133133
assert termination_status.message == "Rollout finished"
134134
assert len(termination_status.details) == 1
@@ -141,9 +141,7 @@ def test_status_creation_methods(self):
141141

142142
# Test with termination reason and extra info
143143
extra_info = {"steps": 10, "reward": 0.8}
144-
termination_status_with_info = Status.with_termination_reason(
145-
TerminationReason.CONTROL_PLANE_SIGNAL, extra_info
146-
)
144+
termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
147145
assert termination_status_with_info.code == Status.Code.FINISHED
148146
assert len(termination_status_with_info.details) == 2
149147
# First detail should be termination reason
@@ -190,14 +188,12 @@ def test_get_termination_reason(self):
190188
assert running_status.get_termination_reason() is None
191189

192190
# Status with termination reason
193-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
191+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
194192
assert termination_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
195193

196194
# Status with termination reason and extra info
197195
extra_info = {"steps": 10}
198-
termination_status_with_info = Status.with_termination_reason(
199-
TerminationReason.CONTROL_PLANE_SIGNAL, extra_info
200-
)
196+
termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
201197
assert termination_status_with_info.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
202198

203199
def test_get_extra_info(self):
@@ -208,7 +204,7 @@ def test_get_extra_info(self):
208204
assert running_status.get_extra_info() is None
209205

210206
# Status with only termination reason (no extra info)
211-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
207+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL)
212208
assert termination_status.get_extra_info() is None
213209

214210
# Status with extra info
@@ -217,9 +213,7 @@ def test_get_extra_info(self):
217213
assert error_status.get_extra_info() == extra_info
218214

219215
# Status with both termination reason and extra info
220-
termination_status_with_info = Status.with_termination_reason(
221-
TerminationReason.CONTROL_PLANE_SIGNAL, extra_info
222-
)
216+
termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
223217
assert termination_status_with_info.get_extra_info() == extra_info
224218

225219
def test_aip_193_compliance(self):
@@ -233,12 +227,12 @@ def test_aip_193_compliance(self):
233227

234228
# Check AIP-193 ErrorInfo structure
235229
assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo"
236-
assert detail["reason"] == "ROLLOUT_ERROR"
230+
assert detail["reason"] == "EXTRA_INFO"
237231
assert detail["domain"] == "evalprotocol.io"
238232
assert detail["metadata"] == extra_info
239233

240234
# Test multiple details
241-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
235+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info)
242236
assert len(termination_status.details) == 2
243237

244238
# First detail should be termination reason
@@ -253,7 +247,7 @@ def test_aip_193_compliance(self):
253247

254248
def test_status_serialization(self):
255249
"""Test that Status can be serialized and deserialized."""
256-
original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10})
250+
original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10})
257251

258252
# Test model_dump
259253
status_dict = original_status.model_dump()
@@ -329,7 +323,7 @@ def test_termination_reason_integration(self):
329323
row = EvaluationRow(messages=[])
330324

331325
# Set status with termination reason
332-
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15})
326+
termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15})
333327
row.rollout_status = termination_status
334328

335329
# Should be finished

0 commit comments

Comments
 (0)