diff --git a/tenacity/retry.py b/tenacity/retry.py index df0cc4d6..60c4b626 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -126,7 +126,7 @@ def __init__( super().__init__(self._check) def _check(self, e: BaseException) -> bool: - return not isinstance(e, self.exception_types) + return isinstance(e, Exception) and not isinstance(e, self.exception_types) class retry_unless_exception_type(retry_if_exception): diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 7b5c6416..e8b3a942 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -159,6 +159,28 @@ def after(retry_state: RetryCallState) -> None: assert len(set(things)) == 1 assert list(attempt_nos2) == [1, 2, 3] + @asynctest + async def test_retry_if_not_exception_type_does_not_swallow_cancelled_error( + self, + ) -> None: + attempts = 0 + + @retry( + retry=tenacity.retry_if_not_exception_type(ValueError), + stop=stop_after_attempt(2), + wait=wait_fixed(0), + reraise=True, + ) + async def always_sleep() -> None: + nonlocal attempts + attempts += 1 + await asyncio.sleep(1) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(always_sleep(), 0.01) + + assert attempts == 1 + class TestAsyncEnabled(unittest.TestCase): @asynctest diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index 6a397392..63ec7c3a 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -1265,6 +1265,27 @@ def test_retry_except_exception_of_type(self) -> None: self.assertTrue(isinstance(err, IOError)) print(err) + def test_retry_except_exception_of_type_does_not_retry_base_exception(self) -> None: + calls = 0 + + class CustomBaseError(BaseException): + pass + + @retry( + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_not_exception_type(IOError), + reraise=True, + ) + def raises_base_exception() -> None: + nonlocal calls + calls += 1 + raise CustomBaseError + + with pytest.raises(CustomBaseError): + raises_base_exception() + + assert calls == 1 + def test_retry_until_exception_of_type_attempt_number(self) -> None: try: self.assertTrue(