Skip to content

Commit c5f375c

Browse files
committed
Clean up torch training cache, resume, and HDF5 API semantics
- make resumed training treat iterations as additional epochs per train() call - add bounded trainer-owned runtime caches with optional warmup - keep normalization stats collection from implicitly warming runtime caches - add deterministic threaded HDF5 build workers with atomic rebuild handling - support stable save_energies identifiers for HDF5-backed datasets - reject unimplemented memory_mode='mixed' explicitly - refresh docs, notebook examples, and regression coverage for the updated semantics
1 parent b965bea commit c5f375c

15 files changed

Lines changed: 1826 additions & 285 deletions

docs/source/usage/torch_datasets.rst

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ Building the Database
336336
from aenet.geometry import AtomicStructure
337337
from glob import glob
338338
339-
# Define a top-level parser function (required for multiprocessing)
339+
# Define the parser used during HDF5 database construction
340340
def parse_xsf(path: str):
341341
"""Parse XSF file and return torch Structure(s)."""
342342
atomic_struct = AtomicStructure.from_file(path)
@@ -360,6 +360,7 @@ Building the Database
360360
361361
db.build_database(
362362
show_progress=True,
363+
build_workers=8, # optional build-time worker threads
363364
persist_descriptor=True, # optional descriptor recovery step
364365
persist_features=True, # optional persisted raw features
365366
persist_force_derivatives=True, # optional sparse derivative cache
@@ -368,6 +369,14 @@ Building the Database
368369
# ``db`` is immediately reusable after build_database(); reopening the
369370
# file is only needed in a later session or when you want a separate handle.
370371
372+
.. note::
373+
374+
``build_workers`` only affects the one-time ``build_database()`` call.
375+
It parallelizes parser execution and optional persisted-cache preparation
376+
with worker threads, while the parent process still performs all ordered
377+
HDF5 writes. This is
378+
separate from training-time ``num_workers`` on ``TorchTrainingConfig``.
379+
371380
.. note::
372381

373382
``persist_descriptor=True`` stores a small versioned descriptor manifest
@@ -413,7 +422,8 @@ different purposes:
413422
for force-labeled entries to ``/torch_cache/force_derivatives``
414423
* ``cache_features=True`` is a trainer-owned in-memory runtime cache attached
415424
to the current dataset instance; it speeds up repeated accesses within a run
416-
but does not modify the HDF5 file
425+
but does not modify the HDF5 file. Its size is controlled by
426+
``cache_feature_max_entries`` on ``TorchTrainingConfig``
417427

418428
Runtime precedence is explicit:
419429

@@ -456,7 +466,9 @@ Training from HDF5 Database
456466
force_fraction=0.3,
457467
force_sampling="random",
458468
cache_features=True,
469+
cache_feature_max_entries=1024,
459470
cache_neighbors=True,
471+
cache_neighbor_max_entries=512,
460472
num_workers=8, # Parallel workers (each opens own handle)
461473
prefetch_factor=4,
462474
persistent_workers=True,
@@ -479,10 +491,18 @@ Key HDF5 Features
479491
* **Multiprocessing-safe**: Each DataLoader worker opens its own read-only handle
480492
* **Compression**: Built-in HDF5 compression (zlib, blosc) reduces disk usage
481493
* **LRU caching**: Configurable in-memory cache per worker for frequently accessed entries
482-
* **Parser requirements**: Must be a top-level function (pickleable) when using ``num_workers > 0``
494+
* **Build parallelism**: ``build_workers`` accelerates parser execution and
495+
optional persisted-cache generation, but ordered HDF5 writes still happen
496+
in the parent process
497+
* **Parser concurrency**: When using ``build_workers > 1``, make sure the
498+
parser callable is safe to invoke concurrently over independent file paths
483499
* **Unified persisted cache**: Optional ``/torch_cache/features`` and
484500
``/torch_cache/force_derivatives`` sections can be written once and reused
485501
lazily across later HDF5-backed runs
502+
* **Separate trainer cache limits**: ``cache_feature_max_entries``,
503+
``cache_neighbor_max_entries``, and ``cache_force_triplet_max_entries`` bound
504+
the trainer-owned runtime caches separately from HDF5
505+
``in_memory_cache_size``
486506
* **Deterministic handle cleanup**: Call ``dataset.close()`` or use
487507
``with HDF5StructureDataset(...) as dataset:``
488508

@@ -586,6 +606,8 @@ Set these on ``TorchTrainingConfig``:
586606
HDF5 features and does not write back to disk.
587607
* **cache_neighbors**: Reuse neighbor search results for energy-view reuse and legacy non-graph paths
588608
* **cache_force_triplets**: Cache CSR graphs and triplets instead of rebuilding them for the default sparse force-training path
609+
* **cache_*_max_entries**: Bound the trainer-owned runtime caches per split and per process/worker
610+
* **cache_warmup**: Optional single-process cache prefill before epoch 0; skipped automatically when ``num_workers > 0``
589611

590612
For repeated fixed-geometry HDF5 workflows, prefer build-time
591613
``persist_features=True`` and ``persist_force_derivatives=True`` when you want

docs/source/usage/torch_training.rst

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,44 @@ The longer file-backed dataset workflow is intentionally kept in the training
201201
notebook above so the ``torch_datasets`` page can stay focused on compact
202202
API-facing examples.
203203

204+
Execution Model
205+
~~~~~~~~~~~~~~~~
206+
207+
The current trainer has two distinct runtime stages:
208+
209+
1. Sample preparation happens in the main process when ``num_workers=0``, or
210+
in ``DataLoader`` workers when ``num_workers > 0``. Structures are
211+
converted to tensors on ``descriptor.device``, and descriptor
212+
featurization, neighbor reuse, graph/triplet construction, and lazy HDF5
213+
cache reads happen there.
214+
2. The collated batch is then moved onto ``config.device`` inside the
215+
training loop. Model forward passes, normalization, loss computation, and
216+
optimizer steps run on that device.
217+
218+
In practice, GPU training with ``num_workers > 0`` is best understood as
219+
worker-side data preparation feeding a training loop on the selected device.
220+
It is not currently a separate mixed CPU/GPU execution pipeline.
221+
222+
If ``descriptor.device`` and ``config.device`` match, featurization and model
223+
compute happen on the same device. If they differ, samples are materialized on
224+
``descriptor.device`` and transferred before the forward pass. The compact
225+
examples on this page create the descriptor on CPU, so later
226+
``device='cuda'`` examples describe CPU-side sample preparation feeding GPU
227+
training unless you also move the descriptor to CUDA.
228+
229+
For HDF5-backed datasets, each worker reopens its own read-only file handle
230+
and keeps its own bounded ``in_memory_cache_size`` LRU cache. Trainer-owned
231+
runtime caches (``cache_features``, ``cache_neighbors``,
232+
``cache_force_triplets``) are also per process/worker, so
233+
``cache_warmup=True`` is skipped automatically when ``num_workers > 0``. See
234+
:doc:`torch_datasets` for persisted HDF5 cache precedence and for the
235+
distinction between build-time ``build_workers`` and training-time
236+
``num_workers``.
237+
238+
``memory_mode='mixed'`` is reserved for a future real mixed-memory mode and
239+
currently raises ``NotImplementedError`` if requested. Today, the supported
240+
execution modes remain ``'cpu'`` and ``'gpu'``.
241+
204242
Performance Optimization Tips
205243
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206244

@@ -245,6 +283,10 @@ Performance Optimization Tips
245283
and legacy non-graph paths
246284
* **cache_force_triplets**: Cache CSR graphs and triplets for the default sparse
247285
force-training path instead of rebuilding them on demand
286+
* **cache_*_max_entries**: Bound the trainer-owned runtime caches per split
287+
and per process/worker instead of letting them grow without limit
288+
* **cache_warmup**: Optional single-process prefill of trainer-owned runtime
289+
caches before epoch 0; skipped automatically when ``num_workers > 0``
248290

249291
These runtime caches are distinct from the on-disk HDF5 persisted cache
250292
sections created with ``HDF5StructureDataset.build_database(...)``. For HDF5
@@ -339,10 +381,11 @@ Large Dataset (> 500 structures)
339381
method=Adam(mu=0.001, batchsize=64), # Larger batches
340382
testpercent=10,
341383
force_weight=0.1,
342-
device='cuda', # Use GPU for speedup
384+
device='cuda', # Model/loss on GPU
343385
# Performance optimizations
344386
cache_features=True, # Runtime in-memory feature cache
345-
num_workers=8, # Parallel data loading
387+
cache_feature_max_entries=1024,
388+
num_workers=8, # Parallel CPU-side sample preparation
346389
prefetch_factor=4
347390
)
348391
@@ -356,7 +399,8 @@ Energy-Only with Maximum Speed
356399
method=Adam(mu=0.001, batchsize=32),
357400
testpercent=10,
358401
force_weight=0.0, # Energy-only
359-
cache_features=True, # Eager/runtime feature cache for this run
402+
cache_features=True, # Bounded runtime feature cache for this run
403+
cache_warmup=True, # Optional single-process prefill
360404
device='cuda'
361405
)
362406
@@ -371,8 +415,8 @@ Force Training with Optimizations
371415
testpercent=10,
372416
force_weight=0.1,
373417
force_fraction=0.3, # Use 30% of forces (3× faster)
374-
cache_neighbors=True, # Cache neighbor lists
375-
num_workers=4,
418+
cache_neighbors=True, # Cache worker-local neighbor lists
419+
num_workers=4, # Parallel CPU-side sample preparation
376420
device='cuda'
377421
)
378422
@@ -410,6 +454,13 @@ To resume training from a checkpoint, pass the checkpoint path to
410454
``train(..., resume_from="checkpoints/checkpoint_epoch_0050.pt")``. The
411455
notebook above contains the maintained checkpoint workflow.
412456

