diff --git a/code/tests/test_training_utils.py b/code/tests/test_training_utils.py index 4d9c68e..b6b5f8b 100644 --- a/code/tests/test_training_utils.py +++ b/code/tests/test_training_utils.py @@ -128,3 +128,38 @@ def test_non_tensor_passthrough(self): out = dict_to_device(d, torch.device("cpu")) self.assertEqual(out["a"], 1) self.assertEqual(out["b"], "x") + + +class TestStepsPerEpochCalculation(unittest.TestCase): + """steps_per_epoch must count forward passes (one generate_batch per step), + not optimizer steps. accumulate_steps only controls optimizer.step() + frequency and must not reduce the loop iteration count.""" + + def _steps_per_epoch(self, num_samples, per_device_bs, world_size): + """Mirror the calculation in train.py:main() (line ~1298).""" + return num_samples // (per_device_bs * world_size) + + def test_accumulate_steps_does_not_reduce_data_seen(self): + num_samples = 8_388_608 + batch_size = 512 + world_size = 1 + + steps = self._steps_per_epoch(num_samples, batch_size, world_size) + data_seen = steps * batch_size * world_size + + self.assertEqual(data_seen, num_samples) + + def test_multi_gpu(self): + num_samples = 8_388_608 + batch_size = 512 + world_size = 4 + + steps = self._steps_per_epoch(num_samples, batch_size, world_size) + data_seen = steps * batch_size * world_size + + self.assertEqual(data_seen, num_samples) + + def test_default_production_values(self): + steps = self._steps_per_epoch(67_108_864, 512, 1) + self.assertEqual(steps, 131072) + self.assertEqual(steps * 512, 67_108_864) diff --git a/code/training/train.py b/code/training/train.py index 1929baa..89e1151 100644 --- a/code/training/train.py +++ b/code/training/train.py @@ -1295,7 +1295,7 @@ def _print_gen(name, g): per_device_batch_size = get_current_per_device_batch_size(epoch, cfg) accumulate_steps = cfg.train.accumulate_steps effective_batch_size = per_device_batch_size * accumulate_steps * dist.world_size - steps_per_epoch = effective_num_samples // effective_batch_size + steps_per_epoch = effective_num_samples // (per_device_batch_size * dist.world_size) if dist.rank == 0: print(