Add Colab training notebooks for all 3 model variants#2
Conversation
- Atomic_1Bit_Train_Stories.ipynb: Stories Base (~1.3M params, TinyStories) - Atomic_1Bit_Train_Pocket.ipynb: Pocket (~10M params, Alpaca) - Atomic_1Bit_Train_Instruct.ipynb: Flagship (~12.5M params, Alpaca) All notebooks are self-contained with inlined model code, Google Drive persistence, AMP support, and tqdm progress bars.
There was a problem hiding this comment.
Pull request overview
Adds three self-contained Google Colab notebooks to train the Atomic-1Bit models (Stories Base, Pocket, Flagship Instruct), including inlined model code, Drive checkpointing, AMP support, and basic progress/visualization.
Changes:
- Added a TinyStories training notebook for the ~1.3M “Stories Base” variant.
- Added an Alpaca Cleaned training notebook for the ~10M “Pocket” variant (with cosine annealing).
- Added an Alpaca Cleaned training notebook for the ~12.5M “Flagship Instruct” variant (warmup+cosine, grad accumulation, logging, checkpoint download).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| notebooks/Atomic_1Bit_Train_Stories.ipynb | Colab notebook to train Stories Base with vocab filtering + Drive checkpointing. |
| notebooks/Atomic_1Bit_Train_Pocket.ipynb | Colab notebook to train Pocket on Alpaca with cosine annealing + Drive checkpointing. |
| notebooks/Atomic_1Bit_Train_Instruct.ipynb | Colab notebook to train Flagship Instruct with grad accumulation, warmup+cosine schedule, logging, and checkpoint download. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| " print(f'Loading TinyStories ({split})...')\n", | ||
| " self.dataset = load_dataset('roneneldan/TinyStories', split=f'{split}[:10%]')\n", | ||
| " self.enc = tiktoken.get_encoding('gpt2')\n", |
There was a problem hiding this comment.
The notebook intro says it trains on TinyStories, but the dataset loader uses split=f'{split}[:10%]', which silently limits training to 10% of the split. Consider making the percentage an explicit hyperparameter (or defaulting to the full split) and documenting it in the markdown so users aren’t surprised by lower data volume.
| "LR = 1e-3\n", | ||
| "# ----------------------------------------------\n", | ||
| "\n", | ||
| "class PocketStoriesDataset:\n", |
There was a problem hiding this comment.
In the Stories notebook, the dataset wrapper is named PocketStoriesDataset, which is easy to confuse with the Pocket model notebook. Renaming it to something specific like TinyStoriesDataset/StoriesDataset would make the notebook clearer and reduce copy/paste confusion.
| "class PocketStoriesDataset:\n", | |
| "class TinyStoriesDataset:\n", |
| " checkpoint = torch.load(ckpt_path, map_location=device)\n", | ||
| " model.load_state_dict(checkpoint.get('model_state_dict', checkpoint))\n", | ||
| " if 'step' in checkpoint:\n", | ||
| " start_step = checkpoint['step']\n", |
There was a problem hiding this comment.
When resuming, start_step = checkpoint['step'] will cause the training loop range(start_step, total_steps) to repeat the last saved step (since checkpoints are written after completing step). Consider storing/loading the next step to run (e.g., save step + 1 in the checkpoint, or set start_step = checkpoint['step'] + 1 on load) to avoid duplicate updates.
| " start_step = checkpoint['step']\n", | |
| " start_step = checkpoint['step'] + 1\n", |
| "\n", | ||
| "# Final save\n", | ||
| "save_dict = {\n", | ||
| " 'step': total_steps,\n", |
There was a problem hiding this comment.
The final checkpoint stores 'step': total_steps, but the loop runs for step in range(start_step, total_steps) so the last completed step is total_steps - 1. This off-by-one makes resume logic ambiguous; consider persisting the next step to run (or persist the last completed step consistently) and aligning the final save accordingly.
| " 'step': total_steps,\n", | |
| " 'step': step,\n", |
| "# Final save\n", | ||
| "save_dict = {\n", | ||
| " 'step': total_steps,\n", | ||
| " 'model_state_dict': model.state_dict(),\n", | ||
| " 'optimizer_state_dict': optimizer.state_dict(),\n", | ||
| " 'scheduler_state_dict': scheduler.state_dict(),\n", | ||
| "}\n", |
There was a problem hiding this comment.
Final checkpoint stores 'step': total_steps, but the loop runs range(start_step, total_steps) so the last completed step index is total_steps - 1. Align the persisted step semantics (either save the next step to run, or save the last completed step consistently) to avoid confusion and off-by-one behavior on resume.
| " if 'model_state_dict' in checkpoint:\n", | ||
| " model.load_state_dict(checkpoint['model_state_dict'])\n", | ||
| " if 'step' in checkpoint:\n", | ||
| " start_step = checkpoint['step']\n", |
There was a problem hiding this comment.
On resume, start_step = checkpoint['step'] will cause the loop range(start_step, total_steps) to repeat the last saved step (since checkpoints are written after completing step). Consider persisting/loading the next step to run (e.g., save step + 1, or set start_step = checkpoint['step'] + 1 after loading).
| " start_step = checkpoint['step']\n", | |
| " start_step = checkpoint['step'] + 1\n", |
| "scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps)\n", | ||
| "\n", | ||
| "if start_step > 0:\n", | ||
| " for _ in range(start_step):\n", | ||
| " scheduler.step()\n", | ||
| " if checkpoint and 'scheduler_state_dict' in checkpoint:\n", | ||
| " try:\n", | ||
| " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", | ||
| " print(' Scheduler state restored.')\n", | ||
| " except:\n", | ||
| " print(' Warning: Could not restore scheduler, using re-computed state.')\n", |
There was a problem hiding this comment.
The scheduler restore logic does for _ in range(start_step): scheduler.step() and then potentially loads scheduler_state_dict. This does unnecessary work (and can be very slow for large start_step), and also triggers the common "lr_scheduler.step() before optimizer.step()" warning. Prefer constructing the scheduler with last_epoch=start_step-1 (or calling scheduler.step(start_step) once) when no state dict is available, and when a scheduler_state_dict exists, load it directly without pre-stepping.
| "scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps)\n", | |
| "\n", | |
| "if start_step > 0:\n", | |
| " for _ in range(start_step):\n", | |
| " scheduler.step()\n", | |
| " if checkpoint and 'scheduler_state_dict' in checkpoint:\n", | |
| " try:\n", | |
| " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", | |
| " print(' Scheduler state restored.')\n", | |
| " except:\n", | |
| " print(' Warning: Could not restore scheduler, using re-computed state.')\n", | |
| "scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps, last_epoch=start_step-1)\n", | |
| "\n", | |
| "if start_step > 0 and checkpoint and 'scheduler_state_dict' in checkpoint:\n", | |
| " try:\n", | |
| " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", | |
| " print(' Scheduler state restored.')\n", | |
| " except:\n", | |
| " print(' Warning: Could not restore scheduler, using re-computed state.')\n", |
| "# Final save\n", | ||
| "save_dict = {\n", | ||
| " 'step': total_steps,\n", | ||
| " 'model_state_dict': model.state_dict(),\n", | ||
| " 'optimizer_state_dict': optimizer.state_dict(),\n", | ||
| " 'scheduler_state_dict': scheduler.state_dict(),\n", | ||
| " 'rng_state': torch.get_rng_state(),\n", | ||
| " 'np_rng_state': np.random.get_state(),\n", | ||
| " 'config': {\n", | ||
| " 'vocab_size': VOCAB_SIZE, 'dim': DIM, 'depth': DEPTH,\n", | ||
| " 'heads': HEADS, 'context_length': CONTEXT_LEN,\n", | ||
| " },\n", | ||
| "}\n", | ||
| "torch.save(save_dict, ckpt_path)\n", | ||
| "print(f'\\n✅ Training complete! Checkpoint saved to {ckpt_path}')" |
There was a problem hiding this comment.
Final checkpoint stores 'step': total_steps, but the loop runs range(start_step, total_steps), meaning the last completed step index is total_steps - 1. Align the meaning of the stored step (last completed vs next-to-run) so resuming doesn't introduce off-by-one behavior.
| " if 'model_state_dict' in checkpoint:\n", | ||
| " model.load_state_dict(checkpoint['model_state_dict'])\n", | ||
| " if 'step' in checkpoint:\n", | ||
| " start_step = checkpoint['step']\n", |
There was a problem hiding this comment.
Similar to the Stories notebook, resuming sets start_step = checkpoint['step'], but the checkpoint is saved after completing step, so range(start_step, total_steps) will repeat that step on resume. Consider saving/loading the next step to run (e.g., persist step + 1 or increment start_step after loading).
| " start_step = checkpoint['step']\n", | |
| " start_step = checkpoint['step'] + 1\n", |
| "# Restore scheduler if available\n", | ||
| "if os.path.exists(ckpt_path) and 'scheduler_state_dict' in checkpoint:\n", | ||
| " try:\n", | ||
| " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", | ||
| " print(' Scheduler state restored.')\n", | ||
| " except:\n", | ||
| " pass\n", | ||
| "\n", |
There was a problem hiding this comment.
The LR scheduler configuration is inconsistent with the resume/extension logic: total_steps = start_step + ADDITIONAL_STEPS extends training beyond the original run, but the scheduler uses T_max=ADDITIONAL_STEPS and then (optionally) restores scheduler_state_dict. If you resume and keep training past T_max, CosineAnnealingLR will start increasing LR again. Consider either (a) setting T_max to the intended total training horizon and advancing it to start_step, or (b) treating each run as a fresh schedule and not restoring scheduler_state_dict when total_steps is extended.
| "# Restore scheduler if available\n", | |
| "if os.path.exists(ckpt_path) and 'scheduler_state_dict' in checkpoint:\n", | |
| " try:\n", | |
| " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", | |
| " print(' Scheduler state restored.')\n", | |
| " except:\n", | |
| " pass\n", | |
| "\n", |
All notebooks are self-contained with inlined model code, Google Drive persistence, AMP support, and tqdm progress bars.