Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,8 +2071,50 @@ def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
rv = completion_queue.get()
for i, (p, completion_queue) in enumerate(
zip(self.processes, self.completion_queues)
):
# Bounded wait. Without a timeout, a rank that dies or
# gets stuck before sending its completion message causes
# this get() to block forever.
_to = getattr(self.__class__, "timeout", None)
_timeout_s = (_to.total_seconds() if _to is not None else 120) + 60
try:
rv = completion_queue.get(timeout=_timeout_s)
except queue.Empty:
_alive = p.is_alive()
_exitcode = p.exitcode
_pid = p.pid
# Tear down the whole (broken) worker pool so no rank
# is left stuck holding devices for the next test.
for _p in self.processes:
try:
if _p.is_alive():
_p.terminate()
except Exception:
pass
for _p in self.processes:
try:
_p.join(timeout=10)
if _p.is_alive():
_p.kill()
except Exception:
pass
# Force a FRESH pool to be spawned for the next test
# (recover instead of poisoning the class).
self.__class__._processes_spawned = False
if _alive:
raise TimeoutError(
f"Rank {i} (pid {_pid}) did not send completion for "
f"{self.id()} within {_timeout_s}s (rank still alive "
f"- likely stuck). Worker pool torn down; next test "
f"will use a fresh pool."
) from None
raise RuntimeError(
f"Rank {i} (pid {_pid}) died with exit_code={_exitcode} "
f"before sending completion for {self.id()}. Worker pool "
f"torn down; next test will use a fresh pool."
) from None
if isinstance(rv, unittest.SkipTest):
raise rv
if isinstance(rv, BaseException):
Expand Down