1111"""
1212
1313import pytest
14- from eval_protocol .models import Status , EvaluationRow , Message , ErrorInfo
14+ from eval_protocol .models import Status , EvaluationRow , ErrorInfo
15+ from eval_protocol .types import TerminationReason
1516
1617
1718class TestErrorInfoModel :
@@ -41,10 +42,10 @@ def test_error_info_to_aip193_format(self):
4142 def test_error_info_factory_methods (self ):
4243 """Test the factory methods for common error types."""
4344 # Test termination_reason
44- term_error = ErrorInfo .termination_reason ("goal_reached" )
45+ term_error = ErrorInfo .termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL )
4546 assert term_error .reason == "TERMINATION_REASON"
4647 assert term_error .domain == "evalprotocol.io"
47- assert term_error .metadata ["termination_reason" ] == "goal_reached"
48+ assert term_error .metadata ["termination_reason" ] == TerminationReason . CONTROL_PLANE_SIGNAL
4849
4950 # Test extra_info
5051 extra_error = ErrorInfo .extra_info ({"steps" : 10 , "reward" : 0.8 })
@@ -100,7 +101,7 @@ def test_status_creation_methods(self):
100101 # Test finished status
101102 finished_status = Status .rollout_finished ()
102103 assert finished_status .code == Status .Code .FINISHED
103- assert finished_status .message == "Rollout finished successfully "
104+ assert finished_status .message == "Rollout finished"
104105 assert finished_status .details == []
105106
106107 # Test error status
@@ -127,18 +128,22 @@ def test_status_creation_methods(self):
127128 assert stopped_status .details == []
128129
129130 # Test with termination reason
130- termination_status = Status .with_termination_reason ("goal_reached" )
131+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL )
131132 assert termination_status .code == Status .Code .FINISHED
132133 assert termination_status .message == "Rollout finished"
133134 assert len (termination_status .details ) == 1
134135 assert termination_status .details [0 ]["@type" ] == "type.googleapis.com/google.rpc.ErrorInfo"
135136 assert termination_status .details [0 ]["reason" ] == "TERMINATION_REASON"
136137 assert termination_status .details [0 ]["domain" ] == "evalprotocol.io"
137- assert termination_status .details [0 ]["metadata" ]["termination_reason" ] == "goal_reached"
138+ assert (
139+ termination_status .details [0 ]["metadata" ]["termination_reason" ] == TerminationReason .CONTROL_PLANE_SIGNAL
140+ )
138141
139142 # Test with termination reason and extra info
140143 extra_info = {"steps" : 10 , "reward" : 0.8 }
141- termination_status_with_info = Status .with_termination_reason ("goal_reached" , extra_info )
144+ termination_status_with_info = Status .with_termination_reason (
145+ TerminationReason .CONTROL_PLANE_SIGNAL , extra_info
146+ )
142147 assert termination_status_with_info .code == Status .Code .FINISHED
143148 assert len (termination_status_with_info .details ) == 2
144149 # First detail should be termination reason
@@ -179,27 +184,31 @@ def test_status_helper_methods(self):
179184
180185 def test_get_termination_reason (self ):
181186 """Test extracting termination reason from status details."""
187+
182188 # Status without termination reason
183189 running_status = Status .rollout_running ()
184190 assert running_status .get_termination_reason () is None
185191
186192 # Status with termination reason
187- termination_status = Status .with_termination_reason ("goal_reached" )
188- assert termination_status .get_termination_reason () == "goal_reached"
193+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL )
194+ assert termination_status .get_termination_reason () == TerminationReason . CONTROL_PLANE_SIGNAL
189195
190196 # Status with termination reason and extra info
191197 extra_info = {"steps" : 10 }
192- termination_status_with_info = Status .with_termination_reason ("timeout" , extra_info )
193- assert termination_status_with_info .get_termination_reason () == "timeout"
198+ termination_status_with_info = Status .with_termination_reason (
199+ TerminationReason .CONTROL_PLANE_SIGNAL , extra_info
200+ )
201+ assert termination_status_with_info .get_termination_reason () == TerminationReason .CONTROL_PLANE_SIGNAL
194202
195203 def test_get_extra_info (self ):
196204 """Test extracting extra info from status details."""
205+
197206 # Status without extra info
198207 running_status = Status .rollout_running ()
199208 assert running_status .get_extra_info () is None
200209
201210 # Status with only termination reason (no extra info)
202- termination_status = Status .with_termination_reason ("goal_reached" )
211+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL )
203212 assert termination_status .get_extra_info () is None
204213
205214 # Status with extra info
@@ -208,7 +217,9 @@ def test_get_extra_info(self):
208217 assert error_status .get_extra_info () == extra_info
209218
210219 # Status with both termination reason and extra info
211- termination_status_with_info = Status .with_termination_reason ("goal_reached" , extra_info )
220+ termination_status_with_info = Status .with_termination_reason (
221+ TerminationReason .CONTROL_PLANE_SIGNAL , extra_info
222+ )
212223 assert termination_status_with_info .get_extra_info () == extra_info
213224
214225 def test_aip_193_compliance (self ):
@@ -227,7 +238,7 @@ def test_aip_193_compliance(self):
227238 assert detail ["metadata" ] == extra_info
228239
229240 # Test multiple details
230- termination_status = Status .with_termination_reason ("goal_reached" , extra_info )
241+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL , extra_info )
231242 assert len (termination_status .details ) == 2
232243
233244 # First detail should be termination reason
@@ -242,7 +253,7 @@ def test_aip_193_compliance(self):
242253
243254 def test_status_serialization (self ):
244255 """Test that Status can be serialized and deserialized."""
245- original_status = Status .with_termination_reason ("goal_reached" , {"steps" : 10 })
256+ original_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL , {"steps" : 10 })
246257
247258 # Test model_dump
248259 status_dict = original_status .model_dump ()
@@ -255,7 +266,7 @@ def test_status_serialization(self):
255266 assert reconstructed_status .code == original_status .code
256267 assert reconstructed_status .message == original_status .message
257268 assert len (reconstructed_status .details ) == len (original_status .details )
258- assert reconstructed_status .get_termination_reason () == "goal_reached"
269+ assert reconstructed_status .get_termination_reason () == TerminationReason . CONTROL_PLANE_SIGNAL
259270 assert reconstructed_status .get_extra_info () == {"steps" : 10 }
260271
261272 def test_status_equality (self ):
@@ -304,7 +315,7 @@ def test_backwards_compatibility_methods(self):
304315 new_status = Status .rollout_finished ()
305316 row .set_rollout_status (new_status )
306317 assert row .rollout_status .code == Status .Code .FINISHED
307- assert row .rollout_status .message == "Rollout finished successfully "
318+ assert row .rollout_status .message == "Rollout finished"
308319
309320 def test_status_transitions (self ):
310321 """Test transitioning between different status states."""
@@ -333,14 +344,14 @@ def test_termination_reason_integration(self):
333344 row = EvaluationRow (messages = [])
334345
335346 # Set status with termination reason
336- termination_status = Status .with_termination_reason ("goal_reached" , {"steps" : 15 })
347+ termination_status = Status .with_termination_reason (TerminationReason . CONTROL_PLANE_SIGNAL , {"steps" : 15 })
337348 row .rollout_status = termination_status
338349
339350 # Should be finished
340351 assert row .rollout_status .is_finished ()
341352
342353 # Should have termination reason
343- assert row .rollout_status .get_termination_reason () == "goal_reached"
354+ assert row .rollout_status .get_termination_reason () == TerminationReason . CONTROL_PLANE_SIGNAL
344355
345356 # Should have extra info
346357 extra_info = row .rollout_status .get_extra_info ()
@@ -375,39 +386,6 @@ def test_empty_details(self):
375386 assert status .get_termination_reason () is None
376387 assert status .get_extra_info () is None
377388
378- def test_malformed_details (self ):
379- """Test Status with malformed details."""
380- malformed_details = [
381- {"not_type" : "invalid" , "reason" : "TEST" },
382- {"@type" : "type.googleapis.com/google.rpc.ErrorInfo" , "metadata" : {"termination_reason" : "test" }},
383- ]
384- status = Status (code = Status .Code .OK , message = "Test" , details = malformed_details )
385-
386- # Should handle malformed details gracefully
387- assert status .get_termination_reason () == "test"
388- assert status .get_extra_info () is None
389-
390- def test_duplicate_detail_types (self ):
391- """Test Status with duplicate detail types."""
392- details = [
393- {
394- "@type" : "type.googleapis.com/google.rpc.ErrorInfo" ,
395- "reason" : "TERMINATION_REASON" ,
396- "domain" : "evalprotocol.io" ,
397- "metadata" : {"termination_reason" : "first" },
398- },
399- {
400- "@type" : "type.googleapis.com/google.rpc.ErrorInfo" ,
401- "reason" : "TERMINATION_REASON" ,
402- "domain" : "evalprotocol.io" ,
403- "metadata" : {"termination_reason" : "second" },
404- },
405- ]
406- status = Status (code = Status .Code .OK , message = "Test" , details = details )
407-
408- # Should return the first termination reason found
409- assert status .get_termination_reason () == "first"
410-
411389 def test_large_metadata (self ):
412390 """Test Status with large metadata."""
413391 large_metadata = {f"key_{ i } " : f"value_{ i } " for i in range (100 )}
0 commit comments