Skip to content

Commit 1ba8790

Browse files
JHopeCollinsdham
andauthored
Ensemble.sequential context manager (#4964)
Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
1 parent 285ed4c commit 1ba8790

2 files changed

Lines changed: 132 additions & 0 deletions

File tree

firedrake/ensemble/ensemble.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from functools import wraps
22
import weakref
3+
from contextlib import contextmanager
34
from itertools import zip_longest
5+
from types import SimpleNamespace
46

57
from firedrake.petsc import PETSc
68
from firedrake.function import Function
@@ -584,3 +586,93 @@ def isendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0,
584586
requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag)
585587
for dat in frecv.dat])
586588
return requests
589+
590+
@contextmanager
591+
def sequential(self, *, synchronise: bool = False, reverse: bool = False, **kwargs):
592+
"""
593+
Context manager for executing code on each ensemble
594+
member consecutively (ordered by increasing
595+
:attr:`~.Ensemble.ensemble_rank`).
596+
597+
Any data in ``kwargs`` will be made available in the returned
598+
context and will be communicated forward after each ensemble
599+
member exits. :class:`.Function` or :class:`.Cofunction`
600+
``kwargs`` will be sent with the corresponding Ensemble methods.
601+
602+
For example:
603+
604+
.. code-block:: python3
605+
606+
with ensemble.sequential(index=0) as ctx:
607+
print(ensemble.ensemble_rank, ctx.index)
608+
ctx.index += 2
609+
610+
Would print:
611+
612+
.. code-block::
613+
614+
0 0
615+
1 2
616+
2 4
617+
3 6
618+
...
619+
620+
If ``reverse is True`` then the ensemble ranks will be looped through
621+
in decreasing order i.e. ``ensemble_rank == (ensemble_size - 1)`` will
622+
run first, then ``ensemble_rank == (ensemble_size - 2)`` etc.
623+
624+
Parameters
625+
----------
626+
synchronise :
627+
If True then MPI_Barrier will be called on the ``global_comm``
628+
at the beginning and end of this method.
629+
630+
reverse :
631+
If True then will iterate through spatial comms in order of
632+
decreasing ``ensemble_rank``.
633+
634+
kwargs :
635+
Data to be passed forward by each rank and made available
636+
in the returned ``ctx``.
637+
"""
638+
rank = self.ensemble_rank
639+
if reverse: # send backwards
640+
src = rank + 1
641+
dst = rank - 1
642+
first_rank = (rank == self.ensemble_size - 1)
643+
last_rank = (rank == 0)
644+
else: # send forwards
645+
src = rank - 1
646+
dst = rank + 1
647+
first_rank = (rank == 0)
648+
last_rank = (rank == self.ensemble_size - 1)
649+
650+
if synchronise:
651+
self.global_comm.Barrier()
652+
653+
if not first_rank:
654+
for i, (k, v) in enumerate(kwargs.items()):
655+
if isinstance(v, (Function, Cofunction)):
656+
# Functions are sent in-place, everything else is pickled
657+
recv_args = [kwargs[k]]
658+
else:
659+
recv_args = []
660+
kwargs[k] = self.recv(*recv_args, source=src, tag=rank+i*100)
661+
662+
ctx = SimpleNamespace(**kwargs)
663+
yield ctx
664+
665+
if not last_rank:
666+
for i, v in enumerate((getattr(ctx, k)
667+
for k in kwargs.keys())):
668+
try:
669+
self.send(v, dest=dst, tag=dst+i*100)
670+
except Exception as error:
671+
raise TypeError(
672+
"Failed to send object of type {type(v)__name__}. kwargs for"
673+
" Ensemble.sequential must be Functions, Cofunctions,"
674+
" or acceptable arguments to mpi4py.MPI.Comm.send."
675+
) from error
676+
677+
if synchronise:
678+
self.global_comm.Barrier()

tests/firedrake/ensemble/test_ensemble.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,43 @@ def test_ensemble_solvers(ensemble, W, urank, urank_sum):
380380
ensemble.allreduce(u_separate, usum)
381381

382382
parallel_assert(errornorm(u_combined, usum) < 1e-8)
383+
384+
385+
@pytest.mark.parallel(nprocs=6)
386+
@pytest.mark.parametrize("direction", ["forward", "reverse"])
387+
def test_ensemble_sequential(ensemble, direction):
388+
"""
389+
Test that the sequential context manager sends forward
390+
the correct values after each rank has executed, for both
391+
intrinsic types (float) and Firedrake types (Function).
392+
"""
393+
394+
rank = ensemble.ensemble_rank
395+
mesh = UnitIntervalMesh(1, comm=ensemble.comm)
396+
R = FunctionSpace(mesh, "R", 0)
397+
398+
reverse = direction == "reverse"
399+
400+
idx_i = 0
401+
idx_f = Function(R).zero()
402+
two = Function(R).assign(2)
403+
404+
with ensemble.sequential(reverse=reverse, idx_i=idx_i, idx_f=idx_f) as ctx:
405+
recv_i = float(ctx.idx_i)
406+
recv_f = float(ctx.idx_f)
407+
408+
ctx.idx_i += 2
409+
ctx.idx_f += two
410+
411+
if reverse:
412+
expected = 2*(ensemble.ensemble_size - 1 - rank)
413+
else:
414+
expected = 2*rank
415+
416+
parallel_assert(
417+
recv_i == expected,
418+
msg=f"Failed to send int properly. Expecting {expected} but received {recv_i}")
419+
420+
parallel_assert(
421+
abs(float(recv_f)-expected) < 1e-12,
422+
msg=f"Failed to send Function properly. Expecting {expected} but received {float(recv_f)}")

0 commit comments

Comments
 (0)