diff --git a/deadpool.py b/deadpool.py index b66d5c6..ed24461 100644 --- a/deadpool.py +++ b/deadpool.py @@ -642,6 +642,9 @@ def submit( return fut def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + if self.closed: + return + logger.debug(f"shutdown: {wait=} {cancel_futures=}") # No more new tasks can be submitted @@ -706,16 +709,14 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self.closed: - kwargs = {} - if self.shutdown_wait is not None: - kwargs["wait"] = self.shutdown_wait - - if self.shutdown_cancel_futures is not None: - kwargs["cancel_futures"] = self.shutdown_cancel_futures + kwargs = {} + if self.shutdown_wait is not None: + kwargs["wait"] = self.shutdown_wait - self.shutdown(**kwargs) + if self.shutdown_cancel_futures is not None: + kwargs["cancel_futures"] = self.shutdown_cancel_futures + self.shutdown(**kwargs) self.runner_thread.join() return False diff --git a/tests/test_deadpool.py b/tests/test_deadpool.py index 9034e87..9040647 100644 --- a/tests/test_deadpool.py +++ b/tests/test_deadpool.py @@ -344,6 +344,14 @@ def test_shutdown(logging_initializer, wait, cancel_futures): assert result == 123 +@pytest.mark.parametrize("wait", [True, False]) +def test_shutdown_idempotent(wait): + """Calling shutdown() twice must not deadlock.""" + d = deadpool.Deadpool(max_workers=2) + d.shutdown(wait=wait) + d.shutdown(wait=wait) + + @pytest.mark.parametrize("wait", [True, False]) @pytest.mark.parametrize("cancel_futures", [True, False]) def test_shutdown_manual(logging_initializer, wait, cancel_futures):