55
66import asyncio
77import os
8- import time
9- from dataclasses import dataclass
10- from typing import AsyncIterator , List
8+ from collections import Counter
9+ from typing import List
10+ from unittest .mock import Mock
11+
12+ import pytest
1113
1214from eval_protocol .models import EvaluateResult , EvaluationRow , Message , RolloutStatus
1315from eval_protocol .pytest .evaluation_test import evaluation_test
1618
1719os .environ ["EP_MAX_RETRY" ] = "2" # Allow up to 2 retries
1820
19- start_time = time .time ()
20- timing_results = [] # Collect timing data for assertions
21-
2221
2322class MockRolloutProcessorWithRetries (RolloutProcessor ):
2423 """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry"""
2524
25+ def __init__ (self ):
26+ self .mock_tracker = Mock ()
27+
2628 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
29+ # Track this batch call
30+ self .mock_tracker .batch_call (len (rows ))
31+
2732 row_setup = {
28- 0 : {"delay" : 3.0 , "should_fail" : False },
29- 1 : {"delay" : 3.0 , "should_fail" : True },
30- 2 : {"delay" : 5.0 , "should_fail" : False },
31- 3 : {"delay" : 5.0 , "should_fail" : False },
32- 4 : {"delay" : 5.0 , "should_fail" : False },
33+ 0 : {"delay" : 0.01 , "should_fail" : False },
34+ 1 : {"delay" : 0.01 , "should_fail" : True }, # Will be adjusted based on attempt number
35+ 2 : {"delay" : 0.01 , "should_fail" : False },
36+ 3 : {"delay" : 0.01 , "should_fail" : False },
37+ 4 : {"delay" : 0.01 , "should_fail" : False },
3338 }
3439
35- async def process_single_row (row : EvaluationRow , delay : float , should_fail : bool = False ) -> EvaluationRow :
36- await asyncio .sleep (delay )
40+ async def process_single_row (
41+ row : EvaluationRow , delay : float , base_should_fail : bool = False
42+ ) -> EvaluationRow :
43+ rollout_id = row .execution_metadata .rollout_id
44+
45+ # Track individual row processing call
46+ self .mock_tracker .process_row_call (rollout_id )
47+
48+ # Determine attempt number by counting previous calls for this rollout_id
49+ previous_calls = [
50+ call for call in self .mock_tracker .process_row_call .call_args_list if call [0 ][0 ] == rollout_id
51+ ]
52+ attempt_number = len (previous_calls )
3753
38- elapsed = time .time () - start_time
39- print (
40- f"🎉 FINISHED { 'error' if should_fail else 'finished' } at { elapsed :.2f} s: { row .execution_metadata .rollout_id } "
41- )
54+ # Determine if this specific attempt should fail
55+ # Row 1 fails on first attempt (attempt_number == 1), succeeds on retry (attempt_number == 2)
56+ should_fail = base_should_fail and attempt_number == 1
57+
58+ print (f"🔄 ATTEMPTING rollout_id={ rollout_id } , attempt={ attempt_number } , will_fail={ should_fail } " )
59+
60+ await asyncio .sleep (delay )
61+ print (f"🎉 FINISHED { 'error' if should_fail else 'finished' } : { row .execution_metadata .rollout_id } " )
4262
4363 if should_fail :
4464 raise Exception ("Simulated failure for testing" )
@@ -54,6 +74,10 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool
5474 return tasks
5575
5676
77+ # Create a shared processor instance for testing
78+ shared_processor = MockRolloutProcessorWithRetries ()
79+
80+
5781@evaluation_test (
5882 completion_params = [{"model" : "gpt-4o-mini" , "temperature" : 0 }],
5983 input_messages = [
@@ -63,16 +87,14 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool
6387 [Message (role = "user" , content = "Task D" )],
6488 [Message (role = "user" , content = "Task E" )],
6589 ],
66- rollout_processor = MockRolloutProcessorWithRetries () ,
90+ rollout_processor = shared_processor ,
6791 num_runs = 1 ,
6892 mode = "pointwise" ,
6993)
7094def test_retry_mechanism (row : EvaluationRow ) -> EvaluationRow :
71- """MOCK TEST: first 2 rows take 3s, last 3 take 5s, second row fails on first attempt, succeeds on retry. Should take around 6s total."""
72- # Just print the timing - we'll parse it from output
73- elapsed = time .time () - start_time
95+ """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry."""
7496 print (
75- f"📊 EVALUATED at { elapsed :.2f } s : { row .execution_metadata .rollout_id } ({ 'SUCCESS' if row .rollout_status .status == 'finished' else 'FAILURE' } )"
97+ f"📊 EVALUATED: { row .execution_metadata .rollout_id } ({ 'SUCCESS' if row .rollout_status .status == 'finished' else 'FAILURE' } )"
7698 )
7799
78100 # Assign a score based on success/failure
@@ -82,56 +104,54 @@ def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow:
82104 return row
83105
84106
85- def test_timing_assertions ():
86- """Validate that timing results match expected pipeline behavior"""
87- global start_time
88-
89- # Reset and run the evaluation test
90- start_time = time .time ()
91-
92- # Capture pytest output
93- import subprocess
94- import sys
95-
96- result = subprocess .run (
97- [sys .executable , "-m" , "pytest" , __file__ + "::test_retry_mechanism" , "-v" , "-s" ],
98- capture_output = True ,
99- text = True ,
100- cwd = os .getcwd (),
101- )
102-
103- print (result .stdout ) # Show the original output
104-
105- # Parse timing from output
106- import re
107-
108- timing_results = []
109- for line in result .stdout .split ("\n " ):
110- match = re .search (r"📊 EVALUATED at (\d+\.\d+)s:" , line )
111- if match :
112- timing_results .append (float (match .group (1 )))
113-
114- print (f"\n 📊 PIPELINE TIMING ANALYSIS:" )
115- print (f" Results received at: { [f'{ t :.2f} s' for t in sorted (timing_results )]} " )
116-
117- # Assertions for expected timing behavior
118- sorted_times = sorted (timing_results )
119-
120- assert len (sorted_times ) == 5 , f"Expected 5 evaluation results, got { len (sorted_times )} "
121-
122- # First result should be around 3s (row 0 success)
123- assert 2.5 <= sorted_times [0 ] <= 3.5 , f"First result at { sorted_times [0 ]:.2f} s, expected ~3s"
124-
125- # Next three results should be around 5s (rows 2,3,4)
126- assert 4.5 <= sorted_times [1 ] <= 5.5 , f"Second result at { sorted_times [1 ]:.2f} s, expected ~5s"
127- assert 4.5 <= sorted_times [2 ] <= 5.5 , f"Third result at { sorted_times [2 ]:.2f} s, expected ~5s"
128- assert 4.5 <= sorted_times [3 ] <= 5.5 , f"Fourth result at { sorted_times [3 ]:.2f} s, expected ~5s"
129-
130- # Last result should be around 6s (row 1 retry success)
131- assert 5.5 <= sorted_times [4 ] <= 6.5 , f"Fifth result at { sorted_times [4 ]:.2f} s, expected ~6s (retry success)"
132-
133- print ("✅ All timing assertions passed! Pipeline behavior is correct." )
134-
135-
136- if __name__ == "__main__" :
137- test_timing_assertions ()
107+ def test_retry_mechanism_mock_verification ():
108+ """Test that verifies the retry mechanism worked by checking the mock calls"""
109+ # Get our mock tracker
110+ mock_tracker = shared_processor .mock_tracker
111+
112+ print (f"\n 🔄 MOCK CALL ANALYSIS:" )
113+ print (f" Batch calls made: { mock_tracker .batch_call .call_count } " )
114+ print (f" Total row processing calls: { mock_tracker .process_row_call .call_count } " )
115+
116+ if mock_tracker .process_row_call .call_count == 0 :
117+ print ("⚠️ No calls recorded yet. The evaluation test may not have run or completed." )
118+ return
119+
120+ # Get all rollout_ids that were processed
121+ call_args = mock_tracker .process_row_call .call_args_list
122+ rollout_ids = [call [0 ][0 ] for call in call_args ]
123+
124+ # Count calls per rollout_id
125+ call_counts = Counter (rollout_ids )
126+
127+ print (f" Call counts per rollout_id: { dict (call_counts )} " )
128+ print (f" Individual calls:" )
129+ for i , call_arg in enumerate (call_args , 1 ):
130+ rollout_id = call_arg [0 ][0 ]
131+ attempt_num = rollout_ids [:i ].count (rollout_id )
132+ print (f" { i } . rollout_id={ rollout_id } , attempt={ attempt_num } " )
133+
134+ # ASSERTIONS USING MOCK DATA
135+ # Should have exactly 6 total row processing calls (5 initial + 1 retry)
136+ assert (
137+ mock_tracker .process_row_call .call_count == 6
138+ ), f"Expected 6 total calls, got { mock_tracker .process_row_call .call_count } "
139+
140+ # Should have exactly 2 batch calls (initial batch + retry batch)
141+ assert mock_tracker .batch_call .call_count == 2 , f"Expected 2 batch calls, got { mock_tracker .batch_call .call_count } "
142+
143+ # First batch should have 5 rows, second batch should have 1 row (the retry)
144+ batch_call_args = mock_tracker .batch_call .call_args_list
145+ assert batch_call_args [0 ][0 ][0 ] == 5 , f"Expected first batch to have 5 rows, got { batch_call_args [0 ][0 ][0 ]} "
146+ assert batch_call_args [1 ][0 ][0 ] == 1 , f"Expected second batch to have 1 row, got { batch_call_args [1 ][0 ][0 ]} "
147+
148+ # Exactly one rollout_id should be called twice, others called once
149+ call_count_values = list (call_counts .values ())
150+ assert (
151+ call_count_values .count (2 ) == 1
152+ ), f"Expected exactly 1 rollout_id to be called twice, got counts: { dict (call_counts )} "
153+ assert (
154+ call_count_values .count (1 ) == 4
155+ ), f"Expected exactly 4 rollout_ids to be called once, got counts: { dict (call_counts )} "
156+
157+ print ("✅ All mock-based assertions passed! Retry mechanism is working correctly." )
0 commit comments