Skip to content

Commit bdd630c

Browse files
committed
test
1 parent 9cffbf5 commit bdd630c

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed

tests/pytest/test_retry_logic.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
Test suite for the individual rollout retry logic in evaluation_test.
3+
4+
Tests the new efficient retry system that retries individual rollouts immediately
5+
as they fail, rather than waiting for entire batches to complete.
6+
"""
7+
8+
import asyncio
9+
import os
10+
from typing import List
11+
from unittest.mock import patch
12+
13+
import pytest
14+
15+
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus
16+
from eval_protocol.pytest import evaluation_test
17+
from eval_protocol.pytest.types import RolloutProcessor, RolloutProcessorConfig
18+
19+
20+
class MockRetryRolloutProcessor:
21+
"""
22+
Mock rollout processor that simulates different rollout statuses.
23+
24+
On first call, returns rollouts with mixed statuses (finished, error, running).
25+
On retry calls, converts error/running rollouts to finished status.
26+
"""
27+
28+
def __init__(self):
29+
self.call_count = 0
30+
self.processed_rollout_ids = set()
31+
32+
async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
33+
"""Process rollouts with simulated statuses"""
34+
self.call_count += 1
35+
36+
for row in rows:
37+
# If this is a retry (rollout_id we've seen before), make it succeed
38+
if row.execution_metadata.rollout_id in self.processed_rollout_ids:
39+
row.rollout_status = RolloutStatus(status="finished")
40+
row.messages.append(
41+
Message(role="assistant", content=f"Retry success for {row.execution_metadata.rollout_id}")
42+
)
43+
else:
44+
# First time processing this logical rollout
45+
self.processed_rollout_ids.add(row.execution_metadata.rollout_id)
46+
47+
# Simulate different statuses based on content
48+
content = row.messages[0].content if row.messages else ""
49+
50+
if "should_finish" in content:
51+
# This one succeeds immediately
52+
row.rollout_status = RolloutStatus(status="finished")
53+
row.messages.append(Message(role="assistant", content="Success on first try"))
54+
elif "should_error" in content:
55+
# This one errors on first try, should be retried
56+
row.rollout_status = RolloutStatus(status="error", termination_reason="Simulated error")
57+
row.messages.append(Message(role="assistant", content="Error on first try"))
58+
elif "should_be_running" in content:
59+
# This one is left in running state, should be retried
60+
row.rollout_status = RolloutStatus(status="running")
61+
row.messages.append(Message(role="assistant", content="Left running, needs retry"))
62+
else:
63+
# Default to finished
64+
row.rollout_status = RolloutStatus(status="finished")
65+
row.messages.append(Message(role="assistant", content="Default success"))
66+
67+
yield row
68+
69+
70+
class MockAlwaysFailRolloutProcessor:
71+
"""Mock rollout processor that always fails, to test retry exhaustion"""
72+
73+
def __init__(self):
74+
self.call_count = 0
75+
76+
async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
77+
"""Always return error status to test retry exhaustion"""
78+
self.call_count += 1
79+
80+
for row in rows:
81+
row.rollout_status = RolloutStatus(
82+
status="error", termination_reason=f"Persistent failure (attempt {self.call_count})"
83+
)
84+
row.messages.append(Message(role="assistant", content=f"Failed attempt {self.call_count}"))
85+
yield row
86+
87+
88+
# Create instances that will be shared across test functions
89+
mock_retry_processor = MockRetryRolloutProcessor()
90+
mock_always_fail_processor = MockAlwaysFailRolloutProcessor()
91+
92+
93+
# Set environment variable at module level for this test
94+
@patch.dict(os.environ, {"EP_MAX_RETRY": "3"})
95+
@evaluation_test(
96+
input_messages=[
97+
[Message(role="user", content="Test case that should_finish immediately")],
98+
[Message(role="user", content="Test case that should_error on first try")],
99+
[Message(role="user", content="Test case that should_be_running and need retry")],
100+
],
101+
model=["dummy/local-model"],
102+
rollout_processor=mock_retry_processor,
103+
mode="batch",
104+
num_runs=1,
105+
)
106+
def test_retry_mixed_statuses_batch_mode(rows: List[EvaluationRow]) -> List[EvaluationRow]:
107+
"""
108+
Test that retry logic works with mixed rollout statuses in batch mode.
109+
110+
Tests:
111+
- One rollout finishes immediately (should not retry)
112+
- One rollout has error status (should retry and succeed)
113+
- One rollout has running status (should retry and succeed)
114+
"""
115+
# Reset processor state at the beginning
116+
mock_retry_processor.call_count = 0
117+
mock_retry_processor.processed_rollout_ids.clear()
118+
119+
# Verify we got all our test cases
120+
assert len(rows) == 3
121+
122+
# Verify all rollouts ended up in finished state after retries
123+
for row in rows:
124+
assert row.rollout_status is not None
125+
assert row.rollout_status.status == "finished", f"Row should be finished but was {row.rollout_status.status}"
126+
127+
# Check that retry cases got the retry response
128+
content = row.messages[0].content
129+
if "should_error" in content or "should_be_running" in content:
130+
# These should have been retried
131+
assistant_messages = [msg for msg in row.messages if msg.role == "assistant"]
132+
assert len(assistant_messages) >= 1
133+
assert "Retry success" in assistant_messages[-1].content
134+
135+
# Set evaluation results
136+
for row in rows:
137+
row.evaluation_result = EvaluateResult(score=1.0, reason="All rollouts completed successfully")
138+
139+
return rows
140+
141+
142+
@patch.dict(os.environ, {"EP_MAX_RETRY": "3"})
143+
@evaluation_test(
144+
input_messages=[
145+
[Message(role="user", content="Test pointwise should_error")],
146+
[Message(role="user", content="Test pointwise should_be_running")],
147+
[Message(role="user", content="Test pointwise should_finish")],
148+
],
149+
model=["dummy/local-model"],
150+
rollout_processor=mock_retry_processor,
151+
mode="pointwise",
152+
num_runs=1,
153+
)
154+
def test_retry_mixed_statuses_pointwise_mode(row: EvaluationRow) -> EvaluationRow:
155+
"""
156+
Test that retry logic works with mixed rollout statuses in pointwise mode.
157+
158+
Each rollout is processed individually and should retry if not finished.
159+
"""
160+
# Verify rollout ended up in finished state after any needed retries
161+
assert row.rollout_status is not None
162+
assert row.rollout_status.status == "finished", f"Row should be finished but was {row.rollout_status.status}"
163+
164+
# Set evaluation result
165+
row.evaluation_result = EvaluateResult(score=1.0, reason="Rollout completed successfully")
166+
167+
return row
168+
169+
170+
def test_retry_exhaustion_should_fail():
171+
"""
172+
Test that rollout process fails when max retries are exceeded.
173+
174+
Sets EP_MAX_RETRY=2 and uses a processor that always fails.
175+
Should fail after 3 total attempts (initial + 2 retries).
176+
"""
177+
178+
# Set max retries environment variable
179+
with patch.dict(os.environ, {"EP_MAX_RETRY": "2"}):
180+
181+
@evaluation_test(
182+
input_messages=[
183+
[Message(role="user", content="This will always fail")],
184+
],
185+
model=["dummy/local-model"],
186+
rollout_processor=mock_always_fail_processor,
187+
mode="batch",
188+
num_runs=1,
189+
)
190+
def failing_evaluation_test(rows: List[EvaluationRow]) -> List[EvaluationRow]:
191+
# This should never be reached due to rollout failures
192+
for row in rows:
193+
row.evaluation_result = EvaluateResult(score=1.0, reason="Should not reach here")
194+
return rows
195+
196+
# The evaluation_test should raise RuntimeError due to retry exhaustion
197+
with pytest.raises(RuntimeError) as exc_info:
198+
# Run the test directly to trigger the retry logic
199+
import asyncio
200+
201+
# Reset the processor call count
202+
mock_always_fail_processor.call_count = 0
203+
204+
# Create test data
205+
rows = [EvaluationRow(messages=[Message(role="user", content="This will always fail")])]
206+
207+
# This should fail after 3 attempts (initial + 2 retries)
208+
asyncio.run(failing_evaluation_test(rows))
209+
210+
# Verify the error message mentions retry exhaustion
211+
error_msg = str(exc_info.value)
212+
assert "failed after 2 retries" in error_msg.lower() or "retry" in error_msg.lower()
213+
214+
# Verify the processor was called multiple times (initial + retries)
215+
assert (
216+
mock_always_fail_processor.call_count >= 3
217+
), f"Expected >= 3 calls, got {mock_always_fail_processor.call_count}"
218+
219+
220+
def test_no_retries_when_max_retry_zero():
221+
"""
222+
Test that no retries happen when EP_MAX_RETRY=0 (default).
223+
224+
Even with failing rollouts, should fail immediately without retries.
225+
"""
226+
227+
# Ensure EP_MAX_RETRY is 0 (default)
228+
with patch.dict(os.environ, {"EP_MAX_RETRY": "0"}):
229+
230+
@evaluation_test(
231+
input_messages=[
232+
[Message(role="user", content="This will fail once and not retry")],
233+
],
234+
model=["dummy/local-model"],
235+
rollout_processor=mock_always_fail_processor,
236+
mode="batch",
237+
num_runs=1,
238+
)
239+
def no_retry_evaluation_test(rows: List[EvaluationRow]) -> List[EvaluationRow]:
240+
# This should never be reached due to immediate failure
241+
for row in rows:
242+
row.evaluation_result = EvaluateResult(score=1.0, reason="Should not reach here")
243+
return rows
244+
245+
# Should fail immediately without retries
246+
with pytest.raises(RuntimeError) as exc_info:
247+
# Reset processor call count
248+
mock_always_fail_processor.call_count = 0
249+
250+
# Create test data
251+
rows = [EvaluationRow(messages=[Message(role="user", content="This will fail once and not retry")])]
252+
253+
# Should fail after just 1 attempt
254+
asyncio.run(no_retry_evaluation_test(rows))
255+
256+
# Verify only 1 attempt was made (no retries)
257+
assert (
258+
mock_always_fail_processor.call_count == 1
259+
), f"Expected 1 call, got {mock_always_fail_processor.call_count}"
260+
261+
262+
@pytest.mark.asyncio
263+
async def test_concurrent_retry_efficiency():
264+
"""
265+
Test that retries happen efficiently with proper concurrency.
266+
267+
Verifies that successful rollouts don't wait for failing ones,
268+
and that retries start immediately as failures are detected.
269+
"""
270+
271+
class TimingMockProcessor:
272+
"""Mock processor that tracks timing of rollout processing"""
273+
274+
def __init__(self):
275+
self.processing_times = {}
276+
self.start_times = {}
277+
278+
async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
279+
import time
280+
281+
for row in rows:
282+
rollout_id = row.execution_metadata.rollout_id
283+
self.start_times[rollout_id] = time.time()
284+
285+
# Simulate different processing times
286+
content = row.messages[0].content if row.messages else ""
287+
288+
if "slow_success" in content:
289+
# Slow but successful rollout
290+
await asyncio.sleep(0.1)
291+
row.rollout_status = RolloutStatus(status="finished")
292+
row.messages.append(Message(role="assistant", content="Slow success"))
293+
elif "fast_fail" in content:
294+
# Fast failure that should retry quickly
295+
await asyncio.sleep(0.01)
296+
if rollout_id not in self.processing_times:
297+
# First attempt - fail
298+
row.rollout_status = RolloutStatus(status="error", termination_reason="Fast failure")
299+
row.messages.append(Message(role="assistant", content="Fast failure"))
300+
self.processing_times[rollout_id] = time.time()
301+
else:
302+
# Retry - succeed
303+
row.rollout_status = RolloutStatus(status="finished")
304+
row.messages.append(Message(role="assistant", content="Fast retry success"))
305+
306+
yield row
307+
308+
timing_processor = TimingMockProcessor()
309+
310+
with patch.dict(os.environ, {"EP_MAX_RETRY": "3"}):
311+
312+
@evaluation_test(
313+
input_messages=[
314+
[Message(role="user", content="slow_success - this takes longer but succeeds")],
315+
[Message(role="user", content="fast_fail - this fails fast then retries")],
316+
],
317+
model=["dummy/local-model"],
318+
rollout_processor=timing_processor,
319+
mode="batch",
320+
num_runs=1,
321+
)
322+
def timing_test(rows: List[EvaluationRow]) -> List[EvaluationRow]:
323+
# Both should succeed eventually
324+
assert len(rows) == 2
325+
for row in rows:
326+
assert row.rollout_status.status == "finished"
327+
row.evaluation_result = EvaluateResult(score=1.0, reason="Success")
328+
return rows
329+
330+
# Create test data
331+
rows = [
332+
EvaluationRow(messages=[Message(role="user", content="slow_success - this takes longer but succeeds")]),
333+
EvaluationRow(messages=[Message(role="user", content="fast_fail - this fails fast then retries")]),
334+
]
335+
336+
# Run the test - should complete successfully with proper retry timing
337+
result = await timing_test(rows)
338+
assert len(result) == 2
339+
340+
# Verify that the fast-failing rollout was processed multiple times due to retry
341+
fast_fail_processed = any("fast_fail" in row.messages[0].content for row in result)
342+
assert fast_fail_processed, "Fast-failing rollout should have been processed"

0 commit comments

Comments
 (0)