Skip to content

Commit 2d3aa95

Browse files
sghelichkhaniJHopeCollinsleo-collins
authored
Add per-rank disk checkpointing for adjoint tape (#4891)
Enable saving and loading Functions to local storage on subcommunicators rather than shared storage on mesh.comm. Co-authored-by: Josh Hope-Collins <joshua.hope-collins13@imperial.ac.uk> Co-authored-by: Leo Collins <leocollins511@gmail.com>
1 parent 15a9980 commit 2d3aa95

3 files changed

Lines changed: 684 additions & 76 deletions

File tree

firedrake/adjoint_utils/checkpointing.py

Lines changed: 206 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import shutil
88
import atexit
9+
import warnings
910
from abc import ABC, abstractmethod
1011
from numbers import Number
1112
_enable_disk_checkpoint = False
@@ -49,7 +50,8 @@ def __exit__(self, *args):
4950
_checkpoint_init_data = self._init
5051

5152

52-
def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
53+
def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True,
54+
checkpoint_comm=None, checkpoint_dir=None):
5355
"""Add a DiskCheckpointer to the current tape.
5456
5557
Disk checkpointing is fully enabled by calling::
@@ -68,23 +70,48 @@ def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
6870
`checkpoint_schedules` provides other schedules for checkpointing to memory, disk,
6971
or a combination of both.
7072
73+
For HPC systems with fast node-local storage, function data can be
74+
checkpointed on a sub-communicator to avoid parallel HDF5 overhead::
75+
76+
enable_disk_checkpointing(checkpoint_comm=MPI.COMM_SELF,
77+
checkpoint_dir="/local/scratch")
78+
7179
Parameters
7280
----------
7381
dirname : str
74-
The directory in which the disk checkpoints should be stored. If not
75-
specified then the current working directory is used. Checkpoints are
76-
stored in a temporary subdirectory of this directory.
82+
The directory in which the shared disk checkpoints should be stored.
83+
If not specified then the current working directory is used.
84+
Checkpoints are stored in a temporary subdirectory of this directory.
7785
comm : mpi4py.MPI.Intracomm
7886
The MPI communicator over which the computation to be disk checkpointed
7987
is defined. This will usually match the communicator on which the
8088
mesh(es) are defined.
8189
cleanup : bool
8290
If set to False, checkpoint files will not be deleted when no longer
8391
required. This is usually only useful for debugging.
92+
checkpoint_comm : mpi4py.MPI.Intracomm or None
93+
If specified, function data is checkpointed using PETSc Vec I/O on
94+
this communicator instead of using Firedrake's CheckpointFile. This
95+
bypasses parallel HDF5 and is ideal for node-local storage on HPC
96+
systems. Passing ``MPI.COMM_SELF`` gives each rank its own file,
97+
while a shared node communicator groups ranks that share storage.
98+
The mesh checkpoint (via ``checkpointable_mesh``) always uses shared
99+
storage. Requires the same communicator layout on restore.
100+
checkpoint_dir : str or None
101+
The directory in which checkpoint_comm files are stored. Only used
102+
when ``checkpoint_comm`` is not None. Each group of ranks sharing
103+
a checkpoint_comm creates a temporary subdirectory here. This
104+
directory must be accessible from all ranks within each
105+
checkpoint_comm group. For example, using a node-local path like
106+
/tmp is safe when checkpoint_comm groups ranks on the same node,
107+
but would fail if checkpoint_comm spans nodes whose filesystems
108+
are not shared.
84109
"""
85110
tape = get_working_tape()
86111
if "firedrake" not in tape._package_data:
87-
tape._package_data["firedrake"] = DiskCheckpointer(dirname, comm, cleanup)
112+
tape._package_data["firedrake"] = DiskCheckpointer(
113+
dirname, comm, cleanup, checkpoint_comm, checkpoint_dir
114+
)
88115

89116

90117
def disk_checkpointing():
@@ -120,14 +147,29 @@ def __exit__(self, *args):
120147

121148
class CheckPointFileReference:
122149
"""A filename which deletes the associated file when it is destroyed."""
123-
def __init__(self, name, comm, cleanup=False):
150+
def __init__(self, name, comm, cleanup=False, checkpoint_comm=None):
124151
self.name = name
125152
self.comm = comm
126153
self.cleanup = cleanup
154+
self.checkpoint_comm = checkpoint_comm
127155

