From 8ee3e460321f2f66daa626f50fd79a98f8fd2c3c Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 19 Aug 2025 16:02:51 -0700 Subject: [PATCH 1/2] Added Configurable Exception Handler --- .../benchmarks/test_tau_bench_retail.py | 9 +- eval_protocol/pytest/__init__.py | 4 + eval_protocol/pytest/evaluation_test.py | 15 +- eval_protocol/pytest/exception_config.py | 126 +++++++++ eval_protocol/pytest/types.py | 4 + eval_protocol/pytest/utils.py | 138 +++++----- pyproject.toml | 1 + tests/test_retry_mechanism.py | 243 +++++++++++++++++- 8 files changed, 453 insertions(+), 87 deletions(-) create mode 100644 eval_protocol/pytest/exception_config.py diff --git a/eval_protocol/benchmarks/test_tau_bench_retail.py b/eval_protocol/benchmarks/test_tau_bench_retail.py index a47d1520..0db242f1 100644 --- a/eval_protocol/benchmarks/test_tau_bench_retail.py +++ b/eval_protocol/benchmarks/test_tau_bench_retail.py @@ -11,8 +11,9 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor +import litellm from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -87,6 +88,12 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu mode="pointwise", max_concurrent_rollouts=50, server_script_path=_get_server_script_path(), + exception_handler_config=ExceptionHandlerConfig( + retryable_exceptions={ + litellm.RateLimitError, + litellm.APIConnectionError, + } + ), ) def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow: """ diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 171fa3dc..b64d8ef2 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -4,6 +4,7 @@ from .default_no_op_rollout_processor import NoOpRolloutProcessor from .default_single_turn_rollout_process import SingleTurnRolloutProcessor from .evaluation_test import evaluation_test +from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig @@ -16,4 +17,7 @@ "default_dataset_adapter", "RolloutProcessorConfig", "evaluation_test", + "ExceptionHandlerConfig", + "BackoffConfig", + "get_default_exception_handler_config", ] diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index ef1a70a7..34bd5d56 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -51,6 +51,7 @@ rollout_processor_with_retry, sanitize_filename, ) +from eval_protocol.pytest.exception_config import ExceptionHandlerConfig from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci from ..common_utils import load_jsonl @@ -76,6 +77,7 @@ def evaluation_test( # noqa: C901 mode: EvaluationTestMode = "batch", combine_datasets: bool = True, logger: Optional[DatasetLogger] = None, + exception_handler_config: Optional[ExceptionHandlerConfig] = None, ) -> Callable[ [TestFunction], TestFunction, @@ -140,6 +142,8 @@ def evaluation_test( # noqa: C901 full dataset. "pointwise" applies test function to each row. If your evaluation requires the full rollout of all rows to compute the score, use logger: DatasetLogger to use for logging. If not provided, a default logger will be used. + exception_handler_config: Configuration for exception handling and backoff retry logic. + If not provided, a default configuration will be used with common retryable exceptions. """ active_logger: DatasetLogger = logger if logger else default_logger @@ -365,10 +369,9 @@ def _log_eval_error( steps=steps, logger=active_logger, kwargs=rollout_processor_kwargs or {}, + exception_handler_config=exception_handler_config, ) - max_retry = int(os.getenv("EP_MAX_RETRY", "0")) - for i in range(num_runs): # Regenerate outputs each run by deep-copying the pristine dataset # so model responses are not reused across runs. @@ -408,9 +411,7 @@ async def _execute_with_semaphore(row): return result # Use wrapper that handles retry logic internally - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, max_retry - ): + async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config): tasks.append(asyncio.create_task(_execute_with_semaphore(row))) results = await asyncio.gather(*tasks) @@ -420,9 +421,7 @@ async def _execute_with_semaphore(row): else: # Batch mode: collect all results first, then evaluate (no pipelining) input_dataset = [] - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, max_retry - ): + async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config): input_dataset.append(row) # NOTE: we will still evaluate errored rows (give users control over this) # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py new file mode 100644 index 00000000..5c195f4e --- /dev/null +++ b/eval_protocol/pytest/exception_config.py @@ -0,0 +1,126 @@ +""" +Exception handling configuration for rollout processors with backoff retry logic. +""" + +import os +from dataclasses import dataclass, field +from typing import Callable, Set, Type, Union + +import backoff + + +import requests +import httpx + +# Default exceptions that should be retried with backoff +DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = { + # Standard library exceptions + ConnectionError, + TimeoutError, + OSError, # Covers network-related OS errors + # Requests library exceptions + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.HTTPError, + requests.exceptions.RequestException, + # HTTPX library exceptions + httpx.ConnectError, + httpx.TimeoutException, + httpx.NetworkError, + httpx.RemoteProtocolError, +} + + +@dataclass +class BackoffConfig: + """Configuration for backoff behavior.""" + + # Backoff strategy: 'expo' for exponential, 'constant' for constant delay + strategy: str = "expo" + + # Base delay in seconds + base_delay: float = 1.0 + + # Maximum delay in seconds + max_delay: float = 60.0 + + # Maximum number of retry attempts + max_tries: int = 3 + + # Jitter: adds randomness to backoff delays (None = no jitter for predictable timing) + jitter: Union[None, Callable] = None + + # Factor for exponential backoff (only used if strategy == 'expo') + factor: float = 2.0 + + # Whether to raise the exception when giving up (instead of returning it) + raise_on_giveup: bool = True + + # Optional custom giveup function - if provided, overrides the default exception handling logic + giveup_func: Callable[[Exception], bool] = lambda e: False + + def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): + """Get the appropriate backoff decorator based on configuration.""" + if not exceptions: + # If no exceptions specified, return a no-op decorator + def no_op_decorator(func): + return func + + return no_op_decorator + + if self.strategy == "expo": + return backoff.on_exception( + backoff.expo, + tuple(exceptions), + max_tries=self.max_tries, + base=self.base_delay, + max_value=self.max_delay, + factor=self.factor, + jitter=self.jitter, + giveup=self.giveup_func, + raise_on_giveup=self.raise_on_giveup, + ) + elif self.strategy == "constant": + return backoff.on_exception( + backoff.constant, + tuple(exceptions), + max_tries=self.max_tries, + interval=self.base_delay, + jitter=self.jitter, + giveup=self.giveup_func, + raise_on_giveup=self.raise_on_giveup, + ) + else: + raise ValueError(f"Unknown backoff strategy: {self.strategy}") + + +@dataclass +class ExceptionHandlerConfig: + """Configuration for exception handling in rollout processors.""" + + # Exceptions that should be retried using backoff + retryable_exceptions: Set[Type[Exception]] = field(default_factory=lambda: DEFAULT_RETRYABLE_EXCEPTIONS.copy()) + + # Backoff configuration + backoff_config: BackoffConfig = field(default_factory=BackoffConfig) + + def __post_init__(self): + """Automatically apply environment variable overrides after initialization.""" + # Override backoff settings from environment variables + if "EP_MAX_RETRY" in os.environ: + max_retry = int(os.environ["EP_MAX_RETRY"]) + if max_retry > 0: + self.backoff_config.max_tries = max_retry + + if "EP_FAIL_ON_MAX_RETRY" in os.environ: + fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower() + self.backoff_config.raise_on_giveup = fail_on_max_retry != "false" + + def get_backoff_decorator(self): + """Get the backoff decorator configured for this exception handler.""" + return self.backoff_config.get_backoff_decorator(self.retryable_exceptions) + + +def get_default_exception_handler_config() -> ExceptionHandlerConfig: + """Get a fresh default exception handler configuration.""" + return ExceptionHandlerConfig() diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 8a3be489..7be70be9 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -10,6 +10,7 @@ from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from ..models import CompletionParams, EvaluationRow, Message +from .exception_config import ExceptionHandlerConfig ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct DatasetPathParam = str @@ -50,3 +51,6 @@ class RolloutProcessorConfig: steps: int = 30 # max number of rollout steps logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor + exception_handler_config: Optional[ExceptionHandlerConfig] = ( + None # configuration for exception handling with backoff + ) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index e0b8328a..bed1ac89 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -15,6 +15,7 @@ InputMessagesParam, RolloutProcessorConfig, ) +from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config def execute_function(func: Callable, **kwargs) -> Any: @@ -239,94 +240,83 @@ async def rollout_processor_with_retry( rollout_processor: RolloutProcessor, fresh_dataset: List[EvaluationRow], config: RolloutProcessorConfig, - max_retry: int, ): """ - Wrapper around rollout_processor that handles retry logic internally. - Uses async queue pattern to yield results immediately as they become available. - Yields both successful and failed results, leaving it up to the user to handle them in test_func. - """ + Wrapper around rollout_processor that handles retry logic using the Python backoff library. - try: - queue: asyncio.Queue[EvaluationRow] = asyncio.Queue() - retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} - failed_permanently = [] + Provides configurable exception handling with automatic retry for specific exception types: + - Retryable exceptions (e.g., ConnectionError, TimeoutError) are automatically retried with backoff + - Fail-fast exceptions (e.g., ValueError, TypeError) are not retried and return immediately + - Unknown exceptions can be configured to either re-raise or return as failed rows - async def retry_handler(failed_row: EvaluationRow): - rollout_id = failed_row.execution_metadata.rollout_id - current_attempts = retry_counts.get(rollout_id, 0) + The backoff behavior (exponential/constant, delays, max attempts) is fully configurable + through the ExceptionHandlerConfig in the RolloutProcessorConfig. - if current_attempts >= max_retry: - assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, ( - f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" - ) - failed_permanently.append(failed_row) - await queue.put(failed_row) # put failed row on queue - return + Yields results as they complete, allowing for concurrent processing while handling + retries transparently in the background. + """ - retry_counts[rollout_id] = current_attempts + 1 + # Use provided exception handler config or fall back to default + # Environment variable overrides are automatically applied in __post_init__ + exception_config = config.exception_handler_config or get_default_exception_handler_config() - # add kwargs start_server=False to config so we don't start new MCP server + try: + # Create initial batch of tasks (preserves indexing for mock processors) + try: + base_tasks = rollout_processor(fresh_dataset, config) + except Exception as e: + print(f"āŒ Rollout processor failed to initialize: {e}") + raise e + + # Create a single backoff-decorated retry function that can be reused + @exception_config.get_backoff_decorator() + async def execute_row_with_backoff_retry(row: EvaluationRow): + """Execute rollout for a single row with backoff retry.""" retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) + retry_tasks = rollout_processor([row], retry_config) + return await retry_tasks[0] - retry_tasks = rollout_processor([failed_row], retry_config) + async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow: + """Execute a single row task with backoff retry.""" try: - retry_result = await retry_tasks[0] - retry_result.rollout_status.status = RolloutStatus.Status.FINISHED - await queue.put(retry_result) + # Try original task first + result = await task + result.rollout_status.status = RolloutStatus.Status.FINISHED + return result except Exception as e: - failed_row.rollout_status.status = RolloutStatus.Status.ERROR - failed_row.rollout_status.termination_reason = str(e) - asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry - - async def initial_processor(): - """Process initial batch and spawn retries for failures""" - # catch any task creation errors and raise them immediately, i.e. port already in use - try: - base_tasks = rollout_processor(fresh_dataset, config) - except Exception as e: - print(f"āŒ Rollout processor failed to initialize: {e}") - raise e - - pending = set(base_tasks) - - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - - for task in done: - task_index = base_tasks.index(task) - + # NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails. + # Check if this exception should be retried + is_retryable = any(isinstance(e, exc_type) for exc_type in exception_config.retryable_exceptions) + giveup_func = exception_config.backoff_config.giveup_func + should_giveup = giveup_func and giveup_func(e) + + if is_retryable and not should_giveup: + # Use shared backoff function for retryable exceptions try: - result = await task + result = await execute_row_with_backoff_retry(row) result.rollout_status.status = RolloutStatus.Status.FINISHED - await queue.put(result) - except Exception as e: - failed_row = fresh_dataset[task_index] - failed_row.rollout_status.status = RolloutStatus.Status.ERROR - failed_row.rollout_status.termination_reason = str(e) - asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task - - processor_task = asyncio.create_task(initial_processor()) - - # yield results as they become available - completed_count = 0 - total_expected = len(fresh_dataset) - - while completed_count < total_expected: - finished_row = await queue.get() - - # only permanent failure rows are put on the queue, so we can check for them here - if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR: - if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false": - raise RuntimeError( - f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" - ) - - completed_count += 1 - yield finished_row - - await processor_task # explicitly wait for task completion and catch any exceptions + return result + except Exception as retry_error: + # Backoff gave up + row.rollout_status.status = RolloutStatus.Status.ERROR + row.rollout_status.termination_reason = str(retry_error) + return row + else: + # Non-retryable exception - fail immediately + row.rollout_status.status = RolloutStatus.Status.ERROR + row.rollout_status.termination_reason = str(e) + return row + + # Process all tasks concurrently with backoff retry + retry_tasks = [ + asyncio.create_task(execute_row_with_backoff(task, fresh_dataset[i])) for i, task in enumerate(base_tasks) + ] + + # Yield results as they complete + for task in asyncio.as_completed(retry_tasks): + result = await task + yield result finally: rollout_processor.cleanup() diff --git a/pyproject.toml b/pyproject.toml index 46b66d77..79b5867a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "fastapi>=0.116.1", "pytest>=6.0.0", "peewee>=3.18.2", + "backoff>=2.2.0", ] [project.urls] diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 8b55869f..6895ad55 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -7,7 +7,7 @@ import os from collections import Counter from typing import List -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest @@ -15,6 +15,8 @@ from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, BackoffConfig +import litellm class MockRolloutProcessorWithRetries(RolloutProcessor): @@ -59,7 +61,7 @@ async def process_single_row( print(f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'}: {row.execution_metadata.rollout_id}") if should_fail: - raise Exception("Simulated failure for testing") + raise ConnectionError("Simulated failure for testing") return row @@ -76,7 +78,6 @@ async def process_single_row( shared_processor = MockRolloutProcessorWithRetries() -@patch.dict(os.environ, {"EP_MAX_RETRY": "2"}) @evaluation_test( completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], input_messages=[ @@ -89,6 +90,7 @@ async def process_single_row( rollout_processor=shared_processor, num_runs=1, mode="pointwise", + exception_handler_config=ExceptionHandlerConfig(backoff_config=BackoffConfig(max_tries=3)), ) def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry.""" @@ -103,7 +105,6 @@ def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: return row -@patch.dict(os.environ, {"EP_MAX_RETRY": "2"}) def test_retry_mechanism_mock_verification(): """Test that verifies the retry mechanism worked by checking the mock calls""" # Get our mock tracker @@ -155,3 +156,237 @@ def test_retry_mechanism_mock_verification(): ) print("āœ… All mock-based assertions passed! Retry mechanism is working correctly.") + + +# Test 2: Fail-fast exceptions should not retry +class MockRolloutProcessorFailFast(RolloutProcessor): + """Mock processor that always raises ValueError (fail-fast exception)""" + + def __init__(self): + self.mock_tracker = Mock() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + self.mock_tracker.process_row_call(row.execution_metadata.rollout_id) + # Always raise ValueError (fail-fast exception) + raise ValueError("This should not be retried") + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_fail_fast = MockRolloutProcessorFailFast() + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[[Message(role="user", content="Test")]], + rollout_processor=shared_processor_fail_fast, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig(backoff_config=BackoffConfig(max_tries=4)), +) +def test_fail_fast_exceptions(row: EvaluationRow) -> EvaluationRow: + """Test that fail-fast exceptions like ValueError are not retried.""" + print( + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + ) + score = 1.0 if row.rollout_status.status == "finished" else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_fail_fast_verification(): + """Verify that fail-fast exceptions are not retried""" + mock_tracker = shared_processor_fail_fast.mock_tracker + + print("\nšŸ”„ FAIL-FAST TEST ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + # Debug: Print all the calls that were made + print(" Batch call args:", mock_tracker.batch_call.call_args_list) + print(" Process row call args:", mock_tracker.process_row_call.call_args_list) + + # Should have exactly 1 call (no retries for fail-fast exceptions) + assert mock_tracker.process_row_call.call_count == 1, ( + f"Expected 1 call for fail-fast exception, got {mock_tracker.process_row_call.call_count}" + ) + + # Should have exactly 1 batch call (no retry batches) + assert mock_tracker.batch_call.call_count == 1, f"Expected 1 batch call, got {mock_tracker.batch_call.call_count}" + + print("āœ… Fail-fast exception test passed! ValueError was not retried.") + + +# Test 3: Custom giveup function +class MockRolloutProcessorCustomGiveup(RolloutProcessor): + """Mock processor for testing custom giveup functions""" + + def __init__(self): + self.mock_tracker = Mock() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + self.mock_tracker.process_row_call(row.execution_metadata.rollout_id) + + # Raise real litellm exceptions based on task content + task_content = row.messages[0].content if row.messages else "" + if "429" in task_content: + raise litellm.RateLimitError( + "Rate limit exceeded", llm_provider="test", model="test-model" + ) # Should retry + else: + raise litellm.BadRequestError( + "Bad request", model="test-model", llm_provider="test" + ) # Should not retry + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_custom_giveup = MockRolloutProcessorCustomGiveup() + + +# Custom giveup function for litellm exceptions +def custom_http_giveup(e): + # Don't retry bad requests (400-level errors), but do retry rate limits (429) + if isinstance(e, litellm.BadRequestError): + return True # Give up immediately on bad requests + elif isinstance(e, litellm.RateLimitError): + return False # Retry rate limits with backoff + + return False # Retry everything else + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[ + [Message(role="user", content="Test 429")], # Should retry + [Message(role="user", content="Test 400")], # Should not retry + ], + rollout_processor=shared_processor_custom_giveup, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig( + retryable_exceptions={ + litellm.RateLimitError, + litellm.BadRequestError, + }, + backoff_config=BackoffConfig(max_tries=3, giveup_func=custom_http_giveup), + ), +) +def test_custom_giveup_function(row: EvaluationRow) -> EvaluationRow: + """Test custom giveup function behavior.""" + task_content = row.messages[0].content if row.messages else "" + print(f"šŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})") + score = 1.0 if row.rollout_status.status == "finished" else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_custom_giveup_verification(): + """Verify custom giveup function works correctly""" + mock_tracker = shared_processor_custom_giveup.mock_tracker + + print("\nšŸ”„ CUSTOM GIVEUP TEST ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + call_counts = Counter(rollout_ids) + + print(f" Call counts per rollout_id: {dict(call_counts)}") + + # Should have 5 calls: 1 for 400 error (giveup immediately), 4 for 429 error (1 original + 3 backoff) + assert mock_tracker.process_row_call.call_count == 5, ( + f"Expected 5 calls total, got {mock_tracker.process_row_call.call_count}" + ) + + # One rollout should be called 4 times (RateLimitError: 1 original + 3 backoff), one called once (BadRequestError: immediate giveup) + call_count_values = list(call_counts.values()) + assert call_count_values.count(4) == 1, ( + f"Expected 1 rollout with 4 calls (RateLimitError: 1 original + 3 backoff), got {call_count_values}" + ) + assert call_count_values.count(1) == 1, ( + f"Expected 1 rollout with 1 call (BadRequestError: immediate giveup), got {call_count_values}" + ) + + print("āœ… Custom giveup function test passed! HTTP status-based retry logic worked correctly.") + + +# Test 4: Simple giveup function - retry all exceptions but give up on 4xx +class MockRolloutProcessorSimpleGiveup(RolloutProcessor): + """Mock processor that raises BadRequestError""" + + def __init__(self): + self.mock_tracker = Mock() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + self.mock_tracker.batch_call(len(rows)) + + async def process_single_row(row: EvaluationRow) -> EvaluationRow: + self.mock_tracker.process_row_call(row.execution_metadata.rollout_id) + # Always raise BadRequestError (400) - should be caught by giveup + mock_response = Mock() + mock_response.status_code = 400 + error = litellm.BadRequestError("Bad request", model="test-model", llm_provider="test") + error.response = mock_response + raise error + + tasks = [asyncio.create_task(process_single_row(row)) for row in rows] + return tasks + + +shared_processor_simple_giveup = MockRolloutProcessorSimpleGiveup() + + +# Simple giveup function for 4xx errors +def simple_4xx_giveup(e): + if hasattr(e, "response") and hasattr(e.response, "status_code"): + status = e.response.status_code + return 400 <= status < 500 # Give up on all 4xx client errors + return False # Retry everything else + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[[Message(role="user", content="Test 400 giveup")]], + rollout_processor=shared_processor_simple_giveup, + num_runs=1, + mode="pointwise", + exception_handler_config=ExceptionHandlerConfig( + retryable_exceptions={Exception}, # Retry all exceptions + backoff_config=BackoffConfig(max_tries=5, giveup_func=simple_4xx_giveup), + ), +) +def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: + """Test that giveup function prevents retries immediately.""" + print( + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + ) + score = 1.0 if row.rollout_status.status == "finished" else 0.0 + row.evaluation_result = EvaluateResult(score=score) + return row + + +def test_simple_giveup_verification(): + """Verify that giveup function prevents retries.""" + mock_tracker = shared_processor_simple_giveup.mock_tracker + + print("\nšŸ”„ SIMPLE GIVEUP TEST ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + print(" Process row call args:", mock_tracker.process_row_call.call_args_list) + + # Should have exactly 1 call (giveup function should prevent retries) + assert mock_tracker.process_row_call.call_count == 1, ( + f"Expected 1 call due to giveup, got {mock_tracker.process_row_call.call_count}" + ) + + print("āœ… Simple giveup test passed! 4xx error was not retried due to giveup function.") From 41079034c7f12fd72905ef37e7424f24d28bf2fe Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 19 Aug 2025 16:08:48 -0700 Subject: [PATCH 2/2] uv lock --- uv.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/uv.lock b/uv.lock index 6758d24c..76310f48 100644 --- a/uv.lock +++ b/uv.lock @@ -1087,6 +1087,7 @@ dependencies = [ { name = "aiohttp" }, { name = "aiosqlite" }, { name = "anthropic" }, + { name = "backoff" }, { name = "dataclasses-json" }, { name = "datasets" }, { name = "deepdiff" }, @@ -1196,6 +1197,7 @@ requires-dist = [ { name = "aiohttp" }, { name = "aiosqlite" }, { name = "anthropic", specifier = ">=0.59.0" }, + { name = "backoff", specifier = ">=2.2.0" }, { name = "build", marker = "extra == 'dev'" }, { name = "dataclasses-json", specifier = ">=0.5.7" }, { name = "datasets" },