457+
When ``resume_from`` is provided, ``config.iterations`` means the number of
458+
additional epochs to run in that ``train()`` call. For example, resuming a
459+
checkpoint with ``iterations=10`` runs 10 more epochs after the saved
460+
checkpoint epoch, regardless of how many epochs were completed in the
461+
original run. This applies to numbered checkpoints and ``best_model.pt``
462+
alike.
463+
413464
The trainer will automatically:
414465

415466
* Load model and optimizer state
@@ -512,26 +563,51 @@ Performance & Caching
512563
* For force training (``force_weight > 0``): Caches features for structures not
513564
selected for force supervision in current epoch (useful with ``force_fraction < 1.0``)
514565

566+
**cache_feature_max_entries** : int or None (default: 1024)
567+
Maximum number of trainer-owned energy-view feature entries to retain per
568+
split and per process/worker when ``cache_features=True``. Use ``None`` for
569+
an explicit unbounded cache or ``0`` to suppress storage.
570+
515571
**cache_neighbors** : bool (default: False)
516572
Cache per-structure neighbor graphs (indices, displacement vectors) across
517573
epochs. Avoids repeated neighbor searches for fixed geometries on
518574
energy-view reuse and legacy non-graph paths. Supported force training
519575
does not require this option.
520576

577+
**cache_neighbor_max_entries** : int or None (default: 512)
578+
Maximum number of trainer-owned neighbor payload entries to retain per
579+
split and per process/worker when ``cache_neighbors=True``. Use ``None`` for
580+
an explicit unbounded cache or ``0`` to suppress storage.
581+
521582
**cache_force_triplets** : bool (default: False)
522583
Cache CSR neighbor graphs and precompute angular triplet indices for the
523584
default sparse force-training path. Leaving this disabled still uses the
524585
sparse graph/triplet path, but rebuilds those graph payloads on demand.
525586

