Skip to content

Commit 8ee5078

Browse files
authored
Add deadlock safe function for calling fallible pieces of code on rank 0 (#4972)
1 parent ef4c1bf commit 8ee5078

3 files changed

Lines changed: 117 additions & 67 deletions

File tree

pyop2/compilation.py

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -565,78 +565,63 @@ def make_so(compiler, code, extension, comm):
565565
exe = compiler.cc
566566
compiler_flags = compiler.cflags
567567

568-
# Compile on compilation communicator (ccomm) rank 0
569-
if ccomm.rank == 0:
570-
# Track exceptions as values so that they may be raised collectively
568+
def compile_single_rank():
569+
# Adding random 2-digit hexnum avoids using excessive filesystem inodes
570+
tempdir = MEM_TMP_DIR.joinpath(f"{randint(0, 255):02x}")
571+
tempdir.mkdir(parents=True, exist_ok=True)
572+
# This path + filename should be unique
573+
descriptor, filename = mkstemp(suffix=f".{extension}", dir=tempdir, text=True)
574+
filename = Path(filename)
575+
576+
cname = filename
577+
oname = filename.with_suffix(".o")
578+
soname = filename.with_suffix(".so")
579+
logfile = filename.with_suffix(".log")
580+
errfile = filename.with_suffix(".err")
581+
571582
try:
572-
# Adding random 2-digit hexnum avoids using excessive filesystem inodes
573-
tempdir = MEM_TMP_DIR.joinpath(f"{randint(0, 255):02x}")
574-
tempdir.mkdir(parents=True, exist_ok=True)
575-
# This path + filename should be unique
576-
descriptor, filename = mkstemp(suffix=f".{extension}", dir=tempdir, text=True)
577-
filename = Path(filename)
578-
579-
cname = filename
580-
oname = filename.with_suffix(".o")
581-
soname = filename.with_suffix(".so")
582-
logfile = filename.with_suffix(".log")
583-
errfile = filename.with_suffix(".err")
584-
except BaseException as e:
585-
result = e
586-
else:
587-
try:
588-
with progress(INFO, 'Compiling wrapper'):
589-
# Write source code to disk
590-
with open(cname, "w") as fh:
591-
fh.write(code)
592-
os.close(descriptor)
593-
594-
if not compiler.ld:
595-
# Compile and link
596-
cc = (exe,) + compiler_flags + ('-o', str(soname), str(cname)) + compiler.ldflags
597-
_run(cc, logfile, errfile)
598-
else:
599-
# Compile
600-
cc = (exe,) + compiler_flags + ('-c', '-o', str(oname), str(cname))
601-
_run(cc, logfile, errfile)
602-
# Extract linker specific "cflags" from ldflags and link
603-
ld = tuple(shlex.split(compiler.ld)) + ('-o', str(soname), str(oname)) + tuple(expandWl(compiler.ldflags))
604-
_run(ld, logfile, errfile, step="Linker", filemode="a")
605-
606-
result = soname
607-
except subprocess.CalledProcessError as e:
608-
msg = dedent(f"""
609-
Command "{e.cmd}" return error status {e.returncode}.
610-
Unable to compile code
611-
""")
612-
if os.environ.get("FIREDRAKE_CI", False):
613-
msg += dedent(f"""
614-
Code is:
615-
{code}
616-
""")
617-
with open(errfile) as err:
618-
msg += dedent(f"""
619-
Compiler output is:
620-
{''.join(err.readlines())}
621-
""")
583+
with progress(INFO, 'Compiling wrapper'):
584+
# Write source code to disk
585+
with open(cname, "w") as fh:
586+
fh.write(code)
587+
os.close(descriptor)
588+
589+
if not compiler.ld:
590+
# Compile and link
591+
cc = (exe,) + compiler_flags + ('-o', str(soname), str(cname)) + compiler.ldflags
592+
_run(cc, logfile, errfile)
622593
else:
594+
# Compile
595+
cc = (exe,) + compiler_flags + ('-c', '-o', str(oname), str(cname))
596+
_run(cc, logfile, errfile)
597+
# Extract linker specific "cflags" from ldflags and link
598+
ld = tuple(shlex.split(compiler.ld)) + ('-o', str(soname), str(oname)) + tuple(expandWl(compiler.ldflags))
599+
_run(ld, logfile, errfile, step="Linker", filemode="a")
600+
except subprocess.CalledProcessError as e:
601+
msg = dedent(f"""
602+
Command "{e.cmd}" return error status {e.returncode}.
603+
Unable to compile code
604+
""")
605+
if os.environ.get("FIREDRAKE_CI", False):
606+
msg += dedent(f"""
607+
Code is:
608+
{code}
609+
""")
610+
with open(errfile) as err:
623611
msg += dedent(f"""
624-
Compile log in {logfile!s}
625-
Compile errors in {errfile!s}
612+
Compiler output is:
613+
{''.join(err.readlines())}
626614
""")
627-
result = CompilationError(msg)
628-
result.__cause__ = e # equivalent to 'raise XXX from e'
629-
except BaseException as e:
630-
# catch and broadcast all exceptions to prevent deadlocks
631-
result = e
632-
else:
633-
result = None
615+
else:
616+
msg += dedent(f"""
617+
Compile log in {logfile!s}
618+
Compile errors in {errfile!s}
619+
""")
620+
raise CompilationError(msg) from e
621+
else:
622+
return soname
634623

635-
result = ccomm.bcast(result)
636-
if isinstance(result, BaseException):
637-
raise result
638-
else:
639-
return result
624+
return mpi.safe_noncollective(ccomm, compile_single_rank, root=0)
640625

641626

642627
def _run(cc, logfile, errfile, step="Compilation", filemode="w"):

pyop2/mpi.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"""PyOP2 MPI communicator."""
3535

3636

37+
from typing import Any, Callable
3738
from petsc4py import PETSc
3839
from mpi4py import MPI # noqa
3940
from itertools import count
@@ -565,6 +566,44 @@ def finalize_safe_debug():
565566
return debug
566567

567568

569+
def safe_noncollective(comm: MPI.Comm, func: Callable[[], Any], *, root: int) -> Any:
570+
"""Run a function on a single rank of ``comm`` in a deadlock safe way.
571+
572+
If an exception is raised on the active rank then this is caught and
573+
raised collectively.
574+
575+
Parameters
576+
----------
577+
comm
578+
The communicator.
579+
func
580+
The operation to be performed on a single rank. This should be a
581+
callable that takes no arguments.
582+
root
583+
The rank performing the operation.
584+
585+
Returns
586+
-------
587+
Any
588+
The result of ``func``, broadcasted to all ranks.
589+
590+
"""
591+
if comm.rank == root:
592+
try:
593+
result = func()
594+
except BaseException as e:
595+
result = e
596+
else:
597+
result = None
598+
599+
with temp_internal_comm(comm) as icomm:
600+
result = icomm.bcast(result, root=root)
601+
if isinstance(result, BaseException):
602+
raise result
603+
else:
604+
return result
605+
606+
568607
@atexit.register
569608
def _free_comms():
570609
"""Free all outstanding communicators."""

tests/pyop2/test_mpi.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
import pytest_mpi
3+
from mpi4py import MPI
4+
5+
import pyop2.mpi
6+
7+
8+
def passing_test():
9+
return "pass"
10+
11+
12+
def failing_test():
13+
raise RuntimeError("This test has failed")
14+
15+
16+
@pytest.mark.parallel(2)
17+
@pytest.mark.parametrize("root", [0, 1])
18+
def test_branches_on_rank_do_not_deadlock(root):
19+
result = pyop2.mpi.safe_noncollective(MPI.COMM_WORLD, passing_test, root=root)
20+
pytest_mpi.parallel_assert(result == "pass")
21+
22+
try:
23+
result = pyop2.mpi.safe_noncollective(MPI.COMM_WORLD, failing_test, root=root)
24+
except BaseException as e:
25+
result = e
26+
pytest_mpi.parallel_assert(isinstance(result, RuntimeError) and str(result) == "This test has failed")

0 commit comments

Comments
 (0)