Skip to content

Commit 8ee3e46

Browse files
committed
Added Configurable Exception Handler
1 parent a473f0c commit 8ee3e46

File tree

8 files changed

+453
-87
lines changed

8 files changed

+453
-87
lines changed

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from typing import Any, Dict, List
1212

1313
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
14-
from eval_protocol.pytest import evaluation_test
14+
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
1515
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
16+
import litellm
1617
from vendor.tau2.data_model.message import (
1718
AssistantMessage,
1819
SystemMessage,
@@ -87,6 +88,12 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
8788
mode="pointwise",
8889
max_concurrent_rollouts=50,
8990
server_script_path=_get_server_script_path(),
91+
exception_handler_config=ExceptionHandlerConfig(
92+
retryable_exceptions={
93+
litellm.RateLimitError,
94+
litellm.APIConnectionError,
95+
}
96+
),
9097
)
9198
def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
9299
"""

eval_protocol/pytest/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
66
from .evaluation_test import evaluation_test
7+
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
78
from .rollout_processor import RolloutProcessor
89
from .types import RolloutProcessorConfig
910

@@ -16,4 +17,7 @@
1617
"default_dataset_adapter",
1718
"RolloutProcessorConfig",
1819
"evaluation_test",
20+
"ExceptionHandlerConfig",
21+
"BackoffConfig",
22+
"get_default_exception_handler_config",
1923
]

eval_protocol/pytest/evaluation_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
rollout_processor_with_retry,
5252
sanitize_filename,
5353
)
54+
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
5455
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
5556

5657
from ..common_utils import load_jsonl
@@ -76,6 +77,7 @@ def evaluation_test( # noqa: C901
7677
mode: EvaluationTestMode = "batch",
7778
combine_datasets: bool = True,
7879
logger: Optional[DatasetLogger] = None,
80+
exception_handler_config: Optional[ExceptionHandlerConfig] = None,
7981
) -> Callable[
8082
[TestFunction],
8183
TestFunction,
@@ -140,6 +142,8 @@ def evaluation_test( # noqa: C901
140142
full dataset. "pointwise" applies test function to each row. If your evaluation requires
141143
the full rollout of all rows to compute the score, use
142144
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
145+
exception_handler_config: Configuration for exception handling and backoff retry logic.
146+
If not provided, a default configuration will be used with common retryable exceptions.
143147
"""
144148

145149
active_logger: DatasetLogger = logger if logger else default_logger
@@ -365,10 +369,9 @@ def _log_eval_error(
365369
steps=steps,
366370
logger=active_logger,
367371
kwargs=rollout_processor_kwargs or {},
372+
exception_handler_config=exception_handler_config,
368373
)
369374

370-
max_retry = int(os.getenv("EP_MAX_RETRY", "0"))
371-
372375
for i in range(num_runs):
373376
# Regenerate outputs each run by deep-copying the pristine dataset
374377
# so model responses are not reused across runs.
@@ -408,9 +411,7 @@ async def _execute_with_semaphore(row):
408411
return result
409412

410413
# Use wrapper that handles retry logic internally
411-
async for row in rollout_processor_with_retry(
412-
rollout_processor, fresh_dataset, config, max_retry
413-
):
414+
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
414415
tasks.append(asyncio.create_task(_execute_with_semaphore(row)))
415416