128156
def __del__(self):
129-
if self.cleanup and self.comm.rank == 0 and os.path.exists(self.name):
130-
os.remove(self.name)
157+
if self.cleanup and os.path.exists(self.name):
158+
if self.comm.rank == 0:
159+
os.remove(self.name)
160+
# Prune the index-tracking entry for this file from CheckpointFunction.
161+
# This is safe for the following reasons:
162+
# (1) CheckpointFunction holds self.file as a direct strong reference,
163+
# so __del__ here can only fire after every CheckpointFunction that
164+
# wrote to this filepath has already been garbage-collected.
165+
# (2) restore() never reads _checkpoint_indices — it uses stored_name
166+
# and stored_index baked into the CheckpointFunction at save time.
167+
# (3) Under revolve schedules the tape checkpoint store holds the
168+
# CheckPointFileReference alive until forward re-execution is done,
169+
# so there is no risk of premature pruning.
170+
# (4) pop is a no-op for init files where no CheckpointFunction ever
171+
# wrote an entry (e.g. checkpointable_mesh files).
172+
CheckpointFunction._checkpoint_indices.pop(self.name, None)
131173

132174

133175
class DiskCheckpointer(TapePackageData):
@@ -136,52 +178,128 @@ class DiskCheckpointer(TapePackageData):
136178
Parameters
137179
----------
138180
dirname : str
139-
The directory in which the disk checkpoints should be stored. If not
140-
specified then the current working directory is used. Checkpoints are
141-
stored in a temporary subdirectory of this directory.
181+
The directory in which the shared disk checkpoints should be stored.
182+
If not specified then the current working directory is used.
183+
Checkpoints are stored in a temporary subdirectory of this directory.
142184
comm : mpi4py.MPI.Intracomm
143185
The MPI communicator over which the computation to be disk checkpointed
144186
is defined. This will usually match the communicator on which the
145187
mesh(es) are defined.
146188
cleanup : bool
147189
If set to False, checkpoint files will not be deleted when no longer
148190
required. This is usually only useful for debugging.
191+
checkpoint_comm : mpi4py.MPI.Intracomm or None
192+
If specified, function data is checkpointed on this communicator.
193+
checkpoint_dir : str or None
194+
Directory for checkpoint_comm files. This directory must be
195+
accessible from all ranks within each checkpoint_comm group.
196+
For example, using a node-local path like /tmp is safe when
197+
checkpoint_comm groups ranks on the same node, but would fail
198+
if checkpoint_comm spans nodes whose filesystems are not shared.
149199
"""
150200

151-
def __init__(self, dirname=None, comm=COMM_WORLD, cleanup=True):
152-
153-
if comm.rank == 0:
154-
self.dirname = comm.bcast(tempfile.mkdtemp(
155-
prefix="firedrake_adjoint_checkpoint_", dir=dirname or os.getcwd()
156-
))
157-
else:
158-
self.dirname = comm.bcast("")
201+
def __init__(self, dirname=None, comm=COMM_WORLD, cleanup=True,
202+
checkpoint_comm=None, checkpoint_dir=None):
203+
self.checkpoint_comm = checkpoint_comm
159204
self.comm = comm
160205
self.cleanup = cleanup
206+
207+
# Shared directory (for mesh checkpoint and init data). The bcast
208+
# uses comm (COMM_WORLD) so every rank knows the shared path.
209+
path = tempfile.mkdtemp(
210+
prefix="firedrake_adjoint_checkpoint_", dir=dirname or os.getcwd()
211+
) if comm.rank == 0 else None
212+
self.dirname = comm.bcast(path)
161213
if self.cleanup and comm.rank == 0:
162-
# Delete the checkpoint folder on process exit.
214+
# Delete the shared checkpoint folder on process exit.
163215
atexit.register(shutil.rmtree, self.dirname)
164-
# # A checkpoint file holding the state of block variables set outside
165-
# the tape.
166-
self.init_checkpoint_file = self.new_checkpoint_file()
167-
self.current_checkpoint_file = self.new_checkpoint_file()
168216

169-
def new_checkpoint_file(self):
170-
"""Set up a disk checkpointing file."""
217+
# Local directory (for function data on checkpoint_comm). The bcast
218+
# uses checkpoint_comm, not comm: only ranks within the same
219+
# checkpoint_comm group share a local filesystem, so we must not
220+
# perform a COMM_WORLD collective here.
221+
if self.checkpoint_comm is not None:
222+
if checkpoint_dir is None:
223+
warnings.warn(
224+
"checkpoint_comm without checkpoint_dir defaults to cwd, "
225+
"which is usually on the shared filesystem. Without a "
226+
"node-local path the collective CheckpointFile is more "
227+
"suitable. Consider setting checkpoint_dir.",
228+
UserWarning
229+
)
230+
base_dir = checkpoint_dir or os.getcwd()
231+
if checkpoint_comm.rank == 0:
232+
# ignore_cleanup_errors avoids tracebacks if the finalizer fires
233+
# during interpreter shutdown after MPI has already finalized.
234+
self._local_tmpdir = tempfile.TemporaryDirectory(
235+
prefix="firedrake_adjoint_checkpoint_cc_",
236+
dir=base_dir,
237+
delete=cleanup,
238+
ignore_cleanup_errors=True,
239+
)
240+
local_path = self._local_tmpdir.name
241+
else:
242+
self._local_tmpdir = None
243+
local_path = None
244+
self._local_dirname = checkpoint_comm.bcast(local_path)
245+
else:
246+
self._local_tmpdir = None
247+
self._local_dirname = None
248+
249+
# A checkpoint file holding the state of block variables set outside
250+
# the tape (always shared, used by checkpointable_mesh).
251+
self.init_checkpoint_file = self._new_shared_checkpoint_file()
252+
self.current_checkpoint_file = self._new_checkpoint_file()
253+
254+
def _new_shared_checkpoint_file(self):
255+
"""Set up a shared disk checkpointing file (all ranks use same file)."""
171256
from firedrake.checkpointing import CheckpointFile
172257
if self.comm.rank == 0:
173-
_, checkpoint_file = tempfile.mkstemp(
174-
dir=self.dirname, suffix=".h5"
175-
)
176-
checkpoint_file = self.comm.bcast(checkpoint_file)
258+
_, checkpoint_file = tempfile.mkstemp(dir=self.dirname, suffix=".h5")
177259
else:
178-
checkpoint_file = self.comm.bcast("")
260+
checkpoint_file = None
261+
checkpoint_file = self.comm.bcast(checkpoint_file)
179262
# Let h5py create a file at this location just to be sure.
180-
with CheckpointFile(checkpoint_file, 'w'):
263+
with CheckpointFile(checkpoint_file, 'w', comm=self.comm):
181264
pass
182265
return CheckPointFileReference(checkpoint_file, self.comm,
183266
self.cleanup)
184267

268+
def _new_checkpoint_comm_file(self):
269+
"""Set up a checkpoint file on the checkpoint communicator."""
270+
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
271+
if self.checkpoint_comm.rank == 0:
272+
fd, filepath = tempfile.mkstemp(dir=self._local_dirname, suffix=".h5")
273+
os.close(fd)
274+
else:
275+
filepath = None
276+
filepath = self.checkpoint_comm.bcast(filepath)
277+
# Initialise an empty HDF5 file. Opened in 'w' mode and immediately
278+
# closed so that subsequent 'a' opens from save_function find a valid
279+
# file.
280+
with TemporaryFunctionCheckpointFile(self.checkpoint_comm, filepath, 'w'):
281+
pass
282+
return CheckPointFileReference(filepath, self.checkpoint_comm, self.cleanup,
283+
checkpoint_comm=self.checkpoint_comm)
284+
285+
def _new_checkpoint_file(self):
286+
"""Set up a checkpoint file for function data."""
287+
if self.checkpoint_comm is not None:
288+
return self._new_checkpoint_comm_file()
289+
else:
290+
return self._new_shared_checkpoint_file()
291+
292+
def new_checkpoint_file(self):
293+
"""Set up a disk checkpointing file."""
294+
warnings.warn(
295+
"'new_checkpoint_file' is deprecated and will be removed in a "
296+
"future release. Checkpoint file management is now handled "
297+
"internally; to advance to a new checkpoint file call "
298+
"'reset()' on the DiskCheckpointer instead.",
299+
FutureWarning
300+
)
301+
return self._new_checkpoint_file()
302+
185303
def clear(self, init=True):
186304
"""Reset the DiskCheckPointer.
187305
@@ -198,8 +316,8 @@ def clear(self, init=True):
198316
if not self.cleanup:
199317
return
200318
if init:
201-
self.init_checkpoint_file = self.new_checkpoint_file()
202-
self.current_checkpoint_file = self.new_checkpoint_file()
319+
self.init_checkpoint_file = self._new_shared_checkpoint_file()
320+
self.current_checkpoint_file = self._new_checkpoint_file()
203321