587+
**cache_force_triplet_max_entries** : int or None (default: 256)
588+
Maximum number of trainer-owned graph/triplet payload entries to retain per
589+
split and per process/worker when ``cache_force_triplets=True``. Use
590+
``None`` for an explicit unbounded cache or ``0`` to suppress storage.
591+
526592
**cache_persist_dir** : str (default: None)
527593
Directory for persisting graph/triplet caches to disk for reuse across runs.
528594

529595
**cache_scope** : str (default: 'all')
530596
Which dataset splits to cache: ``'train'``, ``'val'``, or ``'all'``.
531597

598+
**cache_warmup** : bool (default: False)
599+
If True, pre-populate trainer-owned runtime caches before the first epoch
600+
in single-process training. When all enabled caches have finite entry
601+
limits, warmup stops once those limits are filled. Warmup is skipped
602+
automatically when ``num_workers > 0`` because workers own their own cache
603+
instances and the main-process warmup would not populate those worker-local
604+
caches.
605+
532606
**num_workers** : int (default: 0)
533-
Number of parallel DataLoader workers for on-the-fly featurization.
534-
0 = main process only. Values >0 enable parallel data loading.
607+
Number of parallel ``DataLoader`` workers for structure loading, HDF5
608+
reads, and on-the-fly featurization. ``0`` keeps sample preparation in the
609+
main process. Values ``>0`` parallelize worker-side sample preparation; they
610+
do not parallelize model compute.
535611

