|
1 | 1 | from functools import wraps |
2 | 2 | import weakref |
| 3 | +from contextlib import contextmanager |
3 | 4 | from itertools import zip_longest |
| 5 | +from types import SimpleNamespace |
4 | 6 |
|
5 | 7 | from firedrake.petsc import PETSc |
6 | 8 | from firedrake.function import Function |
@@ -584,3 +586,93 @@ def isendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0, |
584 | 586 | requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) |
585 | 587 | for dat in frecv.dat]) |
586 | 588 | return requests |
| 589 | + |
| 590 | + @contextmanager |
| 591 | + def sequential(self, *, synchronise: bool = False, reverse: bool = False, **kwargs): |
| 592 | + """ |
| 593 | + Context manager for executing code on each ensemble |
| 594 | + member consecutively (ordered by increasing |
| 595 | + :attr:`~.Ensemble.ensemble_rank`). |
| 596 | +
|
| 597 | + Any data in ``kwargs`` will be made available in the returned |
| 598 | + context and will be communicated forward after each ensemble |
| 599 | + member exits. :class:`.Function` or :class:`.Cofunction` |
| 600 | + ``kwargs`` will be sent with the corresponding Ensemble methods. |
| 601 | +
|
| 602 | + For example: |
| 603 | +
|
| 604 | + .. code-block:: python3 |
| 605 | +
|
| 606 | + with ensemble.sequential(index=0) as ctx: |
| 607 | + print(ensemble.ensemble_rank, ctx.index) |
| 608 | + ctx.index += 2 |
| 609 | +
|
| 610 | + Would print: |
| 611 | +
|
| 612 | + .. code-block:: |
| 613 | +
|
| 614 | + 0 0 |
| 615 | + 1 2 |
| 616 | + 2 4 |
| 617 | + 3 6 |
| 618 | + ... |
| 619 | +
|
| 620 | + If ``reverse is True`` then the ensemble ranks will be looped through |
| 621 | + in decreasing order i.e. ``ensemble_rank == (ensemble_size - 1)`` will |
| 622 | + run first, then ``ensemble_rank == (ensemble_size - 2)`` etc. |
| 623 | +
|
| 624 | + Parameters |
| 625 | + ---------- |
| 626 | + synchronise : |
| 627 | + If True then MPI_Barrier will be called on the ``global_comm`` |
| 628 | + at the beginning and end of this method. |
| 629 | +
|
| 630 | + reverse : |
| 631 | + If True then will iterate through spatial comms in order of |
| 632 | + decreasing ``ensemble_rank``. |
| 633 | +
|
| 634 | + kwargs : |
| 635 | + Data to be passed forward by each rank and made available |
| 636 | + in the returned ``ctx``. |
| 637 | + """ |
| 638 | + rank = self.ensemble_rank |
| 639 | + if reverse: # send backwards |
| 640 | + src = rank + 1 |
| 641 | + dst = rank - 1 |
| 642 | + first_rank = (rank == self.ensemble_size - 1) |
| 643 | + last_rank = (rank == 0) |
| 644 | + else: # send forwards |
| 645 | + src = rank - 1 |
| 646 | + dst = rank + 1 |
| 647 | + first_rank = (rank == 0) |
| 648 | + last_rank = (rank == self.ensemble_size - 1) |
| 649 | + |
| 650 | + if synchronise: |
| 651 | + self.global_comm.Barrier() |
| 652 | + |
| 653 | + if not first_rank: |
| 654 | + for i, (k, v) in enumerate(kwargs.items()): |
| 655 | + if isinstance(v, (Function, Cofunction)): |
| 656 | + # Functions are sent in-place, everything else is pickled |
| 657 | + recv_args = [kwargs[k]] |
| 658 | + else: |
| 659 | + recv_args = [] |
| 660 | + kwargs[k] = self.recv(*recv_args, source=src, tag=rank+i*100) |
| 661 | + |
| 662 | + ctx = SimpleNamespace(**kwargs) |
| 663 | + yield ctx |
| 664 | + |
| 665 | + if not last_rank: |
| 666 | + for i, v in enumerate((getattr(ctx, k) |
| 667 | + for k in kwargs.keys())): |
| 668 | + try: |
| 669 | + self.send(v, dest=dst, tag=dst+i*100) |
| 670 | + except Exception as error: |
| 671 | + raise TypeError( |
| 672 | + "Failed to send object of type {type(v)__name__}. kwargs for" |
| 673 | + " Ensemble.sequential must be Functions, Cofunctions," |
| 674 | + " or acceptable arguments to mpi4py.MPI.Comm.send." |
| 675 | + ) from error |
| 676 | + |
| 677 | + if synchronise: |
| 678 | + self.global_comm.Barrier() |
0 commit comments