diff --git a/notebooks/Atomic_1Bit_Train_Instruct.ipynb b/notebooks/Atomic_1Bit_Train_Instruct.ipynb new file mode 100644 index 0000000..e873e4d --- /dev/null +++ b/notebooks/Atomic_1Bit_Train_Instruct.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ⚛️ Atomic-1Bit — Train Flagship Instruct Model (Colab)\n", + "\n", + "Train the **Flagship 12.5M** instruct model on [Alpaca Cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned).\n", + "\n", + "| Param | Value |\n", + "|---|---|\n", + "| **Params** | ~12.5M |\n", + "| **Dim** | 320 |\n", + "| **Depth** | 8 |\n", + "| **Heads** | 5 |\n", + "| **Vocab** | 4096 (frequency-filtered) |\n", + "| **Context** | 256 |\n", + "| **Effective Batch** | 256 (32 × 8 grad accum) |\n", + "| **Scheduler** | Cosine with linear warmup |\n", + "| **Gradient Clipping** | 1.0 |\n", + "| **Weight Decay** | 0.1 (non-norm/bias/emb) |\n", + "\n", + "**Runtime**: Select **GPU** (Runtime → Change runtime type → T4 GPU).\n", + "\n", + "> ⚡ This is the largest model. A T4 GPU is sufficient but training will be significantly faster on V100/A100 (Colab Pro)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 · Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q torch tiktoken datasets numpy matplotlib tqdm pyyaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "import os\n", + "DRIVE_DIR = '/content/drive/MyDrive/Atomic-1Bit/weights'\n", + "os.makedirs(DRIVE_DIR, exist_ok=True)\n", + "print(f'Checkpoints will be saved to: {DRIVE_DIR}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 · Model Code (Inlined)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math\n", + "from dataclasses import dataclass\n", + "\n", + "def activation_quant(x):\n", + " scale = 127.0 / x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)\n", + " y = (x * scale).round().clamp(-127, 127)\n", + " y_ste = (y - x * scale).detach() + x * scale\n", + " return y_ste, scale\n", + "\n", + "def weight_quant(w):\n", + " scale = 1.0 / w.abs().mean().clamp(min=1e-5)\n", + " y = (w * scale).round().clamp(-1, 1)\n", + " y_ste = (y - w * scale).detach() + w * scale\n", + " return y_ste, scale\n", + "\n", + "class BitLinear(nn.Module):\n", + " def __init__(self, in_features, out_features, bias=False):\n", + " super().__init__()\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.weight = nn.Parameter(torch.randn(out_features, in_features))\n", + " if bias:\n", + " self.bias = nn.Parameter(torch.zeros(out_features))\n", + " else:\n", + " self.register_parameter('bias', None)\n", + " self.eps = 1e-5\n", + "\n", + " def forward(self, x):\n", + " x_f32 = x.float()\n", + " rms = torch.sqrt(torch.mean(x_f32 ** 2, dim=-1, keepdim=True) + self.eps)\n", + " x_norm = x_f32 / rms\n", + " x_quant_ste, scale_x = activation_quant(x_norm)\n", + " w_quant_ste, scale_w = weight_quant(self.weight)\n", + " y = F.linear(x_quant_ste, w_quant_ste)\n", + " y_out = y / (scale_x * scale_w)\n", + " if self.bias is not None:\n", + " y_out += self.bias\n", + " return y_out\n", + "\n", + "@dataclass\n", + "class AtomicConfig:\n", + " vocab_size: int = 50257\n", + " dim: int = 512\n", + " depth: int = 8\n", + " heads: int = 8\n", + " context_length: int = 1024\n", + "\n", + "class BitAttention(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " assert config.dim % config.heads == 0\n", + " self.dim = config.dim\n", + " self.heads = config.heads\n", + " self.head_dim = config.dim // config.heads\n", + " self.q_proj = BitLinear(config.dim, config.dim)\n", + " self.k_proj = BitLinear(config.dim, config.dim)\n", + " self.v_proj = BitLinear(config.dim, config.dim)\n", + " self.o_proj = BitLinear(config.dim, config.dim)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " B, T, C = x.shape\n", + " q = self.q_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " k = self.k_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " v = self.v_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " if kv_cache is not None:\n", + " cached_k, cached_v = kv_cache\n", + " k = torch.cat([cached_k, k], dim=2)\n", + " v = torch.cat([cached_v, v], dim=2)\n", + " new_kv_cache = (k, v)\n", + " T_total = k.shape[2]\n", + " att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))\n", + " mask = torch.ones(T, T_total, device=x.device, dtype=torch.bool)\n", + " mask = torch.triu(mask, diagonal=T_total - T + 1)\n", + " att = att.masked_fill(mask, float('-inf'))\n", + " att = F.softmax(att, dim=-1)\n", + " y = att @ v\n", + " y = y.transpose(1, 2).contiguous().view(B, T, C)\n", + " return self.o_proj(y), new_kv_cache\n", + "\n", + "class BitFeedForward(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " hidden_dim = 4 * config.dim\n", + " self.fc1 = BitLinear(config.dim, hidden_dim)\n", + " self.fc2 = BitLinear(hidden_dim, config.dim)\n", + " self.act = nn.GELU()\n", + "\n", + " def forward(self, x):\n", + " return self.fc2(self.act(self.fc1(x)))\n", + "\n", + "class AtomicBlock(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.ln1 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.attn = BitAttention(config)\n", + " self.ln2 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.mlp = BitFeedForward(config)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " attn_out, new_kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache)\n", + " x = x + attn_out\n", + " x = x + self.mlp(self.ln2(x))\n", + " return x, new_kv_cache\n", + "\n", + "class AtomicTransformer(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.config = config\n", + " self.token_emb = nn.Embedding(config.vocab_size, config.dim)\n", + " self.pos_emb = nn.Embedding(config.context_length, config.dim)\n", + " self.layers = nn.ModuleList([AtomicBlock(config) for _ in range(config.depth)])\n", + " self.ln_f = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.head = BitLinear(config.dim, config.vocab_size)\n", + "\n", + " def forward(self, idx, kv_cache=None):\n", + " B, T = idx.shape\n", + " if kv_cache is not None and kv_cache[0] is not None:\n", + " pos_offset = kv_cache[0][0].shape[2]\n", + " else:\n", + " pos_offset = 0\n", + " pos = torch.arange(pos_offset, pos_offset + T, dtype=torch.long, device=idx.device)\n", + " x = self.token_emb(idx) + self.pos_emb(pos)\n", + " new_kv_cache = []\n", + " for i, layer in enumerate(self.layers):\n", + " layer_cache = kv_cache[i] if kv_cache is not None else None\n", + " x, new_cache = layer(x, kv_cache=layer_cache)\n", + " new_kv_cache.append(new_cache)\n", + " x = self.ln_f(x)\n", + " logits = self.head(x)\n", + " if kv_cache is not None:\n", + " return logits, new_kv_cache\n", + " return logits\n", + "\n", + "\n", + "def init_weights(model):\n", + " \"\"\"Apply scaled Kaiming initialization to BitLinear latent weights.\"\"\"\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, BitLinear):\n", + " torch.nn.init.kaiming_normal_(module.weight, a=math.sqrt(5))\n", + "\n", + "print('✅ Model code loaded.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3 · Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import numpy as np\n", + "import tiktoken\n", + "import time\n", + "from collections import Counter\n", + "from datasets import load_dataset\n", + "\n", + "# -------- Hyperparameters --------\n", + "BATCH_SIZE = 32\n", + "GRAD_ACCUM_STEPS = 8 # Effective batch = 32 × 8 = 256\n", + "CONTEXT_LEN = 256\n", + "DIM = 320\n", + "DEPTH = 8\n", + "HEADS = 5\n", + "VOCAB_SIZE = 4096\n", + "UNK_ID = VOCAB_SIZE - 1\n", + "LR = 3e-4\n", + "WARMUP_STEPS = 1000\n", + "CLIP_GRAD = 1.0\n", + "WEIGHT_DECAY = 0.1\n", + "# ---------------------------------\n", + "\n", + "class EfficientInstructDataset:\n", + " def __init__(self, split='train', context_length=256, vocab_file=None):\n", + " if vocab_file is None:\n", + " vocab_file = os.path.join(DRIVE_DIR, 'vocab_map_instruct.json')\n", + " print(f'Loading Alpaca Cleaned ({split})...')\n", + " raw_dataset = load_dataset('yahma/alpaca-cleaned', split=split)\n", + " self.enc = tiktoken.get_encoding('gpt2')\n", + " self.context_length = context_length\n", + " self.vocab_file = vocab_file\n", + "\n", + " print(f'Filtering dataset (Max Tokens: {context_length}, Min Tokens: 10)...')\n", + " def filter_fn(sample):\n", + " text = self.format_prompt(sample)\n", + " ids = self.enc.encode(text)\n", + " return 10 <= len(ids) + 1 <= context_length\n", + "\n", + " self.dataset = raw_dataset.filter(filter_fn)\n", + " print(f'Kept {len(self.dataset)}/{len(raw_dataset)} clean samples.')\n", + "\n", + " self.token_map = {}\n", + " self.reverse_map = {}\n", + " self._init_vocab()\n", + "\n", + " def _init_vocab(self):\n", + " if os.path.exists(self.vocab_file):\n", + " print(f'Loading vocab map from {self.vocab_file}...')\n", + " with open(self.vocab_file, 'r') as f:\n", + " data = json.load(f)\n", + " self.token_map = {int(k): v for k, v in data['token_map'].items()}\n", + " self.reverse_map = {int(k): v for k, v in data['reverse_map'].items()}\n", + " print(f'Loaded {len(self.token_map)} mapped tokens.')\n", + " return\n", + "\n", + " print('Building Frequency-Based Vocab (Scanning first 20k filtered samples)...')\n", + " counter = Counter()\n", + " scan_limit = min(20000, len(self.dataset))\n", + " for i in range(scan_limit):\n", + " row = self.dataset[i]\n", + " text = self.format_prompt(row)\n", + " ids = self.enc.encode(text)\n", + " counter.update(ids)\n", + " eot = self.enc.eot_token\n", + " most_common = counter.most_common(VOCAB_SIZE - 2)\n", + " new_id = 0\n", + " valid_gpt_ids = [k for k, v in most_common]\n", + " if eot not in valid_gpt_ids:\n", + " valid_gpt_ids.append(eot)\n", + " valid_gpt_ids = valid_gpt_ids[:VOCAB_SIZE - 1]\n", + " for gpt_id in valid_gpt_ids:\n", + " self.token_map[gpt_id] = new_id\n", + " self.reverse_map[new_id] = gpt_id\n", + " new_id += 1\n", + " self.unk_token = UNK_ID\n", + " print(f'Saving vocab map to {self.vocab_file}...')\n", + " os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)\n", + " with open(self.vocab_file, 'w') as f:\n", + " json.dump({'token_map': self.token_map, 'reverse_map': self.reverse_map}, f)\n", + "\n", + " def format_prompt(self, sample):\n", + " text = f\"### Instruction: {sample['instruction']}\\n\"\n", + " if sample.get('input', ''):\n", + " text += f\"### Input: {sample['input']}\\n\"\n", + " text += f\"### Response: {sample['output']}\"\n", + " return text\n", + "\n", + " def get_batch(self, batch_size):\n", + " indices = np.random.randint(0, len(self.dataset), batch_size)\n", + " rows = self.dataset.select(indices)\n", + " batch_input_ids, batch_targets = [], []\n", + " for i in range(len(rows)):\n", + " row = rows[i]\n", + " text = self.format_prompt(row)\n", + " gpt_ids = self.enc.encode(text)\n", + " gpt_ids.append(self.enc.eot_token)\n", + " pocket_ids = [self.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " if len(pocket_ids) < self.context_length + 1:\n", + " eot_mapped = self.token_map.get(self.enc.eot_token, UNK_ID)\n", + " pocket_ids += [eot_mapped] * (self.context_length + 1 - len(pocket_ids))\n", + " batch_input_ids.append(pocket_ids[:-1])\n", + " batch_targets.append(pocket_ids[1:])\n", + " return torch.tensor(batch_input_ids, dtype=torch.long), torch.tensor(batch_targets, dtype=torch.long)\n", + "\n", + "print('✅ Dataset class ready.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 · Training Config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ADDITIONAL_STEPS = 5000\n", + "USE_AMP = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5 · Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "from tqdm.auto import tqdm\n", + "import matplotlib.pyplot as plt\n", + "import csv\n", + "\n", + "def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.01):\n", + " \"\"\"Linear warmup followed by cosine decay.\"\"\"\n", + " def lr_lambda(step):\n", + " if step < warmup_steps:\n", + " return float(step) / float(max(1, warmup_steps))\n", + " progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))\n", + " return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))\n", + " return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", + "\n", + "\n", + "def generate_demo(model, ds, instruction='What is AI?', max_tokens=60):\n", + " model.eval()\n", + " device = next(model.parameters()).device\n", + " prompt = f\"### Instruction: {instruction}\\n### Response:\"\n", + " gpt_ids = ds.enc.encode(prompt)\n", + " ids = [ds.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " x = torch.tensor([ids], dtype=torch.long).to(device)\n", + " eot_mapped = ds.token_map.get(ds.enc.eot_token, UNK_ID)\n", + " tokens = []\n", + " for _ in range(max_tokens):\n", + " if x.size(1) >= CONTEXT_LEN:\n", + " break\n", + " with torch.no_grad():\n", + " logits = model(x)\n", + " probs = F.softmax(logits[:, -1, :], dim=-1)\n", + " next_token = torch.multinomial(probs, 1)\n", + " pocket_id = next_token.item()\n", + " gpt_id = ds.reverse_map.get(pocket_id, ds.enc.eot_token)\n", + " try:\n", + " tokens.append(ds.enc.decode([gpt_id]))\n", + " except:\n", + " pass\n", + " x = torch.cat([x, next_token], dim=1)\n", + " if pocket_id == eot_mapped:\n", + " break\n", + " model.train()\n", + " return instruction + '\\n' + ''.join(tokens)\n", + "\n", + "# --- Setup ---\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(f'Device: {device}')\n", + "if device == 'cuda':\n", + " print(f'GPU: {torch.cuda.get_device_name(0)}')\n", + " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')\n", + "\n", + "ds = EfficientInstructDataset(context_length=CONTEXT_LEN)\n", + "config = AtomicConfig(vocab_size=VOCAB_SIZE, dim=DIM, depth=DEPTH, heads=HEADS, context_length=CONTEXT_LEN)\n", + "model = AtomicTransformer(config).to(device)\n", + "\n", + "start_step = 0\n", + "ckpt_path = os.path.join(DRIVE_DIR, 'instruct_final.pt')\n", + "\n", + "# Separate weight decay groups\n", + "decay_params, no_decay_params = [], []\n", + "for name, param in model.named_parameters():\n", + " if not param.requires_grad:\n", + " continue\n", + " if 'ln' in name or 'bias' in name or 'emb' in name:\n", + " no_decay_params.append(param)\n", + " else:\n", + " decay_params.append(param)\n", + "\n", + "optimizer = optim.AdamW([\n", + " {'params': decay_params, 'weight_decay': WEIGHT_DECAY},\n", + " {'params': no_decay_params, 'weight_decay': 0.0},\n", + "], lr=LR)\n", + "\n", + "# Resume from checkpoint\n", + "checkpoint = None\n", + "if os.path.exists(ckpt_path):\n", + " print(f'>> Resuming from checkpoint: {ckpt_path}')\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " if 'model_state_dict' in checkpoint:\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " if 'step' in checkpoint:\n", + " start_step = checkpoint['step']\n", + " if 'optimizer_state_dict' in checkpoint:\n", + " try:\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " print(' Optimizer state restored.')\n", + " except:\n", + " print(' Warning: Could not restore optimizer state')\n", + " if 'rng_state' in checkpoint:\n", + " torch.set_rng_state(checkpoint['rng_state'])\n", + " if 'np_rng_state' in checkpoint:\n", + " np.random.set_state(checkpoint['np_rng_state'])\n", + " print(' RNG state restored.')\n", + " else:\n", + " model.load_state_dict(checkpoint)\n", + " print(f' Resuming from step {start_step}')\n", + "else:\n", + " print('>> Starting fresh Flagship model')\n", + " init_weights(model)\n", + " print(' Applied Kaiming initialization to BitLinear weights.')\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n", + "print(f'Parameters: {total_params:.2f}M')\n", + "\n", + "total_steps = start_step + ADDITIONAL_STEPS\n", + "\n", + "# Cosine schedule with warmup\n", + "scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps)\n", + "\n", + "if start_step > 0:\n", + " for _ in range(start_step):\n", + " scheduler.step()\n", + " if checkpoint and 'scheduler_state_dict' in checkpoint:\n", + " try:\n", + " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", + " print(' Scheduler state restored.')\n", + " except:\n", + " print(' Warning: Could not restore scheduler, using re-computed state.')\n", + "\n", + "scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device == 'cuda'))\n", + "losses = []\n", + "\n", + "# Training log on Drive\n", + "log_file = os.path.join(DRIVE_DIR, 'training_log.csv')\n", + "if not os.path.exists(log_file):\n", + " with open(log_file, 'w') as f:\n", + " f.write('step,loss,lr,step_time_ms\\n')\n", + "\n", + "print(f'Training from step {start_step} to {total_steps} (Grad Accum: {GRAD_ACCUM_STEPS}, Eff Batch: {BATCH_SIZE * GRAD_ACCUM_STEPS})...')\n", + "optimizer.zero_grad()\n", + "pbar = tqdm(range(start_step, total_steps), desc='Training')\n", + "step_start_time = time.time()\n", + "\n", + "for step in pbar:\n", + " # Gradient Accumulation\n", + " loss_accum = 0.0\n", + " for _ in range(GRAD_ACCUM_STEPS):\n", + " x, y = ds.get_batch(BATCH_SIZE)\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " with torch.amp.autocast('cuda', enabled=(USE_AMP and device == 'cuda')):\n", + " logits = model(x)\n", + " loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))\n", + " loss = loss / GRAD_ACCUM_STEPS\n", + "\n", + " scaler.scale(loss).backward()\n", + " loss_accum += loss.item()\n", + "\n", + " # Gradient clipping\n", + " scaler.unscale_(optimizer)\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)\n", + "\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " optimizer.zero_grad()\n", + " scheduler.step()\n", + "\n", + " # Timing\n", + " step_end_time = time.time()\n", + " step_duration_ms = (step_end_time - step_start_time) * 1000\n", + " step_start_time = step_end_time\n", + "\n", + " losses.append(loss_accum)\n", + " current_lr = scheduler.get_last_lr()[0]\n", + " pbar.set_postfix(loss=f'{loss_accum:.4f}', lr=f'{current_lr:.2e}', ms=f'{step_duration_ms:.0f}')\n", + "\n", + " # Log to CSV\n", + " if step % 10 == 0:\n", + " with open(log_file, 'a') as f:\n", + " f.write(f'{step},{loss_accum:.5f},{current_lr:.5e},{step_duration_ms:.1f}\\n')\n", + "\n", + " if step % 500 == 0 and step > 0:\n", + " sample = generate_demo(model, ds, 'What is AI?')\n", + " tqdm.write(f'\\n--- Step {step} Sample ---\\n{sample}\\n')\n", + "\n", + " if step > 0 and step % 1000 == 0:\n", + " save_dict = {\n", + " 'step': step,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'scheduler_state_dict': scheduler.state_dict(),\n", + " 'rng_state': torch.get_rng_state(),\n", + " 'np_rng_state': np.random.get_state(),\n", + " 'config': {\n", + " 'vocab_size': VOCAB_SIZE, 'dim': DIM, 'depth': DEPTH,\n", + " 'heads': HEADS, 'context_length': CONTEXT_LEN,\n", + " },\n", + " }\n", + " torch.save(save_dict, ckpt_path)\n", + " tqdm.write(f'💾 Checkpoint saved at step {step}')\n", + "\n", + "# Final save\n", + "save_dict = {\n", + " 'step': total_steps,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'scheduler_state_dict': scheduler.state_dict(),\n", + " 'rng_state': torch.get_rng_state(),\n", + " 'np_rng_state': np.random.get_state(),\n", + " 'config': {\n", + " 'vocab_size': VOCAB_SIZE, 'dim': DIM, 'depth': DEPTH,\n", + " 'heads': HEADS, 'context_length': CONTEXT_LEN,\n", + " },\n", + "}\n", + "torch.save(save_dict, ckpt_path)\n", + "print(f'\\n✅ Training complete! Checkpoint saved to {ckpt_path}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6 · Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", + "\n", + "# Loss curve\n", + "axes[0].plot(losses, alpha=0.3, label='Raw')\n", + "window = min(100, len(losses) // 5) if len(losses) > 10 else 1\n", + "if window > 1:\n", + " smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')\n", + " axes[0].plot(range(window-1, len(losses)), smoothed, label=f'Smoothed ({window})')\n", + "axes[0].set_xlabel('Step')\n", + "axes[0].set_ylabel('Loss')\n", + "axes[0].set_title('Training Loss')\n", + "axes[0].legend()\n", + "axes[0].grid(True, alpha=0.3)\n", + "\n", + "# LR schedule visualization\n", + "lr_schedule = []\n", + "temp_opt = optim.AdamW(model.parameters(), lr=LR)\n", + "temp_sched = get_cosine_schedule_with_warmup(temp_opt, WARMUP_STEPS, total_steps)\n", + "for _ in range(total_steps):\n", + " lr_schedule.append(temp_sched.get_last_lr()[0])\n", + " temp_sched.step()\n", + "axes[1].plot(lr_schedule)\n", + "axes[1].set_xlabel('Step')\n", + "axes[1].set_ylabel('Learning Rate')\n", + "axes[1].set_title('LR Schedule (Warmup + Cosine)')\n", + "axes[1].grid(True, alpha=0.3)\n", + "\n", + "plt.suptitle('Atomic-1Bit Flagship Instruct (12.5M)', fontsize=14, fontweight='bold')\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Generate samples\n", + "prompts = ['What is AI?', 'Explain gravity simply.', 'Write a haiku about computers.', 'Count to 5.']\n", + "print('\\n📝 Generated Samples:')\n", + "print('=' * 60)\n", + "for p in prompts:\n", + " sample = generate_demo(model, ds, p)\n", + " print(f'\\n{sample}')\n", + " print('-' * 60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7 · Download Checkpoint\n", + "\n", + "Your checkpoint is already saved to Google Drive. To use it locally:\n", + "\n", + "1. Go to [Google Drive](https://drive.google.com) → `Atomic-1Bit/weights/`\n", + "2. Download `instruct_final.pt`\n", + "3. Place it in your local `weights/` directory\n", + "\n", + "Or download directly from Colab:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optional: Download checkpoint directly from Colab\n", + "from google.colab import files\n", + "files.download(ckpt_path)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/Atomic_1Bit_Train_Pocket.ipynb b/notebooks/Atomic_1Bit_Train_Pocket.ipynb new file mode 100644 index 0000000..91661bd --- /dev/null +++ b/notebooks/Atomic_1Bit_Train_Pocket.ipynb @@ -0,0 +1,531 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ⚛️ Atomic-1Bit — Train Pocket Alpaca Model (Colab)\n", + "\n", + "Train the **Pocket** model (~10M params) on [Alpaca Cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned).\n", + "\n", + "| Param | Value |\n", + "|---|---|\n", + "| **Dim** | 320 |\n", + "| **Depth** | 8 |\n", + "| **Heads** | 5 |\n", + "| **Vocab** | 4096 (frequency-filtered) |\n", + "| **Context** | 128 |\n", + "| **Scheduler** | Cosine Annealing |\n", + "\n", + "**Runtime**: Select **GPU** (Runtime → Change runtime type → T4 GPU)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 · Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q torch tiktoken datasets numpy matplotlib tqdm pyyaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "import os\n", + "DRIVE_DIR = '/content/drive/MyDrive/Atomic-1Bit/weights'\n", + "os.makedirs(DRIVE_DIR, exist_ok=True)\n", + "print(f'Checkpoints will be saved to: {DRIVE_DIR}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 · Model Code (Inlined)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math\n", + "from dataclasses import dataclass\n", + "\n", + "def activation_quant(x):\n", + " scale = 127.0 / x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)\n", + " y = (x * scale).round().clamp(-127, 127)\n", + " y_ste = (y - x * scale).detach() + x * scale\n", + " return y_ste, scale\n", + "\n", + "def weight_quant(w):\n", + " scale = 1.0 / w.abs().mean().clamp(min=1e-5)\n", + " y = (w * scale).round().clamp(-1, 1)\n", + " y_ste = (y - w * scale).detach() + w * scale\n", + " return y_ste, scale\n", + "\n", + "class BitLinear(nn.Module):\n", + " def __init__(self, in_features, out_features, bias=False):\n", + " super().__init__()\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.weight = nn.Parameter(torch.randn(out_features, in_features))\n", + " if bias:\n", + " self.bias = nn.Parameter(torch.zeros(out_features))\n", + " else:\n", + " self.register_parameter('bias', None)\n", + " self.eps = 1e-5\n", + "\n", + " def forward(self, x):\n", + " x_f32 = x.float()\n", + " rms = torch.sqrt(torch.mean(x_f32 ** 2, dim=-1, keepdim=True) + self.eps)\n", + " x_norm = x_f32 / rms\n", + " x_quant_ste, scale_x = activation_quant(x_norm)\n", + " w_quant_ste, scale_w = weight_quant(self.weight)\n", + " y = F.linear(x_quant_ste, w_quant_ste)\n", + " y_out = y / (scale_x * scale_w)\n", + " if self.bias is not None:\n", + " y_out += self.bias\n", + " return y_out\n", + "\n", + "@dataclass\n", + "class AtomicConfig:\n", + " vocab_size: int = 50257\n", + " dim: int = 512\n", + " depth: int = 8\n", + " heads: int = 8\n", + " context_length: int = 1024\n", + "\n", + "class BitAttention(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " assert config.dim % config.heads == 0\n", + " self.dim = config.dim\n", + " self.heads = config.heads\n", + " self.head_dim = config.dim // config.heads\n", + " self.q_proj = BitLinear(config.dim, config.dim)\n", + " self.k_proj = BitLinear(config.dim, config.dim)\n", + " self.v_proj = BitLinear(config.dim, config.dim)\n", + " self.o_proj = BitLinear(config.dim, config.dim)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " B, T, C = x.shape\n", + " q = self.q_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " k = self.k_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " v = self.v_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " if kv_cache is not None:\n", + " cached_k, cached_v = kv_cache\n", + " k = torch.cat([cached_k, k], dim=2)\n", + " v = torch.cat([cached_v, v], dim=2)\n", + " new_kv_cache = (k, v)\n", + " T_total = k.shape[2]\n", + " att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))\n", + " mask = torch.ones(T, T_total, device=x.device, dtype=torch.bool)\n", + " mask = torch.triu(mask, diagonal=T_total - T + 1)\n", + " att = att.masked_fill(mask, float('-inf'))\n", + " att = F.softmax(att, dim=-1)\n", + " y = att @ v\n", + " y = y.transpose(1, 2).contiguous().view(B, T, C)\n", + " return self.o_proj(y), new_kv_cache\n", + "\n", + "class BitFeedForward(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " hidden_dim = 4 * config.dim\n", + " self.fc1 = BitLinear(config.dim, hidden_dim)\n", + " self.fc2 = BitLinear(hidden_dim, config.dim)\n", + " self.act = nn.GELU()\n", + "\n", + " def forward(self, x):\n", + " return self.fc2(self.act(self.fc1(x)))\n", + "\n", + "class AtomicBlock(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.ln1 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.attn = BitAttention(config)\n", + " self.ln2 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.mlp = BitFeedForward(config)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " attn_out, new_kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache)\n", + " x = x + attn_out\n", + " x = x + self.mlp(self.ln2(x))\n", + " return x, new_kv_cache\n", + "\n", + "class AtomicTransformer(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.config = config\n", + " self.token_emb = nn.Embedding(config.vocab_size, config.dim)\n", + " self.pos_emb = nn.Embedding(config.context_length, config.dim)\n", + " self.layers = nn.ModuleList([AtomicBlock(config) for _ in range(config.depth)])\n", + " self.ln_f = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.head = BitLinear(config.dim, config.vocab_size)\n", + "\n", + " def forward(self, idx, kv_cache=None):\n", + " B, T = idx.shape\n", + " if kv_cache is not None and kv_cache[0] is not None:\n", + " pos_offset = kv_cache[0][0].shape[2]\n", + " else:\n", + " pos_offset = 0\n", + " pos = torch.arange(pos_offset, pos_offset + T, dtype=torch.long, device=idx.device)\n", + " x = self.token_emb(idx) + self.pos_emb(pos)\n", + " new_kv_cache = []\n", + " for i, layer in enumerate(self.layers):\n", + " layer_cache = kv_cache[i] if kv_cache is not None else None\n", + " x, new_cache = layer(x, kv_cache=layer_cache)\n", + " new_kv_cache.append(new_cache)\n", + " x = self.ln_f(x)\n", + " logits = self.head(x)\n", + " if kv_cache is not None:\n", + " return logits, new_kv_cache\n", + " return logits\n", + "\n", + "print('✅ Model code loaded.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3 · Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import numpy as np\n", + "import tiktoken\n", + "from collections import Counter\n", + "from datasets import load_dataset\n", + "\n", + "# -------- Hyperparameters --------\n", + "BATCH_SIZE = 32\n", + "CONTEXT_LEN = 128\n", + "DIM = 320\n", + "DEPTH = 8\n", + "HEADS = 5\n", + "VOCAB_SIZE = 4096\n", + "UNK_ID = VOCAB_SIZE - 1\n", + "LR = 6e-4\n", + "# ---------------------------------\n", + "\n", + "class PocketAlpacaDataset:\n", + " def __init__(self, split='train', context_length=128, vocab_file=None):\n", + " if vocab_file is None:\n", + " vocab_file = os.path.join(DRIVE_DIR, 'pocket_vocab_map.json')\n", + " print(f'Loading Alpaca Cleaned ({split})...')\n", + " raw_dataset = load_dataset('yahma/alpaca-cleaned', split=split)\n", + " self.enc = tiktoken.get_encoding('gpt2')\n", + " self.context_length = context_length\n", + " self.vocab_file = vocab_file\n", + "\n", + " print(f'Filtering dataset (Max Tokens: {context_length})...')\n", + " def filter_fn(sample):\n", + " text = self.format_prompt(sample)\n", + " ids = self.enc.encode(text)\n", + " return len(ids) + 1 <= context_length\n", + "\n", + " self.dataset = raw_dataset.filter(filter_fn)\n", + " print(f'Filtered: {len(raw_dataset)} → {len(self.dataset)} samples.')\n", + "\n", + " self.token_map = {}\n", + " self.reverse_map = {}\n", + " self._init_vocab()\n", + "\n", + " def _init_vocab(self):\n", + " if os.path.exists(self.vocab_file):\n", + " print(f'Loading vocab map from {self.vocab_file}...')\n", + " with open(self.vocab_file, 'r') as f:\n", + " data = json.load(f)\n", + " self.token_map = {int(k): v for k, v in data['token_map'].items()}\n", + " self.reverse_map = {int(k): v for k, v in data['reverse_map'].items()}\n", + " print(f'Loaded {len(self.token_map)} mapped tokens.')\n", + " return\n", + "\n", + " print('Building Frequency-Based Vocab (Scanning first 10k filtered samples)...')\n", + " counter = Counter()\n", + " scan_limit = min(10000, len(self.dataset))\n", + " for i in range(scan_limit):\n", + " row = self.dataset[i]\n", + " text = self.format_prompt(row)\n", + " ids = self.enc.encode(text)\n", + " counter.update(ids)\n", + " eot = self.enc.eot_token\n", + " most_common = counter.most_common(VOCAB_SIZE - 2)\n", + " new_id = 0\n", + " valid_gpt_ids = [k for k, v in most_common]\n", + " if eot not in valid_gpt_ids:\n", + " valid_gpt_ids.append(eot)\n", + " valid_gpt_ids = valid_gpt_ids[:VOCAB_SIZE - 1]\n", + " for gpt_id in valid_gpt_ids:\n", + " self.token_map[gpt_id] = new_id\n", + " self.reverse_map[new_id] = gpt_id\n", + " new_id += 1\n", + " self.unk_token = UNK_ID\n", + " print(f'Saving vocab map to {self.vocab_file}...')\n", + " os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)\n", + " with open(self.vocab_file, 'w') as f:\n", + " json.dump({'token_map': self.token_map, 'reverse_map': self.reverse_map}, f)\n", + "\n", + " def format_prompt(self, sample):\n", + " text = f\"### Instruction: {sample['instruction']}\\n\"\n", + " if sample.get('input', ''):\n", + " text += f\"### Input: {sample['input']}\\n\"\n", + " text += f\"### Response: {sample['output']}\"\n", + " return text\n", + "\n", + " def get_batch(self, batch_size):\n", + " indices = np.random.randint(0, len(self.dataset), batch_size)\n", + " rows = self.dataset.select(indices)\n", + " batch_input_ids, batch_targets = [], []\n", + " for i in range(len(rows)):\n", + " row = rows[i]\n", + " text = self.format_prompt(row)\n", + " gpt_ids = self.enc.encode(text)\n", + " gpt_ids.append(self.enc.eot_token)\n", + " pocket_ids = [self.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " if len(pocket_ids) < self.context_length + 1:\n", + " eot_mapped = self.token_map.get(self.enc.eot_token, UNK_ID)\n", + " pocket_ids += [eot_mapped] * (self.context_length + 1 - len(pocket_ids))\n", + " batch_input_ids.append(pocket_ids[:-1])\n", + " batch_targets.append(pocket_ids[1:])\n", + " return torch.tensor(batch_input_ids, dtype=torch.long), torch.tensor(batch_targets, dtype=torch.long)\n", + "\n", + "print('✅ Dataset class ready.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 · Training Config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ADDITIONAL_STEPS = 5000\n", + "USE_AMP = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5 · Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "from tqdm.auto import tqdm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def generate_demo(model, ds, instruction='Count to 5.', max_tokens=60):\n", + " model.eval()\n", + " device = next(model.parameters()).device\n", + " prompt = f\"### Instruction: {instruction}\\n### Response:\"\n", + " gpt_ids = ds.enc.encode(prompt)\n", + " ids = [ds.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " x = torch.tensor([ids], dtype=torch.long).to(device)\n", + " eot_mapped = ds.token_map.get(ds.enc.eot_token, UNK_ID)\n", + " tokens = []\n", + " for _ in range(max_tokens):\n", + " if x.size(1) >= CONTEXT_LEN:\n", + " break\n", + " with torch.no_grad():\n", + " logits = model(x)\n", + " probs = F.softmax(logits[:, -1, :], dim=-1)\n", + " next_token = torch.multinomial(probs, 1)\n", + " pocket_id = next_token.item()\n", + " gpt_id = ds.reverse_map.get(pocket_id, ds.enc.eot_token)\n", + " try:\n", + " tokens.append(ds.enc.decode([gpt_id]))\n", + " except:\n", + " pass\n", + " x = torch.cat([x, next_token], dim=1)\n", + " if pocket_id == eot_mapped:\n", + " break\n", + " model.train()\n", + " return instruction + '\\n' + ''.join(tokens)\n", + "\n", + "# --- Setup ---\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(f'Device: {device}')\n", + "\n", + "ds = PocketAlpacaDataset(context_length=CONTEXT_LEN)\n", + "config = AtomicConfig(vocab_size=VOCAB_SIZE, dim=DIM, depth=DEPTH, heads=HEADS, context_length=CONTEXT_LEN)\n", + "model = AtomicTransformer(config).to(device)\n", + "\n", + "start_step = 0\n", + "ckpt_path = os.path.join(DRIVE_DIR, 'pocket_final.pt')\n", + "optimizer = optim.AdamW(model.parameters(), lr=LR)\n", + "\n", + "if os.path.exists(ckpt_path):\n", + " print(f'>> Resuming from checkpoint: {ckpt_path}')\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " if 'model_state_dict' in checkpoint:\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " if 'step' in checkpoint:\n", + " start_step = checkpoint['step']\n", + " if 'optimizer_state_dict' in checkpoint:\n", + " try:\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " except:\n", + " print(' Warning: Could not restore optimizer state')\n", + " else:\n", + " model.load_state_dict(checkpoint)\n", + " print(f' Resuming from step {start_step}')\n", + "else:\n", + " print('>> Starting fresh Pocket model')\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n", + "print(f'Parameters: {total_params:.2f}M')\n", + "\n", + "total_steps = start_step + ADDITIONAL_STEPS\n", + "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ADDITIONAL_STEPS, eta_min=1e-5)\n", + "\n", + "# Restore scheduler if available\n", + "if os.path.exists(ckpt_path) and 'scheduler_state_dict' in checkpoint:\n", + " try:\n", + " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", + " print(' Scheduler state restored.')\n", + " except:\n", + " pass\n", + "\n", + "scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device == 'cuda'))\n", + "losses = []\n", + "\n", + "print(f'Training from step {start_step} to {total_steps}...')\n", + "pbar = tqdm(range(start_step, total_steps), desc='Training')\n", + "\n", + "for step in pbar:\n", + " x, y = ds.get_batch(BATCH_SIZE)\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " with torch.amp.autocast('cuda', enabled=(USE_AMP and device == 'cuda')):\n", + " logits = model(x)\n", + " loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))\n", + "\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " scheduler.step()\n", + "\n", + " loss_val = loss.item()\n", + " losses.append(loss_val)\n", + " current_lr = scheduler.get_last_lr()[0]\n", + " pbar.set_postfix(loss=f'{loss_val:.4f}', lr=f'{current_lr:.2e}')\n", + "\n", + " if step % 500 == 0 and step > 0:\n", + " sample = generate_demo(model, ds, 'Hi')\n", + " tqdm.write(f'\\n--- Step {step} Sample ---\\n{sample}\\n')\n", + "\n", + " if step > 0 and step % 1000 == 0:\n", + " save_dict = {\n", + " 'step': step,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'scheduler_state_dict': scheduler.state_dict(),\n", + " }\n", + " torch.save(save_dict, ckpt_path)\n", + " tqdm.write(f'💾 Checkpoint saved at step {step}')\n", + "\n", + "# Final save\n", + "save_dict = {\n", + " 'step': total_steps,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'scheduler_state_dict': scheduler.state_dict(),\n", + "}\n", + "torch.save(save_dict, ckpt_path)\n", + "print(f'\\n✅ Training complete! Checkpoint saved to {ckpt_path}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6 · Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(10, 4))\n", + "plt.plot(losses, alpha=0.3, label='Raw')\n", + "window = min(100, len(losses) // 5) if len(losses) > 10 else 1\n", + "if window > 1:\n", + " smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')\n", + " plt.plot(range(window-1, len(losses)), smoothed, label=f'Smoothed ({window})')\n", + "plt.xlabel('Step')\n", + "plt.ylabel('Loss')\n", + "plt.title('Atomic-1Bit Pocket — Training Loss')\n", + "plt.legend()\n", + "plt.grid(True, alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "prompts = ['What is AI?', 'Count to 5.', 'Tell me a joke.']\n", + "print('\\n📝 Generated Samples:')\n", + "print('=' * 60)\n", + "for p in prompts:\n", + " sample = generate_demo(model, ds, p)\n", + " print(f'\\n{sample}')\n", + " print('-' * 60)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/Atomic_1Bit_Train_Stories.ipynb b/notebooks/Atomic_1Bit_Train_Stories.ipynb new file mode 100644 index 0000000..b1d00b0 --- /dev/null +++ b/notebooks/Atomic_1Bit_Train_Stories.ipynb @@ -0,0 +1,513 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ⚛️ Atomic-1Bit — Train Stories Base Model (Colab)\n", + "\n", + "Train the **Stories Base** model (~1.3M params) on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories).\n", + "\n", + "| Param | Value |\n", + "|---|---|\n", + "| **Dim** | 256 |\n", + "| **Depth** | 6 |\n", + "| **Heads** | 4 |\n", + "| **Vocab** | 4096 (frequency-filtered) |\n", + "| **Context** | 128 |\n", + "\n", + "**Runtime**: Select **GPU** (Runtime → Change runtime type → T4 GPU)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 · Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q torch tiktoken datasets numpy matplotlib tqdm pyyaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mount Google Drive for persistent checkpoints\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "import os\n", + "DRIVE_DIR = '/content/drive/MyDrive/Atomic-1Bit/weights'\n", + "os.makedirs(DRIVE_DIR, exist_ok=True)\n", + "print(f'Checkpoints will be saved to: {DRIVE_DIR}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 · Model Code (Inlined)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math\n", + "from dataclasses import dataclass\n", + "\n", + "# ---------- BitLinear (1.58-bit) ----------\n", + "\n", + "def activation_quant(x):\n", + " \"\"\"Quantize activation to INT8 using AbsMax scaling with STE.\"\"\"\n", + " scale = 127.0 / x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)\n", + " y = (x * scale).round().clamp(-127, 127)\n", + " y_ste = (y - x * scale).detach() + x * scale\n", + " return y_ste, scale\n", + "\n", + "def weight_quant(w):\n", + " \"\"\"Quantize weights to {-1, 0, 1} using Mean scaling with STE.\"\"\"\n", + " scale = 1.0 / w.abs().mean().clamp(min=1e-5)\n", + " y = (w * scale).round().clamp(-1, 1)\n", + " y_ste = (y - w * scale).detach() + w * scale\n", + " return y_ste, scale\n", + "\n", + "class BitLinear(nn.Module):\n", + " def __init__(self, in_features, out_features, bias=False):\n", + " super().__init__()\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.weight = nn.Parameter(torch.randn(out_features, in_features))\n", + " if bias:\n", + " self.bias = nn.Parameter(torch.zeros(out_features))\n", + " else:\n", + " self.register_parameter('bias', None)\n", + " self.eps = 1e-5\n", + "\n", + " def forward(self, x):\n", + " x_f32 = x.float()\n", + " rms = torch.sqrt(torch.mean(x_f32 ** 2, dim=-1, keepdim=True) + self.eps)\n", + " x_norm = x_f32 / rms\n", + " x_quant_ste, scale_x = activation_quant(x_norm)\n", + " w_quant_ste, scale_w = weight_quant(self.weight)\n", + " y = F.linear(x_quant_ste, w_quant_ste)\n", + " y_out = y / (scale_x * scale_w)\n", + " if self.bias is not None:\n", + " y_out += self.bias\n", + " return y_out\n", + "\n", + "# ---------- Transformer ----------\n", + "\n", + "@dataclass\n", + "class AtomicConfig:\n", + " vocab_size: int = 50257\n", + " dim: int = 512\n", + " depth: int = 8\n", + " heads: int = 8\n", + " context_length: int = 1024\n", + "\n", + "class BitAttention(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " assert config.dim % config.heads == 0\n", + " self.dim = config.dim\n", + " self.heads = config.heads\n", + " self.head_dim = config.dim // config.heads\n", + " self.q_proj = BitLinear(config.dim, config.dim)\n", + " self.k_proj = BitLinear(config.dim, config.dim)\n", + " self.v_proj = BitLinear(config.dim, config.dim)\n", + " self.o_proj = BitLinear(config.dim, config.dim)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " B, T, C = x.shape\n", + " q = self.q_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " k = self.k_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " v = self.v_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)\n", + " if kv_cache is not None:\n", + " cached_k, cached_v = kv_cache\n", + " k = torch.cat([cached_k, k], dim=2)\n", + " v = torch.cat([cached_v, v], dim=2)\n", + " new_kv_cache = (k, v)\n", + " T_total = k.shape[2]\n", + " att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))\n", + " mask = torch.ones(T, T_total, device=x.device, dtype=torch.bool)\n", + " mask = torch.triu(mask, diagonal=T_total - T + 1)\n", + " att = att.masked_fill(mask, float('-inf'))\n", + " att = F.softmax(att, dim=-1)\n", + " y = att @ v\n", + " y = y.transpose(1, 2).contiguous().view(B, T, C)\n", + " return self.o_proj(y), new_kv_cache\n", + "\n", + "class BitFeedForward(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " hidden_dim = 4 * config.dim\n", + " self.fc1 = BitLinear(config.dim, hidden_dim)\n", + " self.fc2 = BitLinear(hidden_dim, config.dim)\n", + " self.act = nn.GELU()\n", + "\n", + " def forward(self, x):\n", + " return self.fc2(self.act(self.fc1(x)))\n", + "\n", + "class AtomicBlock(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.ln1 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.attn = BitAttention(config)\n", + " self.ln2 = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.mlp = BitFeedForward(config)\n", + "\n", + " def forward(self, x, kv_cache=None):\n", + " attn_out, new_kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache)\n", + " x = x + attn_out\n", + " x = x + self.mlp(self.ln2(x))\n", + " return x, new_kv_cache\n", + "\n", + "class AtomicTransformer(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.config = config\n", + " self.token_emb = nn.Embedding(config.vocab_size, config.dim)\n", + " self.pos_emb = nn.Embedding(config.context_length, config.dim)\n", + " self.layers = nn.ModuleList([AtomicBlock(config) for _ in range(config.depth)])\n", + " self.ln_f = nn.RMSNorm(config.dim, eps=1e-5)\n", + " self.head = BitLinear(config.dim, config.vocab_size)\n", + "\n", + " def forward(self, idx, kv_cache=None):\n", + " B, T = idx.shape\n", + " if kv_cache is not None and kv_cache[0] is not None:\n", + " pos_offset = kv_cache[0][0].shape[2]\n", + " else:\n", + " pos_offset = 0\n", + " pos = torch.arange(pos_offset, pos_offset + T, dtype=torch.long, device=idx.device)\n", + " x = self.token_emb(idx) + self.pos_emb(pos)\n", + " new_kv_cache = []\n", + " for i, layer in enumerate(self.layers):\n", + " layer_cache = kv_cache[i] if kv_cache is not None else None\n", + " x, new_cache = layer(x, kv_cache=layer_cache)\n", + " new_kv_cache.append(new_cache)\n", + " x = self.ln_f(x)\n", + " logits = self.head(x)\n", + " if kv_cache is not None:\n", + " return logits, new_kv_cache\n", + " return logits\n", + "\n", + "print('✅ Model code loaded.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3 · Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import numpy as np\n", + "import tiktoken\n", + "from collections import Counter\n", + "from datasets import load_dataset\n", + "\n", + "# -------- Hyperparameters (edit here) --------\n", + "BATCH_SIZE = 32\n", + "CONTEXT_LEN = 128\n", + "DIM = 256\n", + "DEPTH = 6\n", + "HEADS = 4\n", + "VOCAB_SIZE = 4096\n", + "UNK_ID = VOCAB_SIZE - 1\n", + "LR = 1e-3\n", + "# ----------------------------------------------\n", + "\n", + "class PocketStoriesDataset:\n", + " def __init__(self, split='train', context_length=128, vocab_file=None):\n", + " if vocab_file is None:\n", + " vocab_file = os.path.join(DRIVE_DIR, 'vocab_map_stories.json')\n", + " print(f'Loading TinyStories ({split})...')\n", + " self.dataset = load_dataset('roneneldan/TinyStories', split=f'{split}[:10%]')\n", + " self.enc = tiktoken.get_encoding('gpt2')\n", + " self.context_length = context_length\n", + " self.vocab_file = vocab_file\n", + " self.token_map = {}\n", + " self.reverse_map = {}\n", + " self._init_vocab()\n", + "\n", + " def _init_vocab(self):\n", + " if os.path.exists(self.vocab_file):\n", + " print(f'Loading vocab map from {self.vocab_file}...')\n", + " with open(self.vocab_file, 'r') as f:\n", + " data = json.load(f)\n", + " self.token_map = {int(k): v for k, v in data['token_map'].items()}\n", + " self.reverse_map = {int(k): v for k, v in data['reverse_map'].items()}\n", + " print(f'Loaded {len(self.token_map)} mapped tokens.')\n", + " return\n", + "\n", + " print('Building Frequency-Based Vocab (Scanning first 20k samples)...')\n", + " counter = Counter()\n", + " scan_limit = min(20000, len(self.dataset))\n", + " rows = self.dataset.select(range(scan_limit))\n", + " for text in rows['text']:\n", + " ids = self.enc.encode(text)\n", + " counter.update(ids)\n", + " eot = self.enc.eot_token\n", + " most_common = counter.most_common(VOCAB_SIZE - 2)\n", + " new_id = 0\n", + " valid_gpt_ids = [k for k, v in most_common]\n", + " if eot not in valid_gpt_ids:\n", + " valid_gpt_ids.append(eot)\n", + " valid_gpt_ids = valid_gpt_ids[:VOCAB_SIZE - 1]\n", + " for gpt_id in valid_gpt_ids:\n", + " self.token_map[gpt_id] = new_id\n", + " self.reverse_map[new_id] = gpt_id\n", + " new_id += 1\n", + " self.unk_token = UNK_ID\n", + " print(f'Saving vocab map to {self.vocab_file}...')\n", + " os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)\n", + " with open(self.vocab_file, 'w') as f:\n", + " json.dump({'token_map': self.token_map, 'reverse_map': self.reverse_map}, f)\n", + "\n", + " def get_batch(self, batch_size):\n", + " indices = np.random.randint(0, len(self.dataset), batch_size)\n", + " rows = self.dataset.select(indices)\n", + " batch_input_ids, batch_targets = [], []\n", + " for text in rows['text']:\n", + " gpt_ids = self.enc.encode(text)\n", + " gpt_ids.append(self.enc.eot_token)\n", + " pocket_ids = [self.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " if len(pocket_ids) < self.context_length + 1:\n", + " eot_mapped = self.token_map.get(self.enc.eot_token, UNK_ID)\n", + " pocket_ids += [eot_mapped] * (self.context_length + 1 - len(pocket_ids))\n", + " if len(pocket_ids) > self.context_length + 1:\n", + " start = np.random.randint(0, len(pocket_ids) - self.context_length - 1)\n", + " pocket_ids = pocket_ids[start : start + self.context_length + 1]\n", + " batch_input_ids.append(pocket_ids[:-1])\n", + " batch_targets.append(pocket_ids[1:])\n", + " x = torch.tensor(batch_input_ids, dtype=torch.long)\n", + " y = torch.tensor(batch_targets, dtype=torch.long)\n", + " return x, y\n", + "\n", + "print('✅ Dataset class ready.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 · Training Config\n", + "\n", + "Edit `ADDITIONAL_STEPS` to control how many steps to train." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# -------- Training length (edit here) --------\n", + "ADDITIONAL_STEPS = 5000\n", + "USE_AMP = True # Mixed precision for faster training on GPU\n", + "# -----------------------------------------------" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5 · Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "from tqdm.auto import tqdm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def generate_demo(model, ds, start_text='Once upon a time', max_tokens=60):\n", + " model.eval()\n", + " device = next(model.parameters()).device\n", + " gpt_ids = ds.enc.encode(start_text)\n", + " ids = [ds.token_map.get(gid, UNK_ID) for gid in gpt_ids]\n", + " x = torch.tensor([ids], dtype=torch.long).to(device)\n", + " eot_mapped = ds.token_map.get(ds.enc.eot_token, UNK_ID)\n", + " tokens = []\n", + " for _ in range(max_tokens):\n", + " if x.size(1) >= CONTEXT_LEN:\n", + " break\n", + " with torch.no_grad():\n", + " logits = model(x)\n", + " probs = F.softmax(logits[:, -1, :], dim=-1)\n", + " next_token = torch.multinomial(probs, 1)\n", + " pocket_id = next_token.item()\n", + " gpt_id = ds.reverse_map.get(pocket_id, ds.enc.eot_token)\n", + " try:\n", + " tokens.append(ds.enc.decode([gpt_id]))\n", + " except:\n", + " pass\n", + " x = torch.cat([x, next_token], dim=1)\n", + " if pocket_id == eot_mapped:\n", + " break\n", + " model.train()\n", + " return start_text + ''.join(tokens)\n", + "\n", + "# --- Setup ---\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(f'Device: {device}')\n", + "\n", + "ds = PocketStoriesDataset(context_length=CONTEXT_LEN)\n", + "config = AtomicConfig(vocab_size=VOCAB_SIZE, dim=DIM, depth=DEPTH, heads=HEADS, context_length=CONTEXT_LEN)\n", + "model = AtomicTransformer(config).to(device)\n", + "\n", + "start_step = 0\n", + "ckpt_path = os.path.join(DRIVE_DIR, 'stories_final.pt')\n", + "optimizer = optim.AdamW(model.parameters(), lr=LR)\n", + "\n", + "# Resume from checkpoint\n", + "if os.path.exists(ckpt_path):\n", + " print(f'>> Resuming from checkpoint: {ckpt_path}')\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " model.load_state_dict(checkpoint.get('model_state_dict', checkpoint))\n", + " if 'step' in checkpoint:\n", + " start_step = checkpoint['step']\n", + " if 'optimizer_state_dict' in checkpoint:\n", + " try:\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " except:\n", + " print(' Warning: Could not restore optimizer state')\n", + " print(f' Resuming from step {start_step}')\n", + "else:\n", + " print('>> Starting fresh model')\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n", + "print(f'Parameters: {total_params:.2f}M')\n", + "\n", + "total_steps = start_step + ADDITIONAL_STEPS\n", + "scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device == 'cuda'))\n", + "losses = []\n", + "\n", + "# --- Training Loop ---\n", + "print(f'Training from step {start_step} to {total_steps}...')\n", + "pbar = tqdm(range(start_step, total_steps), desc='Training')\n", + "\n", + "for step in pbar:\n", + " x, y = ds.get_batch(BATCH_SIZE)\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " with torch.amp.autocast('cuda', enabled=(USE_AMP and device == 'cuda')):\n", + " logits = model(x)\n", + " loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))\n", + "\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + "\n", + " loss_val = loss.item()\n", + " losses.append(loss_val)\n", + " pbar.set_postfix(loss=f'{loss_val:.4f}')\n", + "\n", + " if step % 500 == 0 and step > 0:\n", + " sample = generate_demo(model, ds, 'One day,')\n", + " tqdm.write(f'\\n--- Step {step} Sample ---\\n{sample}\\n')\n", + "\n", + " if step > 0 and step % 1000 == 0:\n", + " save_dict = {\n", + " 'step': step,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " }\n", + " torch.save(save_dict, ckpt_path)\n", + " tqdm.write(f'💾 Checkpoint saved at step {step}')\n", + "\n", + "# Final save\n", + "save_dict = {\n", + " 'step': total_steps,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + "}\n", + "torch.save(save_dict, ckpt_path)\n", + "print(f'\\n✅ Training complete! Checkpoint saved to {ckpt_path}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6 · Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Loss Curve ---\n", + "plt.figure(figsize=(10, 4))\n", + "plt.plot(losses, alpha=0.3, label='Raw')\n", + "# Smoothed\n", + "window = min(100, len(losses) // 5) if len(losses) > 10 else 1\n", + "if window > 1:\n", + " smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')\n", + " plt.plot(range(window-1, len(losses)), smoothed, label=f'Smoothed ({window})')\n", + "plt.xlabel('Step')\n", + "plt.ylabel('Loss')\n", + "plt.title('Atomic-1Bit Stories — Training Loss')\n", + "plt.legend()\n", + "plt.grid(True, alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# --- Generate Samples ---\n", + "prompts = ['Once upon a time', 'The little dog', 'She was very happy because']\n", + "print('\\n📝 Generated Samples:')\n", + "print('=' * 60)\n", + "for p in prompts:\n", + " sample = generate_demo(model, ds, p)\n", + " print(f'\\n{sample}')\n", + " print('-' * 60)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}