Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion eval_protocol/benchmarks/test_tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from typing import Any, Dict, List

from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
import litellm
from vendor.tau2.data_model.message import (
AssistantMessage,
SystemMessage,
Expand Down Expand Up @@ -87,6 +88,12 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
mode="pointwise",
max_concurrent_rollouts=50,
server_script_path=_get_server_script_path(),
exception_handler_config=ExceptionHandlerConfig(
retryable_exceptions={
litellm.RateLimitError,
litellm.APIConnectionError,
}
),
)
def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
"""
Expand Down
4 changes: 4 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

Expand All @@ -16,4 +17,7 @@
"default_dataset_adapter",
"RolloutProcessorConfig",
"evaluation_test",
"ExceptionHandlerConfig",
"BackoffConfig",
"get_default_exception_handler_config",
]
23 changes: 11 additions & 12 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
rollout_processor_with_retry,
sanitize_filename,
)
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci

from ..common_utils import load_jsonl
Expand Down Expand Up @@ -248,6 +249,7 @@ def evaluation_test( # noqa: C901
mode: EvaluationTestMode = "pointwise",
combine_datasets: bool = True,
logger: Optional[DatasetLogger] = None,
exception_handler_config: Optional[ExceptionHandlerConfig] = None,
) -> Callable[
[TestFunction],
TestFunction,
Expand Down Expand Up @@ -312,6 +314,8 @@ def evaluation_test( # noqa: C901
"groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo).
"all" applies test function to the whole dataset.
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
exception_handler_config: Configuration for exception handling and backoff retry logic.
If not provided, a default configuration will be used with common retryable exceptions.
"""

active_logger: DatasetLogger = logger if logger else default_logger
Expand Down Expand Up @@ -557,8 +561,9 @@ def _log_eval_error(
steps=steps,
logger=active_logger,
kwargs=rollout_processor_kwargs or {},
exception_handler_config=exception_handler_config,
)
max_retry = int(os.getenv("EP_MAX_RETRY", "0"))

for i in range(num_runs):
# Regenerate outputs each run by deep-copying the pristine dataset
# so model responses are not reused across runs.
Expand Down Expand Up @@ -598,9 +603,7 @@ async def _execute_with_semaphore(row):
return result

# Use wrapper that handles retry logic internally
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, max_retry
):
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
tasks.append(asyncio.create_task(_execute_with_semaphore(row)))

results = await asyncio.gather(*tasks)
Expand All @@ -623,11 +626,9 @@ async def _execute_with_semaphore(row):
)
lst = []

async def _collect_result(config, lst, max_retry):
async def _collect_result(config, lst):
result = []
async for row in rollout_processor_with_retry(
rollout_processor, lst, config, max_retry
):
async for row in rollout_processor_with_retry(rollout_processor, lst, config):
result.append(row)
return result

Expand All @@ -639,7 +640,7 @@ async def _collect_result(config, lst, max_retry):
)
copied_row.input_metadata.completion_params = cp
lst.append(copied_row)
tasks.append(asyncio.create_task(_collect_result(config, lst, max_retry)))
tasks.append(asyncio.create_task(_collect_result(config, lst)))
rollout_results = await asyncio.gather(*tasks)
for result in rollout_results:
for row in result:
Expand All @@ -656,9 +657,7 @@ async def _collect_result(config, lst, max_retry):
else:
# Batch mode: collect all results first, then evaluate (no pipelining)
input_dataset = []
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, max_retry
):
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
input_dataset.append(row)
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
Expand Down
126 changes: 126 additions & 0 deletions eval_protocol/pytest/exception_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Exception handling configuration for rollout processors with backoff retry logic.
"""

import os
from dataclasses import dataclass, field
from typing import Callable, Set, Type, Union

import backoff


import requests
import httpx

# Default exceptions that should be retried with backoff
DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = {
# Standard library exceptions
ConnectionError,
TimeoutError,
OSError, # Covers network-related OS errors
# Requests library exceptions
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.HTTPError,
requests.exceptions.RequestException,
# HTTPX library exceptions
httpx.ConnectError,
httpx.TimeoutException,
httpx.NetworkError,
httpx.RemoteProtocolError,
}


@dataclass
class BackoffConfig:
"""Configuration for backoff behavior."""

# Backoff strategy: 'expo' for exponential, 'constant' for constant delay
strategy: str = "expo"

# Base delay in seconds
base_delay: float = 1.0

# Maximum delay in seconds
max_delay: float = 60.0

# Maximum number of retry attempts
max_tries: int = 3

# Jitter: adds randomness to backoff delays (None = no jitter for predictable timing)
jitter: Union[None, Callable] = None

# Factor for exponential backoff (only used if strategy == 'expo')
factor: float = 2.0

# Whether to raise the exception when giving up (instead of returning it)
raise_on_giveup: bool = True

# Optional custom giveup function - if provided, overrides the default exception handling logic
giveup_func: Callable[[Exception], bool] = lambda e: False

def get_backoff_decorator(self, exceptions: Set[Type[Exception]]):
"""Get the appropriate backoff decorator based on configuration."""
if not exceptions:
# If no exceptions specified, return a no-op decorator
def no_op_decorator(func):
return func

return no_op_decorator

if self.strategy == "expo":
return backoff.on_exception(
backoff.expo,
tuple(exceptions),
max_tries=self.max_tries,
base=self.base_delay,
max_value=self.max_delay,
factor=self.factor,
jitter=self.jitter,
giveup=self.giveup_func,
raise_on_giveup=self.raise_on_giveup,
)
elif self.strategy == "constant":
return backoff.on_exception(
backoff.constant,
tuple(exceptions),
max_tries=self.max_tries,
interval=self.base_delay,
jitter=self.jitter,
giveup=self.giveup_func,
raise_on_giveup=self.raise_on_giveup,
)
else:
raise ValueError(f"Unknown backoff strategy: {self.strategy}")


@dataclass
class ExceptionHandlerConfig:
"""Configuration for exception handling in rollout processors."""

# Exceptions that should be retried using backoff
retryable_exceptions: Set[Type[Exception]] = field(default_factory=lambda: DEFAULT_RETRYABLE_EXCEPTIONS.copy())

# Backoff configuration
backoff_config: BackoffConfig = field(default_factory=BackoffConfig)

def __post_init__(self):
"""Automatically apply environment variable overrides after initialization."""
# Override backoff settings from environment variables
if "EP_MAX_RETRY" in os.environ:
max_retry = int(os.environ["EP_MAX_RETRY"])
if max_retry > 0:
self.backoff_config.max_tries = max_retry

if "EP_FAIL_ON_MAX_RETRY" in os.environ:
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()
self.backoff_config.raise_on_giveup = fail_on_max_retry != "false"

def get_backoff_decorator(self):
"""Get the backoff decorator configured for this exception handler."""
return self.backoff_config.get_backoff_decorator(self.retryable_exceptions)


def get_default_exception_handler_config() -> ExceptionHandlerConfig:
"""Get a fresh default exception handler configuration."""
return ExceptionHandlerConfig()
4 changes: 4 additions & 0 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger

from ..models import CompletionParams, EvaluationRow, Message
from .exception_config import ExceptionHandlerConfig

ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
DatasetPathParam = str
Expand Down Expand Up @@ -47,3 +48,6 @@ class RolloutProcessorConfig:
steps: int = 30 # max number of rollout steps
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor
exception_handler_config: Optional[ExceptionHandlerConfig] = (
None # configuration for exception handling with backoff
)
Loading
Loading