536612
**prefetch_factor** : int (default: 2)
537613
Number of batches to prefetch per worker when ``num_workers > 0``.
@@ -540,7 +616,9 @@ Performance & Caching
540616
Keep DataLoader workers alive between epochs for faster iteration.
541617
During training, this is disabled automatically when
542618
``force_sampling='random'`` uses epoch-level resampling, because worker
543-
copies would otherwise keep a stale force-supervision subset.
619+
copies would otherwise keep a stale force-supervision subset. Trainer-owned
620+
runtime caches and HDF5 ``in_memory_cache_size`` state are also
621+
worker-local when ``num_workers > 0``.
544622

545623

546624
Data Filtering & Quality Control
@@ -593,7 +671,11 @@ Output & Diagnostics
593671
Save predicted energies for train/test sets to disk. The
594672
``Path-of-input-file`` column preserves the original structure path or
595673
name when available; otherwise it uses a stable ``structure_XXXXXX``
596-
identifier from the pre-split input order.
674+
identifier from the pre-split input order. For HDF5-backed datasets,
675+
the identifier is reconstructed from persisted metadata as
676+
``path#frame=N`` when the source path is available, ``name#frame=N``
677+
when only the persisted name is available, and
678+
``structure_XXXXXX#frame=N`` as the final fallback.
597679

598680
**save_forces** : bool (default: False)
599681
Save predicted forces for train/test sets to disk.
@@ -618,11 +700,17 @@ Advanced Options
618700

619701
**memory_mode** : str (default: 'gpu')
620702
Memory management strategy: ``'cpu'``, ``'gpu'``, or ``'mixed'``.
621-
Controls where data and intermediate results are stored.
703+
``'mixed'`` is reserved for a future real mixed-memory implementation and
704+
currently raises ``NotImplementedError``. Use ``'cpu'`` or ``'gpu'`` with
705+
``descriptor.device`` and ``device`` set explicitly to control the current
706+
execution path.
622707

623708
**device** : str (default: None)
624709
PyTorch device: ``'cpu'``, ``'cuda'``, or ``'cuda:0'``. Auto-detected if
625-
None.
710+
None. This selects the model/training-loop device. ``descriptor.device``
711+
separately controls where structures are featurized. When the two differ,
712+
samples are prepared on ``descriptor.device`` and moved to ``device``
713+
before the forward pass.
626714

627715

628716
Monitoring Training Progress

