diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index c2061c924e79e..fb0b4facda896 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -2071,8 +2071,50 @@ def wrapper(self): if self.rank == self.MAIN_PROCESS_RANK: logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004 # Wait for the workers to finish the test - for i, completion_queue in enumerate(self.completion_queues): - rv = completion_queue.get() + for i, (p, completion_queue) in enumerate( + zip(self.processes, self.completion_queues) + ): + # Bounded wait. Without a timeout, a rank that dies or + # gets stuck before sending its completion message causes + # this get() to block forever. + _to = getattr(self.__class__, "timeout", None) + _timeout_s = (_to.total_seconds() if _to is not None else 120) + 60 + try: + rv = completion_queue.get(timeout=_timeout_s) + except queue.Empty: + _alive = p.is_alive() + _exitcode = p.exitcode + _pid = p.pid + # Tear down the whole (broken) worker pool so no rank + # is left stuck holding devices for the next test. + for _p in self.processes: + try: + if _p.is_alive(): + _p.terminate() + except Exception: + pass + for _p in self.processes: + try: + _p.join(timeout=10) + if _p.is_alive(): + _p.kill() + except Exception: + pass + # Force a FRESH pool to be spawned for the next test + # (recover instead of poisoning the class). + self.__class__._processes_spawned = False + if _alive: + raise TimeoutError( + f"Rank {i} (pid {_pid}) did not send completion for " + f"{self.id()} within {_timeout_s}s (rank still alive " + f"- likely stuck). Worker pool torn down; next test " + f"will use a fresh pool." + ) from None + raise RuntimeError( + f"Rank {i} (pid {_pid}) died with exit_code={_exitcode} " + f"before sending completion for {self.id()}. Worker pool " + f"torn down; next test will use a fresh pool." + ) from None if isinstance(rv, unittest.SkipTest): raise rv if isinstance(rv, BaseException):