204322
def reset(self):
205323
self.clear(init=False)
@@ -254,9 +372,9 @@ def checkpointable_mesh(mesh):
254372
"No current checkpoint file. Call enable_disk_checkpointing()."
255373
)
256374

257-
with CheckpointFile(checkpoint_file.name, 'a') as outfile:
375+
with CheckpointFile(checkpoint_file.name, 'a', comm=checkpoint_file.comm) as outfile:
258376
outfile.save_mesh(mesh)
259-
with CheckpointFile(checkpoint_file.name, 'r') as outfile:
377+
with CheckpointFile(checkpoint_file.name, 'r', comm=checkpoint_file.comm) as outfile:
260378
return outfile.load_mesh(mesh.name)
261379

262380

@@ -290,7 +408,6 @@ class CheckpointFunction(CheckpointBase, OverloadedType):
290408
_checkpoint_indices = {}
291409

292410
def __init__(self, function):
293-
from firedrake.checkpointing import CheckpointFile
294411
self.name = function.name()
295412
self.mesh = function.function_space().mesh()
296413
self.file = current_checkpoint_file()
@@ -300,31 +417,70 @@ def __init__(self, function):
300417
"No current checkpoint file. Call enable_disk_checkpointing()."
301418
)
302419

420+
self.count = function.count()
421+
422+
# Compute stored_name and stored_index once, shared by both checkpoint
423+
# paths. stored_name encodes the function space (mesh name + element
424+
# family/degree) so that functions on different meshes or spaces never
425+
# collide. stored_index disambiguates successive saves of the same
426+
# space to the same file.
427+
from firedrake.checkpointing import _generate_function_space_name
303428
stored_names = CheckpointFunction._checkpoint_indices
304429
if self.file.name not in stored_names:
305430
stored_names[self.file.name] = {}
431+
self.stored_name = _generate_function_space_name(function.function_space())
432+
indices = stored_names[self.file.name]
433+
indices.setdefault(self.stored_name, 0)
434+
indices[self.stored_name] += 1
435+
self.stored_index = indices[self.stored_name]
436+
437+
if self.file.checkpoint_comm is not None:
438+
self._function_space = function.function_space()
439+
self._save_local_checkpoint(function)
440+
else:
441+
self._save_shared_checkpoint(function)
306442

