@@ -106,33 +106,33 @@ def test_termination_reason_in_status_details(self):
106106 row = EvaluationRow (messages = [])
107107
108108 # Set status with termination reason
109- termination_status = Status .with_termination_reason ("goal_reached" )
109+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL )
110110 row .rollout_status = termination_status
111111
112112 # Should be finished
113113 assert row .rollout_status .is_finished ()
114114
115115 # Should have termination reason in details
116- assert row .rollout_status .get_termination_reason () == "goal_reached"
116+ assert row .rollout_status .get_termination_reason () == TerminationReason . CONTROL_PLANE_SIGNAL
117117
118118 # Check details structure
119119 assert len (row .rollout_status .details ) == 1
120120 detail = row .rollout_status .details [0 ]
121121 assert detail ["@type" ] == "type.googleapis.com/google.rpc.ErrorInfo"
122122 assert detail ["reason" ] == "TERMINATION_REASON"
123123 assert detail ["domain" ] == "evalprotocol.io"
124- assert detail ["metadata" ]["termination_reason" ] == "goal_reached"
124+ assert detail ["metadata" ]["termination_reason" ] == TerminationReason . CONTROL_PLANE_SIGNAL
125125
126126 def test_termination_reason_with_extra_info (self ):
127127 """Test termination reason with additional extra info."""
128128 row = EvaluationRow (messages = [])
129129
130130 extra_info = {"steps" : 10 , "reward" : 0.8 }
131- termination_status = Status .with_termination_reason ("timeout" , extra_info )
131+ termination_status = Status .with_termination_reason (TerminationReason . USER_STOP , extra_info )
132132 row .rollout_status = termination_status
133133
134134 # Should have both termination reason and extra info
135- assert row .rollout_status .get_termination_reason () == "timeout"
135+ assert row .rollout_status .get_termination_reason () == TerminationReason . USER_STOP
136136 assert row .rollout_status .get_extra_info () == extra_info
137137
138138 # Check details structure
@@ -157,21 +157,21 @@ def test_multiple_termination_reasons(self):
157157 "@type" : "type.googleapis.com/google.rpc.ErrorInfo" ,
158158 "reason" : "TERMINATION_REASON" ,
159159 "domain" : "evalprotocol.io" ,
160- "metadata" : {"termination_reason" : "first" },
160+ "metadata" : {"termination_reason" : TerminationReason . USER_STOP },
161161 },
162162 {
163163 "@type" : "type.googleapis.com/google.rpc.ErrorInfo" ,
164164 "reason" : "TERMINATION_REASON" ,
165165 "domain" : "evalprotocol.io" ,
166- "metadata" : {"termination_reason" : "second" },
166+ "metadata" : {"termination_reason" : TerminationReason . SKIPPABLE_ERROR },
167167 },
168168 ]
169169
170170 status = Status (code = Status .Code .FINISHED , message = "Test" , details = details )
171171 row .rollout_status = status
172172
173173 # Should return the first termination reason found
174- assert row .rollout_status .get_termination_reason () == "first"
174+ assert row .rollout_status .get_termination_reason () == TerminationReason . USER_STOP
175175
176176
177177class TestErrorHandlingIntegration :
@@ -204,7 +204,7 @@ def test_error_status_with_metadata(self):
204204 assert len (row .rollout_status .details ) == 1
205205 detail = row .rollout_status .details [0 ]
206206 assert detail ["@type" ] == "type.googleapis.com/google.rpc.ErrorInfo"
207- assert detail ["reason" ] == "ROLLOUT_ERROR "
207+ assert detail ["reason" ] == "EXTRA_INFO "
208208 assert detail ["domain" ] == "evalprotocol.io"
209209 assert detail ["metadata" ] == error_info
210210
@@ -331,7 +331,7 @@ def test_status_model_validate(self):
331331
332332 # Set a complex status
333333 extra_info = {"steps" : 10 , "reward" : 0.8 }
334- original_status = Status .with_termination_reason ("goal_reached" , extra_info )
334+ original_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL , extra_info )
335335 row .rollout_status = original_status
336336
337337 # Dump and reconstruct
@@ -344,7 +344,7 @@ def test_status_model_validate(self):
344344 assert len (reconstructed_status .details ) == len (original_status .details )
345345
346346 # Should preserve functionality
347- assert reconstructed_status .get_termination_reason () == "goal_reached"
347+ assert reconstructed_status .get_termination_reason () == TerminationReason . CONTROL_PLANE_SIGNAL
348348 assert reconstructed_status .get_extra_info () == extra_info
349349
350350
@@ -377,7 +377,7 @@ def test_malformed_status_details(self):
377377 row .rollout_status = malformed_status
378378
379379 # Should handle gracefully
380- assert row .rollout_status .get_termination_reason () == "test"
380+ assert row .rollout_status .get_termination_reason () is None
381381 assert row .rollout_status .get_extra_info () is None
382382
383383 def test_large_metadata_handling (self ):
0 commit comments