Summary
When the total number of samples is not evenly divisible by comm_size, the dlio_sampler class in torch_data_loader.py assigns fewer samples to the last rank than to all other ranks. With drop_last=True (the default for training), the last rank produces one fewer batch per epoch. The training loop calls comm.barrier() once per batch and once at end of epoch, so the mismatched batch count causes ranks to enter incompatible collectives and deadlock at every epoch boundary.
Root cause
torch_data_loader.py, lines 418-423:
class dlio_sampler(Sampler):
def __init__(self, rank, size, num_samples, epochs):
samples_per_proc = int(math.ceil(num_samples / size)) # ← ceil
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc - 1
if end_sample > num_samples - 1:
end_sample = num_samples - 1 # ← clamp last rank
self.indices = list(range(start_sample, end_sample + 1))
math.ceil(N / size) overshoots for the last rank, then the clamp at line 422 cuts it back. When N % size != 0, the last rank ends up with fewer samples than every other rank.
Example (N=100, 7 ranks, batch_size=3)
| Ranks |
samples_per_proc |
Actual samples |
Batches (drop_last) |
| 0..5 |
15 |
15 |
5 |
| 6 |
15 |
10 (clamped) |
3 |
The last rank gets 2 fewer batches. The same pattern applies at any scale where N % size != 0.
How the deadlock forms
main.py _train() has two barriers:
# line 456 — called once per batch step
self.comm.barrier()
# line 478 — called once at end of epoch
self.comm.barrier()
Execution trace (using the N=100, 7 ranks, batch=3 example above):
- All 7 ranks enter the training loop. Ranks 0..5 do 5 batches; rank 6 does 3.
- For the first 3 steps, all ranks call the line 456 barrier together. No problem.
- Step 4: ranks 0..5 call barrier at line 456. Rank 6 has finished its loop and calls barrier at line 478 (end of epoch). MPI matches these because it does not distinguish barrier call sites.
- All ranks unblock. Ranks 0..5 continue stepping; rank 6 has already moved on to
reconfigure() for the next epoch.
- Step 5: ranks 0..5 call barrier at line 456. Rank 6 calls
comm.reduce() from reconfigure(). These are incompatible collectives. Permanent deadlock.
Proposed fix
Replace ceil with floor division so all ranks get the same sample count:
samples_per_proc = num_samples // size
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc - 1
self.indices = list(range(start_sample, end_sample + 1))
This drops at most size - 1 samples per epoch. The clamp is no longer needed since floor never overshoots.
Acceptance criteria
Impact
- Any
total_samples not divisible by comm_size deadlocks after the first epoch. Since dataset sizes are rarely exact multiples of the rank count, this affects most multi-node multi-epoch runs.
- The failure mode is silent: no error, no timeout, just permanent CPU spin. Operators have to manually cancel and have no diagnostic pointing at the sampler.
Summary
When the total number of samples is not evenly divisible by
comm_size, thedlio_samplerclass intorch_data_loader.pyassigns fewer samples to the last rank than to all other ranks. Withdrop_last=True(the default for training), the last rank produces one fewer batch per epoch. The training loop callscomm.barrier()once per batch and once at end of epoch, so the mismatched batch count causes ranks to enter incompatible collectives and deadlock at every epoch boundary.Root cause
torch_data_loader.py, lines 418-423:math.ceil(N / size)overshoots for the last rank, then the clamp at line 422 cuts it back. WhenN % size != 0, the last rank ends up with fewer samples than every other rank.Example (N=100, 7 ranks, batch_size=3)
The last rank gets 2 fewer batches. The same pattern applies at any scale where
N % size != 0.How the deadlock forms
main.py_train()has two barriers:Execution trace (using the N=100, 7 ranks, batch=3 example above):
reconfigure()for the next epoch.comm.reduce()fromreconfigure(). These are incompatible collectives. Permanent deadlock.Proposed fix
Replace
ceilwith floor division so all ranks get the same sample count:This drops at most
size - 1samples per epoch. The clamp is no longer needed sincefloornever overshoots.Acceptance criteria
N % size.size - 1per epoch.N=100, size=7, batch=3, drop_last=Trueproduces equal batch counts on all ranks.Impact
total_samplesnot divisible bycomm_sizedeadlocks after the first epoch. Since dataset sizes are rarely exact multiples of the rank count, this affects most multi-node multi-epoch runs.