307-
self.count = function.count()
308-
with CheckpointFile(self.file.name, 'a') as outfile:
309-
self.stored_name = outfile._generate_function_space_name(
310-
function.function_space()
311-
)
312-
indices = stored_names[self.file.name]
313-
indices.setdefault(self.stored_name, 0)
314-
indices[self.stored_name] += 1
315-
self.stored_index = indices[self.stored_name]
443+
def _save_shared_checkpoint(self, function):
444+
"""Save function data to a shared HDF5 file via CheckpointFile."""
445+
from firedrake.checkpointing import CheckpointFile
446+
with CheckpointFile(self.file.name, 'a', self.file.comm) as outfile:
316447
outfile.save_function(function, name=self.stored_name,
317448
idx=self.stored_index)
318449

450+
def _save_local_checkpoint(self, function):
451+
"""Save function data to a local HDF5 file via PETSc Vec I/O."""
452+
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
453+
with TemporaryFunctionCheckpointFile(
454+
self.file.checkpoint_comm, self.file.name, 'a'
455+
) as outfile:
456+
outfile.save_function(function, self.stored_name, self.stored_index)
457+
319458
def restore(self):
320459
"""Read and return this Function from the checkpoint."""
321-
from firedrake.checkpointing import CheckpointFile
322-
with CheckpointFile(self.file.name, 'r') as infile:
323-
function = infile.load_function(self.mesh, self.stored_name,
324-
idx=self.stored_index)
460+
if self.file.checkpoint_comm is not None:
461+
function = self._restore_local_checkpoint()
462+
else:
463+
function = self._restore_shared_checkpoint()
325464
return type(function)(function.function_space(),
326465
function.dat, name=self.name, count=self.count)
327466

467+
def _restore_shared_checkpoint(self):
468+
"""Load function data from a shared HDF5 file via :class:`.CheckpointFile`."""
469+
from firedrake.checkpointing import CheckpointFile
470+
with CheckpointFile(self.file.name, 'r', comm=self.file.comm) as infile:
471+
return infile.load_function(self.mesh, self.stored_name,
472+
idx=self.stored_index)
473+
474+
def _restore_local_checkpoint(self):
475+
"""Load function data via :class:`TemporaryFunctionCheckpointFile`."""
476+
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
477+
with TemporaryFunctionCheckpointFile(
478+
self.file.checkpoint_comm, self.file.name, 'r'
479+
) as infile:
480+
return infile.load_function(
481+
self._function_space, self.stored_name, self.stored_index
482+
)
483+
328484
def _ad_restore_at_checkpoint(self, checkpoint):
329485
return checkpoint.restore()
330486

0 commit comments

Comments
 (0)