Skip to content

Exact one-epoch training from local data (padding-free distributed sampler)#72

Open
ksd3 wants to merge 1 commit into
mainfrom
feat/exact-epoch-local
Open

Exact one-epoch training from local data (padding-free distributed sampler)#72
ksd3 wants to merge 1 commit into
mainfrom
feat/exact-epoch-local

Conversation

@ksd3

@ksd3 ksd3 commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

For scaling studies that must train every config on the identical data, every example exactly once. Streaming + split_dataset_by_node + buffered shuffle can't do that — the order depends on GPU/worker count, and the standard distributed samplers either drop the tail or pad it with duplicates when len % world_size != 0.

src/astropt/exact_epoch.py (new, unit-tested)

  • ExactDistributedSampler — partitions range(N) across ranks with no padding, no dropping: a single seeded shuffle, sliced strided (idx[rank::world_size]). Exact coverage, reproducible (depends only on data_seed), and independent of num_workers (the sampler fixes the order; workers only fetch).
  • one_epoch_loop — consumes the loader exactly once; an optimizer step every grad_accum_per_rank micro-batches with a short final step. The exact partition gives every rank the same micro-batch count, so the short final step is identical across ranks → the per-step all-reduce count stays balanced and DDP needs no Join. Gradients sync every backward (all-reduce is linear ⇒ identical to accumulate-then-sync).
  • Self-tests (python -m astropt.exact_epoch): exact partition for many sizes incl. a prime; the loop consumes exactly N with the expected step count. For the galaxies set (N=8,474,566), W∈{1,2,4,8} all balance to 13,242 steps.

scripts/train.py

When data_dir is set: load the galaxies map-style from local parquet, shard with ExactDistributedSampler, run one_epoch_loop, and assert the all-reduced example count == dataset size — a run self-proves exact coverage or fails loudly. Streaming path unchanged.

Notes

Add an exact one-epoch training path (gated by the new `data_dir` config),
for scaling studies that must train every config on the IDENTICAL data,
each example exactly once.

New `astropt.exact_epoch`:
* `ExactDistributedSampler` -- partitions range(N) across ranks with NO
  padding and NO dropping (strided slice of one seeded shuffle). Exact
  coverage, reproducible (depends only on `data_seed`), and independent of
  num_workers. Unlike DistributedSampler it never duplicates/drops the tail.
* `one_epoch_loop` -- consumes the loader exactly once; an optimizer step
  every `grad_accum_per_rank` micro-batches with a short final step. Because
  the exact partition gives every rank the same micro-batch count, the final
  step is short by the same amount on every rank, so the per-step all-reduce
  count stays balanced and DDP needs no Join.
* Self-tests (`python -m astropt.exact_epoch`) prove exact partition for
  many sizes (incl. a prime) and that the loop consumes exactly N with the
  expected step count; for the galaxies set (N=8,474,566) all of W in
  {1,2,4,8} balance to 13,242 steps.

`scripts/train.py`: when `data_dir` is set, load the galaxies map-style from
local parquet, shard with `ExactDistributedSampler`, run `one_epoch_loop`,
and assert the all-reduced example count equals the dataset size (so a run
self-proves exact coverage or fails loudly). The streaming path is
unchanged; `data_seed` is also added (overlaps with the data-order PR).

Note: the DDP path is arithmetically balanced + self-asserting but has not
been run multi-GPU here -- needs a smoke test on a real 4-GPU node.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant