Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions code/tests/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,38 @@ def test_non_tensor_passthrough(self):
out = dict_to_device(d, torch.device("cpu"))
self.assertEqual(out["a"], 1)
self.assertEqual(out["b"], "x")


class TestStepsPerEpochCalculation(unittest.TestCase):
"""steps_per_epoch must count forward passes (one generate_batch per step),
not optimizer steps. accumulate_steps only controls optimizer.step()
frequency and must not reduce the loop iteration count."""

def _steps_per_epoch(self, num_samples, per_device_bs, world_size):
"""Mirror the calculation in train.py:main() (line ~1298)."""
return num_samples // (per_device_bs * world_size)

def test_accumulate_steps_does_not_reduce_data_seen(self):
num_samples = 8_388_608
batch_size = 512
world_size = 1

steps = self._steps_per_epoch(num_samples, batch_size, world_size)
data_seen = steps * batch_size * world_size

self.assertEqual(data_seen, num_samples)

def test_multi_gpu(self):
num_samples = 8_388_608
batch_size = 512
world_size = 4

steps = self._steps_per_epoch(num_samples, batch_size, world_size)
data_seen = steps * batch_size * world_size

self.assertEqual(data_seen, num_samples)

def test_default_production_values(self):
steps = self._steps_per_epoch(67_108_864, 512, 1)
self.assertEqual(steps, 131072)
self.assertEqual(steps * 512, 67_108_864)
2 changes: 1 addition & 1 deletion code/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ def _print_gen(name, g):
per_device_batch_size = get_current_per_device_batch_size(epoch, cfg)
accumulate_steps = cfg.train.accumulate_steps
effective_batch_size = per_device_batch_size * accumulate_steps * dist.world_size
steps_per_epoch = effective_num_samples // effective_batch_size
steps_per_epoch = effective_num_samples // (per_device_batch_size * dist.world_size)

if dist.rank == 0:
print(
Expand Down
Loading