diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 282e6dae..8148d1f8 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -54,6 +54,7 @@ retry_if_not_result, retry_if_result, retry_never, + retry_unless_exception_cause_type, retry_unless_exception_type, ) @@ -782,6 +783,7 @@ def wrap(f: t.Callable[P, R]) -> _RetryDecorated[P, R]: "retry_if_not_result", "retry_if_result", "retry_never", + "retry_unless_exception_cause_type", "retry_unless_exception_type", "sleep", "sleep_using_event", diff --git a/tenacity/retry.py b/tenacity/retry.py index df0cc4d6..2e6769a6 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -57,10 +57,28 @@ def __ror__(self, other: "RetryBaseT") -> "retry_any": return retry_any(*other.retries, self) return retry_any(other, self) + def __invert__(self) -> "retry_base": + """Return a retry strategy that is the logical inverse of this one.""" + return _retry_inverted(self) + RetryBaseT = retry_base | typing.Callable[["RetryCallState"], bool] +class _retry_inverted(retry_base): + """Retry strategy that inverts the decision of another retry strategy.""" + + def __init__(self, retry: "retry_base") -> None: + self.retry = retry + + def __call__(self, retry_state: "RetryCallState") -> bool: + return not self.retry(retry_state) + + def __invert__(self) -> "retry_base": + # Double inversion returns the original. + return self.retry + + class _retry_never(retry_base): """Retry strategy that never rejects any result.""" @@ -185,6 +203,39 @@ def __call__(self, retry_state: "RetryCallState") -> bool: return False +class retry_unless_exception_cause_type(retry_base): + """Retries unless any of the causes of the raised exception is of one or more types. + + This is the inverse of `retry_if_exception_cause_type`: it keeps retrying + as long as none of the causes in the exception chain match the given + type(s). As soon as a matching cause is found, it stops retrying. + + The check on the type of the cause of the exception is done recursively + (until finding an exception in the chain that has no `__cause__`). + """ + + def __init__( + self, + exception_types: type[BaseException] + | tuple[type[BaseException], ...] = Exception, + ) -> None: + self.exception_cause_types = exception_types + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome is None: + raise RuntimeError("__call__ called before outcome was set") + + if retry_state.outcome.failed: + exc = retry_state.outcome.exception() + while exc is not None: + if isinstance(exc.__cause__, self.exception_cause_types): + return False # a matching cause found — stop retrying + exc = exc.__cause__ + return True # no matching cause anywhere in the chain — keep retrying + + return False + + class retry_if_result(retry_base): """Retries if the result verifies a predicate.""" diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 6a397392..cffdb975 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -28,8 +28,8 @@ from typeguard import check_type import tenacity -from tenacity import RetryCallState, RetryError, Retrying, retry -from tenacity.retry import retry_all, retry_any +from tenacity import RetryCallState, RetryError, Retrying, retry, stop_after_attempt +from tenacity.retry import retry_all, retry_any, retry_unless_exception_cause_type _unset = object() @@ -2125,5 +2125,41 @@ def succeed_on_third() -> str: assert calls == 3 +def test_retry_unless_exception_cause_type_logic() -> None: + class StopError(Exception): + pass + + class ContinueError(Exception): + pass + + stop_attempts = [] + + @retry( + retry=retry_unless_exception_cause_type(StopError), stop=stop_after_attempt(3) + ) + def fail_with_stop() -> None: + stop_attempts.append(1) + raise RuntimeError from StopError() + + continue_attempts = [] + + @retry( + retry=retry_unless_exception_cause_type(StopError), stop=stop_after_attempt(3) + ) + def fail_with_continue() -> None: + continue_attempts.append(1) + raise RuntimeError from ContinueError() + + # Test 1: Should stop immediately (raise the raw RuntimeError) + with contextlib.suppress(RuntimeError): + fail_with_stop() + assert len(stop_attempts) == 1 + + # Test 2: Should retry 3 times (hits limit, raises RetryError) + with contextlib.suppress(RetryError): + fail_with_continue() + assert len(continue_attempts) == 3 + + if __name__ == "__main__": unittest.main()