Skip to content

Commit c00ac2d

Browse files
author
Dylan Huang
committed
fix test_migration_Changes
1 parent 39b17dc commit c00ac2d

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

tests/test_migration_changes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_trajectory_terminated_status_creation(self):
2323
# Mock trajectory with termination
2424
trajectory = Mock()
2525
trajectory.terminated = True
26-
trajectory.termination_reason = "goal_reached"
26+
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
2727
trajectory.control_plane_summary = {"error_message": "No errors"}
2828

2929
# Create evaluation row
@@ -63,7 +63,7 @@ def test_trajectory_terminated_status_creation(self):
6363
assert row.rollout_status.is_finished()
6464

6565
# Verify termination reason
66-
assert row.rollout_status.get_termination_reason() == "goal_reached"
66+
assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
6767

6868
# Verify extra info
6969
assert row.rollout_status.get_extra_info() == {"error_message": "No errors"}
@@ -98,7 +98,7 @@ def test_trajectory_terminated_without_error_message(self):
9898
# Mock trajectory with termination but no error
9999
trajectory = Mock()
100100
trajectory.terminated = True
101-
trajectory.termination_reason = "timeout"
101+
trajectory.termination_reason = TerminationReason.USER_STOP
102102
trajectory.control_plane_summary = {}
103103

104104
# Create evaluation row
@@ -137,7 +137,7 @@ def test_trajectory_terminated_without_error_message(self):
137137
# Verify the status
138138
assert row.rollout_status.code == Status.Code.FINISHED
139139
assert row.rollout_status.is_finished()
140-
assert row.rollout_status.get_termination_reason() == "timeout"
140+
assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP
141141

142142
# Should not have extra info since there was no error message
143143
assert row.rollout_status.get_extra_info() is None
@@ -300,19 +300,19 @@ def test_termination_reason_integration(self):
300300
row = EvaluationRow(messages=[])
301301

302302
# Test with termination reason
303-
termination_status = Status.with_termination_reason("goal_reached")
303+
termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL)
304304
row.rollout_status = termination_status
305305

306306
assert row.rollout_status.is_finished()
307-
assert row.rollout_status.get_termination_reason() == "goal_reached"
307+
assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL
308308

309309
# Test with termination reason and extra info
310310
extra_info = {"steps": 10, "reward": 0.8}
311-
termination_status_with_info = Status.with_termination_reason("timeout", extra_info)
311+
termination_status_with_info = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info)
312312
row.rollout_status = termination_status_with_info
313313

314314
assert row.rollout_status.is_finished()
315-
assert row.rollout_status.get_termination_reason() == "timeout"
315+
assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP
316316
assert row.rollout_status.get_extra_info() == extra_info
317317

318318
def test_error_handling_integration(self):

0 commit comments

Comments
 (0)