notebooks/example-05-torch-training.ipynb

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@
104104
"\n",
105105
"<sup>1</sup>By default, the cohesive energy is the training target and recommended.\n",
106106
"\n",
107-
"<sup>2</sup>The `train()` method accepts strings and `Path` objects, as well as\n",
108-
"lists of `AtomicStructure` objects. More advanced data handling is possible\n",
109-
"using dataset classes from `aenet.torch_training.dataset`.\n",
110-
"\n",
111-
"Energy-only training can benefit significantly from feature caching. Setting\n",
112-
"`cache_features=True` below takes the automatic cached-dataset path internally.\n"
107+
"<sup>2</sup>The `train()` method accepts strings and `Path` objects, as well as\n",
108+
"lists of `AtomicStructure` objects. More advanced data handling is possible\n",
109+
"using dataset classes from `aenet.torch_training.dataset`.\n",
110+
"\n",
111+
"Energy-only training can benefit significantly from feature caching. Setting\n",
112+
"`cache_features=True` below takes the automatic cached-dataset path internally.\n",
113+
"The descriptor device controls where structures are featurized, while\n",
114+
"`config.device` controls where the model, losses, and optimizer run.\n",
115+
"In this notebook the descriptor stays on CPU, so a later GPU config would\n",
116+
"mean CPU-side sample preparation feeding a GPU training loop.\n"
113117
]
114118
},
115119
{
@@ -139,6 +143,7 @@
139143
"\n",
140144
"pot = TorchANNPotential(arch, descriptor=descr)\n",
141145
"\n",
146+
"# Resume from the saved best checkpoint and run 12 additional epochs.\n",
142147
"cfg = TorchTrainingConfig(\n",
143148
" atomic_energies={'O': -432.503149303, 'Ti': -1604.604515075},\n",
144149
" testpercent=10,\n",
@@ -342,15 +347,21 @@
342347
"source": [
343348
"# 5. Force training\n",
344349
"\n",
345-
"Training on both energies and forces is computationally significantly more\n",
346-
"expensive and memory intensive. Feature caching is not effective for the\n",
347-
"gradient evaluation required for force training, so additional caching of\n",
348-
"neighbor lists and triplet vectors is available instead. Typically, a fraction\n",
349-
"of all structures is randomly selected for force training to balance efficiency\n",
350-
"and accuracy.\n",
351-
"\n",
352-
"The fixed-split dataset objects above are the more reliable approach when you\n",
353-
"want train/test membership to stay unchanged across repeated runs.\n"
350+
"Training on both energies and forces is computationally significantly more\n",
351+
"expensive and memory intensive. Feature caching is not effective for the\n",
352+
"gradient evaluation required for force training, so additional caching of\n",
353+
"neighbor lists and triplet vectors is available instead. Typically, a fraction\n",
354+
"of all structures is randomly selected for force training to balance efficiency\n",
355+
"and accuracy.\n",
356+
"\n",
357+
"With `num_workers > 0`, loading and featurization stay on the worker side.\n",
358+
"The collated batch is moved onto `config.device` before the forward and loss\n",
359+
"steps, so worker parallelism and model-device selection are separate knobs.\n",
360+
"`memory_mode='mixed'` is reserved for a future real mixed-memory mode and\n",
361+
"currently raises `NotImplementedError` if requested.\n",
362+
"\n",
363+
"The fixed-split dataset objects above are the more reliable approach when you\n",
364+
"want train/test membership to stay unchanged across repeated runs.\n"
354365
]
355366
},
356367
{
@@ -458,11 +469,13 @@
458469
"derivative cache. The trainer will then load those derivatives lazily per\n",
459470
"force-supervised sample instead of recomputing them on every pass.\n",
460471
"\n",
461-
"`build_database()` leaves the same dataset instance ready for read-only use, so\n",
462-
"reopening the file is optional and mainly useful in a later session.\n",
463-
"\n",
464-
"This does **not** persist the network features themselves.\n",
465-
"`persist_force_derivatives=True` writes sparse derivative blocks to disk for\n",
472+
"`build_database()` leaves the same dataset instance ready for read-only use, so\n",
473+
"reopening the file is optional and mainly useful in a later session.\n",
474+
"When `num_workers > 0`, each worker reopens its own read-only HDF5 handle and\n",
475+
"owns its own in-memory dataset/runtime caches.\n",
476+
"\n",
477+
"This does **not** persist the network features themselves.\n",
478+
"`persist_force_derivatives=True` writes sparse derivative blocks to disk for\n",
466479
"cross-run reuse, whereas `cache_force_triplets=True` and `cache_features=True`\n",
467480
"are in-memory runtime caches configured on `TorchTrainingConfig`.\n"
468481
]
@@ -522,6 +535,7 @@
522535
"# Reuse the dataset built above. Reopening the HDF5 file is optional.\n",
523536
"pot_hdf5 = TorchANNPotential(arch, descriptor=descr)\n",
524537
"\n",
538+
"# Resume from the saved best checkpoint and run 6 additional epochs.\n",
525539
"cfg_hdf5 = TorchTrainingConfig(\n",
526540
" atomic_energies={'O': -432.503149303, 'Ti': -1604.604515075},\n",
527541
" iterations=6,\n",

0 commit comments

Comments
 (0)