416417
results = await asyncio.gather(*tasks)
@@ -420,9 +421,7 @@ async def _execute_with_semaphore(row):
420421
else:
421422
# Batch mode: collect all results first, then evaluate (no pipelining)
422423
input_dataset = []
423-
async for row in rollout_processor_with_retry(
424-
rollout_processor, fresh_dataset, config, max_retry
425-
):
424+
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
426425
input_dataset.append(row)
427426
# NOTE: we will still evaluate errored rows (give users control over this)
428427
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Exception handling configuration for rollout processors with backoff retry logic.
3+
"""
4+
5+
import os
6+
from dataclasses import dataclass, field
7+
from typing import Callable, Set, Type, Union
8+
9+
import backoff
10+
11+
12+
import requests
13+
import httpx
14+
15+
# Default exceptions that should be retried with backoff
16+
DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = {
17+
# Standard library exceptions
18+
ConnectionError,
19+
TimeoutError,
20+
OSError, # Covers network-related OS errors
21+
# Requests library exceptions
22+
requests.exceptions.ConnectionError,
23+
requests.exceptions.Timeout,
24+
requests.exceptions.HTTPError,
25+
requests.exceptions.RequestException,
26+
# HTTPX library exceptions
27+
httpx.ConnectError,
28+
httpx.TimeoutException,
29+
httpx.NetworkError,
30+
httpx.RemoteProtocolError,
31+
}
32+
33+
34+
@dataclass
35+
class BackoffConfig:
36+
"""Configuration for backoff behavior."""
37+
38+
# Backoff strategy: 'expo' for exponential, 'constant' for constant delay
39+
strategy: str = "expo"
40+
41+
# Base delay in seconds
42+
base_delay: float = 1.0
43+
44+
# Maximum delay in seconds
45+
max_delay: float = 60.0
46+
47+
# Maximum number of retry attempts
48+
max_tries: int = 3
49+
50+
# Jitter: adds randomness to backoff delays (None = no jitter for predictable timing)
51+
jitter: Union[None, Callable] = None
52+
53+
# Factor for exponential backoff (only used if strategy == 'expo')
54+
factor: float = 2.0
55+
56+
# Whether to raise the exception when giving up (instead of returning it)
57+
raise_on_giveup: bool = True
58+
59+
# Optional custom giveup function - if provided, overrides the default exception handling logic
60+
giveup_func: Callable[[Exception], bool] = lambda e: False
61+
62+
def get_backoff_decorator(self, exceptions: Set[Type[Exception]]):
63+
"""Get the appropriate backoff decorator based on configuration."""
64+
if not exceptions:
65+
# If no exceptions specified, return a no-op decorator
66+
def no_op_decorator(func):
67+
return func
68+
69+
return no_op_decorator
70+
71+
if self.strategy == "expo":
72+
return backoff.on_exception(
73+
backoff.expo,
74+
tuple(exceptions),
75+
max_tries=self.max_tries,
76+
base=self.base_delay,
77+
max_value=self.max_delay,
78+
factor=self.factor,
79+
jitter=self.jitter,
80+
giveup=self.giveup_func,
81+
raise_on_giveup=self.raise_on_giveup,
82+
)
83+
elif self.strategy == "constant":
84+
return backoff.on_exception(
85+
backoff.constant,
86+
tuple(exceptions),
87+
max_tries=self.max_tries,
88+
interval=self.base_delay,
89+
jitter=self.jitter,
90+
giveup=self.giveup_func,
91+
raise_on_giveup=self.raise_on_giveup,
92+
)
93+
else:
94+
raise ValueError(f"Unknown backoff strategy: {self.strategy}")
95+
96+
97+
@dataclass
98+
class ExceptionHandlerConfig:
99+
"""Configuration for exception handling in rollout processors."""
100+
101+
# Exceptions that should be retried using backoff
102+
retryable_exceptions: Set[Type[Exception]] = field(default_factory=lambda: DEFAULT_RETRYABLE_EXCEPTIONS.copy())
103+
104+
# Backoff configuration
105+
backoff_config: BackoffConfig = field(default_factory=BackoffConfig)
106+
107+
def __post_init__(self):
108+
"""Automatically apply environment variable overrides after initialization."""
109+
# Override backoff settings from environment variables
110+
if "EP_MAX_RETRY" in os.environ:
111+
max_retry = int(os.environ["EP_MAX_RETRY"])
112+
if max_retry > 0:
113+
self.backoff_config.max_tries = max_retry
114+
115+
if "EP_FAIL_ON_MAX_RETRY" in os.environ:
116+
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()
117+
self.backoff_config.raise_on_giveup = fail_on_max_retry != "false"
118+
119+
def get_backoff_decorator(self):
120+
"""Get the backoff decorator configured for this exception handler."""
121+
return self.backoff_config.get_backoff_decorator(self.retryable_exceptions)
122+
123+
124+
def get_default_exception_handler_config() -> ExceptionHandlerConfig:
125+
"""Get a fresh default exception handler configuration."""
126+
return ExceptionHandlerConfig()

eval_protocol/pytest/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1111

1212
from ..models import CompletionParams, EvaluationRow, Message
13+
from .exception_config import ExceptionHandlerConfig
1314

1415
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1516
DatasetPathParam = str
@@ -50,3 +51,6 @@ class RolloutProcessorConfig:
5051
steps: int = 30 # max number of rollout steps
5152
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
5253
kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor
54+
exception_handler_config: Optional[ExceptionHandlerConfig] = (
55+
None # configuration for exception handling with backoff
56+
)

0 commit comments

Comments
 (0)