diff --git a/exp1_logs.txt b/exp1_logs.txt new file mode 100644 index 0000000000..baba324f4f Binary files /dev/null and b/exp1_logs.txt differ diff --git a/exp2_logs.txt b/exp2_logs.txt new file mode 100644 index 0000000000..f3c4a30411 Binary files /dev/null and b/exp2_logs.txt differ diff --git a/exp_balanced_peak.ptz b/exp_balanced_peak.ptz new file mode 100644 index 0000000000..773a3daa21 Binary files /dev/null and b/exp_balanced_peak.ptz differ diff --git a/exp_deep_stable.ptz b/exp_deep_stable.ptz new file mode 100644 index 0000000000..d97d52c19d Binary files /dev/null and b/exp_deep_stable.ptz differ diff --git a/exp_deep_stable.txt b/exp_deep_stable.txt new file mode 100644 index 0000000000..1189d4c8d9 --- /dev/null +++ b/exp_deep_stable.txt @@ -0,0 +1,123 @@ +model_params: 7214100 +step:0 val_loss:6.9391 val_bpb:3.0801 time:0ms +step:1/20000 train_loss:6.9395 time:1867ms +step:200/20000 train_loss:3.4674 time:356549ms +step:400/20000 train_loss:2.9106 time:713415ms +step:600/20000 train_loss:2.6936 time:1070507ms +step:800/20000 train_loss:2.6614 time:1427634ms +step:1000/20000 train_loss:2.5387 time:1784223ms +step:1000 val_loss:2.5594 val_bpb:1.1361 time:1784250ms +step:1200/20000 train_loss:2.5454 time:2141147ms +step:1400/20000 train_loss:2.5426 time:2498112ms +step:1600/20000 train_loss:2.4487 time:2855086ms +step:1800/20000 train_loss:2.3379 time:3212126ms +step:2000/20000 train_loss:2.4562 time:3568629ms +step:2000 val_loss:2.4289 val_bpb:1.0781 time:3568656ms +step:2200/20000 train_loss:2.4138 time:3925446ms +step:2400/20000 train_loss:2.4391 time:4282379ms +step:2600/20000 train_loss:2.4551 time:4639388ms +step:2800/20000 train_loss:2.5381 time:4996325ms +step:3000/20000 train_loss:2.3298 time:5352934ms +step:3000 val_loss:2.3803 val_bpb:1.0566 time:5352961ms +step:3200/20000 train_loss:2.3655 time:5710027ms +step:3400/20000 train_loss:2.3240 time:6066993ms +step:3600/20000 train_loss:2.3648 time:6423945ms +step:3800/20000 train_loss:2.3503 time:6780408ms +step:4000/20000 train_loss:2.2927 time:7137404ms +step:4000 val_loss:2.3472 val_bpb:1.0419 time:7137430ms +step:4200/20000 train_loss:2.2679 time:7494473ms +step:4400/20000 train_loss:2.3620 time:7851536ms +step:4600/20000 train_loss:2.2581 time:8208549ms +step:4800/20000 train_loss:2.3169 time:8565316ms +step:5000/20000 train_loss:2.2840 time:8922606ms +step:5000 val_loss:2.3301 val_bpb:1.0343 time:8922633ms +step:5200/20000 train_loss:2.3659 time:9279791ms +step:5400/20000 train_loss:2.3127 time:9636974ms +step:5600/20000 train_loss:2.3703 time:9994176ms +step:5800/20000 train_loss:2.3227 time:10350840ms +step:6000/20000 train_loss:2.3040 time:10708007ms +step:6000 val_loss:2.3168 val_bpb:1.0284 time:10708034ms +step:6200/20000 train_loss:2.2733 time:11065082ms +step:6400/20000 train_loss:2.2992 time:11422134ms +step:6600/20000 train_loss:2.3462 time:11778809ms +step:6800/20000 train_loss:2.3268 time:12135985ms +step:7000/20000 train_loss:2.3172 time:12493125ms +step:7000 val_loss:2.3070 val_bpb:1.0240 time:12493152ms +step:7200/20000 train_loss:2.3400 time:12850075ms +step:7400/20000 train_loss:2.3022 time:13207262ms +step:7600/20000 train_loss:2.3172 time:13563930ms +step:7800/20000 train_loss:2.2923 time:13921208ms +step:8000/20000 train_loss:2.4725 time:14278302ms +step:8000 val_loss:2.2961 val_bpb:1.0192 time:14278328ms +step:8200/20000 train_loss:2.2816 time:14635460ms +step:8400/20000 train_loss:2.2776 time:14992494ms +step:8600/20000 train_loss:2.3197 time:15349042ms +step:8800/20000 train_loss:2.2828 time:15706095ms +step:9000/20000 train_loss:2.3846 time:16063234ms +step:9000 val_loss:2.2874 val_bpb:1.0153 time:16063261ms +step:9200/20000 train_loss:2.2827 time:16420410ms +step:9400/20000 train_loss:2.2857 time:16777125ms +step:9600/20000 train_loss:2.3054 time:17134224ms +step:9800/20000 train_loss:2.2765 time:17491346ms +step:10000/20000 train_loss:2.2645 time:17848514ms +step:10000 val_loss:2.2813 val_bpb:1.0126 time:17848541ms +step:10200/20000 train_loss:2.2942 time:18205640ms +step:10400/20000 train_loss:2.2714 time:18562294ms +step:10600/20000 train_loss:2.2175 time:18919345ms +step:10800/20000 train_loss:2.2286 time:19276431ms +step:11000/20000 train_loss:2.3317 time:19633544ms +step:11000 val_loss:2.2760 val_bpb:1.0103 time:19633571ms +step:11200/20000 train_loss:2.2325 time:19990561ms +step:11400/20000 train_loss:2.3167 time:20347219ms +step:11600/20000 train_loss:2.2806 time:20704380ms +step:11800/20000 train_loss:2.2680 time:21061526ms +step:12000/20000 train_loss:2.2483 time:21418563ms +step:12000 val_loss:2.2716 val_bpb:1.0083 time:21418590ms +step:12200/20000 train_loss:2.2979 time:21775193ms +step:12400/20000 train_loss:2.4286 time:22132106ms +step:12600/20000 train_loss:2.2667 time:22489177ms +step:12800/20000 train_loss:2.3347 time:22846141ms +step:13000/20000 train_loss:2.3092 time:23203249ms +step:13000 val_loss:2.2664 val_bpb:1.0060 time:23203275ms +step:13200/20000 train_loss:2.2674 time:23559741ms +step:13400/20000 train_loss:2.1219 time:23916757ms +step:13600/20000 train_loss:2.3351 time:24273831ms +step:13800/20000 train_loss:2.1970 time:24630769ms +step:14000/20000 train_loss:2.2588 time:24987888ms +step:14000 val_loss:2.2648 val_bpb:1.0053 time:24987915ms +step:14200/20000 train_loss:2.2442 time:25344451ms +step:14400/20000 train_loss:2.2547 time:25701380ms +step:14600/20000 train_loss:2.3532 time:26058340ms +step:14800/20000 train_loss:2.2392 time:26415282ms +step:15000/20000 train_loss:2.2752 time:26771797ms +step:15000 val_loss:2.2610 val_bpb:1.0036 time:26771824ms +step:15200/20000 train_loss:2.2623 time:27128714ms +step:15400/20000 train_loss:2.2178 time:27485680ms +step:15600/20000 train_loss:2.2268 time:27842632ms +step:15800/20000 train_loss:2.2129 time:28199690ms +step:16000/20000 train_loss:2.1928 time:28556155ms +step:16000 val_loss:2.2559 val_bpb:1.0013 time:28556182ms +step:16200/20000 train_loss:2.0934 time:28913073ms +step:16400/20000 train_loss:2.2089 time:29270064ms +step:16600/20000 train_loss:2.2424 time:29626980ms +step:16800/20000 train_loss:2.3310 time:29984060ms +step:17000/20000 train_loss:2.1701 time:30340626ms +step:17000 val_loss:2.2533 val_bpb:1.0002 time:30340652ms +step:17200/20000 train_loss:2.2980 time:30697720ms +step:17400/20000 train_loss:2.2495 time:31054901ms +step:17600/20000 train_loss:2.2457 time:31412053ms +step:17800/20000 train_loss:2.2284 time:31768663ms +step:18000/20000 train_loss:2.2611 time:32125810ms +step:18000 val_loss:2.2502 val_bpb:0.9988 time:32125836ms +step:18200/20000 train_loss:2.2404 time:32483151ms +step:18400/20000 train_loss:2.2179 time:32840284ms +step:18600/20000 train_loss:2.2812 time:33197342ms +step:18800/20000 train_loss:2.2552 time:33553864ms +step:19000/20000 train_loss:2.3222 time:33910761ms +step:19000 val_loss:2.2481 val_bpb:0.9979 time:33910788ms +step:19200/20000 train_loss:2.2750 time:34267930ms +step:19400/20000 train_loss:2.3267 time:34625031ms +step:19600/20000 train_loss:2.2458 time:34982029ms +step:19800/20000 train_loss:2.2294 time:35338887ms +step:20000/20000 train_loss:2.2933 time:35696205ms +step:20000 val_loss:2.2451 val_bpb:0.9965 time:35696231ms diff --git a/exp_large.txt b/exp_large.txt new file mode 100644 index 0000000000..306018b6ab --- /dev/null +++ b/exp_large.txt @@ -0,0 +1,75 @@ +model_params: 4722704 +step:0 val_loss:6.9313 val_bpb:3.0767 time:0ms +step:1/100000 train_loss:6.9316 time:1366ms +step:200/100000 train_loss:3.3504 time:252792ms +step:400/100000 train_loss:2.7117 time:505701ms +step:600/100000 train_loss:2.7676 time:758655ms +step:800/100000 train_loss:2.6034 time:1011610ms +step:1000/100000 train_loss:2.6137 time:1264517ms +step:1000 val_loss:2.5906 val_bpb:1.1499 time:1264544ms +step:1200/100000 train_loss:2.5432 time:1517485ms +step:1400/100000 train_loss:2.5791 time:1770502ms +step:1600/100000 train_loss:2.4673 time:2023463ms +step:1800/100000 train_loss:2.5169 time:2276429ms +step:2000/100000 train_loss:2.4716 time:2529429ms +step:2000 val_loss:2.4860 val_bpb:1.1035 time:2529456ms +step:2200/100000 train_loss:2.4066 time:2782329ms +step:2400/100000 train_loss:2.4410 time:3035326ms +step:2600/100000 train_loss:2.5189 time:3288278ms +step:2800/100000 train_loss:2.4805 time:3541314ms +step:3000/100000 train_loss:2.3908 time:3794405ms +step:3000 val_loss:2.4431 val_bpb:1.0844 time:3794432ms +step:3200/100000 train_loss:2.4809 time:4047561ms +step:3400/100000 train_loss:2.4615 time:4301295ms +step:3600/100000 train_loss:2.4051 time:4554080ms +step:3800/100000 train_loss:2.4890 time:4807127ms +step:4000/100000 train_loss:2.3803 time:5060054ms +step:4000 val_loss:2.4200 val_bpb:1.0742 time:5060082ms +step:4200/100000 train_loss:2.4535 time:5313409ms +step:4400/100000 train_loss:2.4609 time:5566296ms +step:4600/100000 train_loss:2.3022 time:5819221ms +step:4800/100000 train_loss:2.4082 time:6072129ms +step:5000/100000 train_loss:2.4053 time:6325152ms +step:5000 val_loss:2.4035 val_bpb:1.0668 time:6325179ms +step:5200/100000 train_loss:2.4255 time:6578051ms +step:5400/100000 train_loss:2.4298 time:6830983ms +step:5600/100000 train_loss:2.4312 time:7083927ms +step:5800/100000 train_loss:2.4098 time:7336854ms +step:6000/100000 train_loss:2.5350 time:7589790ms +step:6000 val_loss:2.3963 val_bpb:1.0636 time:7589817ms +step:6200/100000 train_loss:2.3915 time:7842747ms +step:6400/100000 train_loss:2.4177 time:8095560ms +step:6600/100000 train_loss:2.3729 time:8348483ms +step:6800/100000 train_loss:2.4950 time:8601352ms +step:7000/100000 train_loss:2.3784 time:8854296ms +step:7000 val_loss:2.3845 val_bpb:1.0584 time:8854323ms +step:7200/100000 train_loss:2.4101 time:9107260ms +step:7400/100000 train_loss:2.3841 time:9360233ms +step:7600/100000 train_loss:2.3276 time:9613224ms +step:7800/100000 train_loss:2.3774 time:9866197ms +step:8000/100000 train_loss:2.3701 time:10119195ms +step:8000 val_loss:2.3764 val_bpb:1.0548 time:10119222ms +step:8200/100000 train_loss:2.3478 time:10372158ms +step:8400/100000 train_loss:2.3402 time:10625650ms +step:8600/100000 train_loss:2.3790 time:10878633ms +step:8800/100000 train_loss:2.3785 time:11131675ms +step:9000/100000 train_loss:2.3425 time:11384741ms +step:9000 val_loss:2.3722 val_bpb:1.0530 time:11384768ms +step:9200/100000 train_loss:2.3309 time:11637707ms +step:9400/100000 train_loss:2.3777 time:11890682ms +step:9600/100000 train_loss:2.4305 time:12143596ms +step:9800/100000 train_loss:2.3388 time:12396581ms +step:10000/100000 train_loss:2.2825 time:12649568ms +step:10000 val_loss:2.3684 val_bpb:1.0513 time:12649595ms +step:10200/100000 train_loss:2.4388 time:12902507ms +step:10400/100000 train_loss:2.3939 time:13155443ms +step:10600/100000 train_loss:2.3045 time:13408383ms +step:10800/100000 train_loss:2.3808 time:13661295ms +step:11000/100000 train_loss:2.3669 time:13914261ms +step:11000 val_loss:2.3632 val_bpb:1.0490 time:13914288ms +step:11200/100000 train_loss:2.3381 time:14167332ms +step:11400/100000 train_loss:2.3324 time:14420201ms +step:11600/100000 train_loss:2.3328 time:14673080ms +step:11800/100000 train_loss:2.3933 time:14926023ms +step:12000/100000 train_loss:2.2873 time:15179072ms +step:12000 val_loss:2.3438 val_bpb:1.0404 time:15179099ms diff --git a/exp_xl.txt b/exp_xl.txt new file mode 100644 index 0000000000..934fbc88bd --- /dev/null +++ b/exp_xl.txt @@ -0,0 +1,39 @@ +model_params: 7214100 +step:0 val_loss:6.9392 val_bpb:3.0801 time:0ms +step:1/150000 train_loss:6.9393 time:3665ms +step:200/150000 train_loss:3.4129 time:720568ms +step:400/150000 train_loss:2.7769 time:1441803ms +step:600/150000 train_loss:2.5164 time:2163050ms +step:800/150000 train_loss:2.5384 time:2884328ms +step:1000/150000 train_loss:2.4318 time:3605217ms +step:1000 val_loss:2.4864 val_bpb:1.1036 time:3605277ms +step:1200/150000 train_loss:2.4116 time:4326409ms +step:1400/150000 train_loss:2.4039 time:5047843ms +step:1600/150000 train_loss:2.4459 time:5768990ms +step:1800/150000 train_loss:2.4444 time:6490127ms +step:2000/150000 train_loss:2.3412 time:7211520ms +step:2000 val_loss:2.3684 val_bpb:1.0513 time:7211580ms +step:2200/150000 train_loss:2.3277 time:7931124ms +step:2400/150000 train_loss:2.3313 time:8652261ms +step:2600/150000 train_loss:2.4536 time:9373368ms +step:2800/150000 train_loss:2.3634 time:10094473ms +step:3000/150000 train_loss:2.3563 time:10813916ms +step:3000 val_loss:2.3277 val_bpb:1.0332 time:10813976ms +step:3200/150000 train_loss:2.3075 time:11535027ms +step:3400/150000 train_loss:2.3187 time:12256350ms +step:3600/150000 train_loss:2.2674 time:12977761ms +step:3800/150000 train_loss:2.3659 time:13699086ms +step:4000/150000 train_loss:2.3068 time:14420163ms +step:4000 val_loss:2.3065 val_bpb:1.0238 time:14420223ms +step:4200/150000 train_loss:2.2895 time:15141323ms +step:4400/150000 train_loss:2.3260 time:15862396ms +step:4600/150000 train_loss:2.2911 time:16583460ms +step:4800/150000 train_loss:2.2846 time:17302885ms +step:5000/150000 train_loss:2.2688 time:18024135ms +step:5000 val_loss:2.2900 val_bpb:1.0165 time:18024195ms +step:5200/150000 train_loss:2.3081 time:18743467ms +step:5400/150000 train_loss:2.3612 time:19464628ms +step:5600/150000 train_loss:2.2872 time:20185876ms +step:5800/150000 train_loss:2.3083 time:20907058ms +step:6000/150000 train_loss:2.2619 time:21628023ms +step:6000 val_loss:2.2666 val_bpb:1.0061 time:21628083ms diff --git a/exp_xl_model.ptz b/exp_xl_model.ptz new file mode 100644 index 0000000000..d0ba05fb27 Binary files /dev/null and b/exp_xl_model.ptz differ diff --git a/final_model.int8.ptz b/final_model.int8.ptz new file mode 100644 index 0000000000..d17c6287aa Binary files /dev/null and b/final_model.int8.ptz differ diff --git a/infer.py b/infer.py new file mode 100644 index 0000000000..482f1638e3 --- /dev/null +++ b/infer.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +import sentencepiece as spm +import zlib +import io +import os + +from train_gpt_optimized import dequantize_state_dict_int8, GPT, Hyperparameters + +# Configure hyperparameters to match the fast training run +os.environ["NUM_LAYERS"] = "4" +os.environ["MODEL_DIM"] = "256" +os.environ["NUM_HEADS"] = "4" +os.environ["NUM_KV_HEADS"] = "4" + +# Load tokenizer +sp = spm.SentencePieceProcessor() +sp.load("./data/tokenizers/fineweb_1024_bpe.model") + +# Load and decompress model weights +print("Loading and dequantizing model...") +with open("final_model.int8.ptz", "rb") as f: + q_obj = torch.load(io.BytesIO(zlib.decompress(f.read())), map_location="cpu") + +state_dict = dequantize_state_dict_int8(q_obj) + +# Initialize model +args = Hyperparameters() +args.num_steps = 4 +args.model_dim = 256 +args.num_heads = 4 +args.num_kv_heads = 4 +model = GPT(args).bfloat16() +model.load_state_dict(state_dict) +model.eval() +print("Model loaded successfully!") + +def patched_forward(self, input_ids): + x = F.rms_norm(self.tok_emb(input_ids), (self.args.model_dim,)) + x0 = x + for i in range(self.args.num_steps): + block_idx = i % self.args.num_unique_blocks + x = self.unique_blocks[block_idx](x, x0) + + x = self.final_norm(x) # [bsz, seq_len, dim] + logits_proj = F.linear(x, self.tok_emb.weight) if self.args.tie_embeddings else self.lm_head(x) + logits = self.args.logit_softcap * torch.tanh(logits_proj / self.args.logit_softcap) + return logits + +GPT.forward = patched_forward + +def sample_logits(logits, temperature=0.8, top_k=40, top_p=0.9): + logits = logits / temperature + probs = F.softmax(logits, dim=-1) + + # Top-k + if top_k > 0: + values, indices = torch.topk(probs, top_k) + probs_filtered = torch.zeros_like(probs) + probs_filtered.scatter_(0, indices, values) + probs = probs_filtered / probs_filtered.sum() + + # Top-p (nucleus) + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=0) + + cutoff = cumulative_probs > top_p + if torch.any(cutoff): + cutoff_idx = torch.where(cutoff)[0][0] + sorted_probs[cutoff_idx:] = 0 + if sorted_probs.sum() > 0: + sorted_probs /= sorted_probs.sum() + probs = torch.zeros_like(probs) + probs.scatter_(0, sorted_indices, sorted_probs) + + return torch.multinomial(probs, 1).item() + +def generate(prompt, max_tokens=80, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.2): + tokens = sp.encode(prompt) + tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) + generated = tokens.clone() + + for _ in range(max_tokens): + with torch.no_grad(): + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + logits = model(generated)[0, -1].float() + + # repetition penalty + for token_id in set(generated[0].tolist()): + if logits[token_id] < 0: + logits[token_id] *= repetition_penalty + else: + logits[token_id] /= repetition_penalty + + next_token = sample_logits(logits, temperature, top_k, top_p) + generated = torch.cat([generated, torch.tensor([[next_token]])], dim=1) + + return sp.decode(generated[0].tolist()) + +# Test prompts +prompts = [ + "The future of AI is", + "Once upon a time", + "India is known for", + "The meaning of life is" +] + +for p in prompts: + print("\nPROMPT:", p) + print(generate(p)) + diff --git a/logs/exp_balanced_peak.txt b/logs/exp_balanced_peak.txt new file mode 100644 index 0000000000..813c22afbf --- /dev/null +++ b/logs/exp_balanced_peak.txt @@ -0,0 +1,99 @@ +model_params: 10229784 +step:0 val_loss:6.9499 val_bpb:3.0849 time:0ms +step:1/20000 train_loss:6.9497 time:2822ms +step:200/20000 train_loss:3.3912 time:550994ms +step:400/20000 train_loss:2.7145 time:1105646ms +step:600/20000 train_loss:2.6910 time:1656467ms +step:800/20000 train_loss:2.5038 time:2207275ms +step:1000/20000 train_loss:2.4996 time:2757835ms +step:1000 val_loss:2.4740 val_bpb:1.0982 time:2757875ms +step:1200/20000 train_loss:2.4191 time:3308247ms +step:1400/20000 train_loss:2.4504 time:3858630ms +step:1600/20000 train_loss:2.3354 time:4408979ms +step:1800/20000 train_loss:2.3797 time:4959376ms +step:2000/20000 train_loss:2.3362 time:5509724ms +step:2000 val_loss:2.3460 val_bpb:1.0413 time:5509764ms +step:2200/20000 train_loss:2.2632 time:6060168ms +step:2400/20000 train_loss:2.3066 time:6610416ms +step:2600/20000 train_loss:2.3798 time:7160713ms +step:2800/20000 train_loss:2.3299 time:7710986ms +step:3000/20000 train_loss:2.2481 time:8260512ms +step:3000 val_loss:2.2984 val_bpb:1.0202 time:8260552ms +step:3200/20000 train_loss:2.3273 time:8810812ms +step:3400/20000 train_loss:2.3097 time:9361096ms +step:3600/20000 train_loss:2.2559 time:9911436ms +step:3800/20000 train_loss:2.3396 time:10461525ms +step:4000/20000 train_loss:2.2345 time:11012093ms +step:4000 val_loss:2.2707 val_bpb:1.0079 time:11012133ms +step:4200/20000 train_loss:2.3068 time:11562683ms +step:4400/20000 train_loss:2.3141 time:12112992ms +step:4600/20000 train_loss:2.1471 time:12663239ms +step:4800/20000 train_loss:2.2601 time:13213483ms +step:5000/20000 train_loss:2.2539 time:13763910ms +step:5000 val_loss:2.2527 val_bpb:0.9999 time:13763950ms +step:5200/20000 train_loss:2.2698 time:14313453ms +step:5400/20000 train_loss:2.2791 time:14861595ms +step:5600/20000 train_loss:2.2793 time:15411815ms +step:5800/20000 train_loss:2.2551 time:15962091ms +step:6000/20000 train_loss:2.3829 time:16512317ms +step:6000 val_loss:2.2416 val_bpb:0.9950 time:16512357ms +step:6200/20000 train_loss:2.2417 time:17062532ms +step:6400/20000 train_loss:2.2622 time:17612777ms +step:6600/20000 train_loss:2.2144 time:18162944ms +step:6800/20000 train_loss:2.3474 time:18713185ms +step:7000/20000 train_loss:2.2297 time:19263415ms +step:7000 val_loss:2.2309 val_bpb:0.9902 time:19263455ms +step:7200/20000 train_loss:2.2484 time:19813711ms +step:7400/20000 train_loss:2.2324 time:20363939ms +step:7600/20000 train_loss:2.1733 time:20914205ms +step:7800/20000 train_loss:2.2252 time:21464984ms +step:8000/20000 train_loss:2.2129 time:22016061ms +step:8000 val_loss:2.2195 val_bpb:0.9852 time:22016101ms +step:8200/20000 train_loss:2.1925 time:22566620ms +step:8400/20000 train_loss:2.1841 time:23117628ms +step:8600/20000 train_loss:2.2280 time:23668528ms +step:8800/20000 train_loss:2.2131 time:24219096ms +step:9000/20000 train_loss:2.1896 time:24769698ms +step:9000 val_loss:2.2150 val_bpb:0.9832 time:24769738ms +step:9200/20000 train_loss:2.1730 time:25320028ms +step:9400/20000 train_loss:2.2124 time:25869523ms +step:9600/20000 train_loss:2.2730 time:26417673ms +step:9800/20000 train_loss:2.1840 time:26968427ms +step:10000/20000 train_loss:2.1216 time:27518970ms +step:10000 val_loss:2.2097 val_bpb:0.9808 time:27519010ms +step:10200/20000 train_loss:2.2808 time:28069311ms +step:10400/20000 train_loss:2.2331 time:28619556ms +step:10600/20000 train_loss:2.1465 time:29169778ms +step:10800/20000 train_loss:2.2254 time:29720062ms +step:11000/20000 train_loss:2.2083 time:30270220ms +step:11000 val_loss:2.2027 val_bpb:0.9777 time:30270259ms +step:11200/20000 train_loss:2.1782 time:30820457ms +step:11400/20000 train_loss:2.1896 time:31370603ms +step:11600/20000 train_loss:2.1839 time:31920817ms +step:11800/20000 train_loss:2.2448 time:32471023ms +step:12000/20000 train_loss:2.1466 time:33021266ms +step:12000 val_loss:2.1970 val_bpb:0.9752 time:33021305ms +step:12200/20000 train_loss:2.2019 time:33568613ms +step:12400/20000 train_loss:2.1480 time:34119197ms +step:12600/20000 train_loss:2.2547 time:34669246ms +step:12800/20000 train_loss:2.3070 time:35219539ms +step:13000/20000 train_loss:2.4537 time:35769812ms +step:13000 val_loss:2.1983 val_bpb:0.9758 time:35769852ms +step:13200/20000 train_loss:2.2096 time:36320273ms +step:13400/20000 train_loss:2.1688 time:36870543ms +step:13600/20000 train_loss:2.1500 time:37420737ms +step:13800/20000 train_loss:2.1564 time:37970999ms +step:14000/20000 train_loss:2.1841 time:38521288ms +step:14000 val_loss:2.1906 val_bpb:0.9723 time:38521328ms +step:14200/20000 train_loss:2.1927 time:39071626ms +step:14400/20000 train_loss:2.2155 time:39621957ms +step:14600/20000 train_loss:2.1839 time:40172238ms +step:14800/20000 train_loss:2.1867 time:40722763ms +step:15000/20000 train_loss:2.2154 time:41273172ms +step:15000 val_loss:2.1846 val_bpb:0.9697 time:41273212ms +step:15200/20000 train_loss:2.2216 time:41823512ms +step:15400/20000 train_loss:2.2372 time:42373977ms +step:15600/20000 train_loss:2.1846 time:42924356ms +step:15800/20000 train_loss:2.0556 time:43474710ms +step:16000/20000 train_loss:2.1982 time:44025221ms +step:16000 val_loss:2.1687 val_bpb:0.9627 time:44025261ms diff --git a/nohup_36h.out b/nohup_36h.out new file mode 100644 index 0000000000..21f70fa020 --- /dev/null +++ b/nohup_36h.out @@ -0,0 +1,535 @@ +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_deep_stable.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 + main() + File "/workspace/parameter-golf/train_gpt.py", line 826, in main + base_model = GPT( + ^^^^ + File "/workspace/parameter-golf/train_gpt.py", line 675, in __init__ + [ + File "/workspace/parameter-golf/train_gpt.py", line 676, in + Block( + File "/workspace/parameter-golf/train_gpt.py", line 633, in __init__ + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt.py", line 568, in __init__ + raise ValueError("num_heads must be divisible by num_kv_heads") +ValueError: num_heads must be divisible by num_kv_heads +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:723.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:725.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. Query.sizes(): [64, 12, 1024, 64], Key sizes(): [64, 4, 1024, 64], Value sizes(): [64, 4, 1024, 64] instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:363.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:727.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.) + return node.target(*args, **kwargs) +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_wide_aggressive.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:40142712 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:12 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:30 max_wallclock_seconds:43200.000 +seed:1337 + main() + File "/workspace/parameter-golf/train_gpt.py", line 948, in main + warmup_loss = model(x, y) + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__ + return self._torchdynamo_orig_callable( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ + return _compile( + ^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function + return StrobelightCompileTimeProfiler.profile_compile_time( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.11/contextlib.py", line 81, in inner + return func(*args, **kwds) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile + guarded_code = compile_inner(code, one_graph, hooks, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper + r = func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner + out_code = transform_code_object(code, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object + transformations(instructions, code_options) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run + super().run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 757, in call_function + tensor_variable = wrap_fx_proxy( + ^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value + ret_val = wrap_fake_exception( + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception + return fn() + ^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1786, in + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node + raise RuntimeError(make_error_message(e)).with_traceback( + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node + return node.target(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(64, 12, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), dtype=torch.bfloat16, + grad_fn=)), **{'attn_mask': None, 'is_causal': True}): +No available kernel. Aborting execution. + +from user code: + File "/workspace/parameter-golf/train_gpt.py", line 708, in forward + x = self.blocks[i](x, x0) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 642, in forward + attn_out = self.attn(self.attn_norm(x)) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 594, in forward + y = F.scaled_dot_product_attention( + +Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information + + +You can suppress this exception and fall back to eager by setting: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:723.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:725.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. Query.sizes(): [64, 12, 1024, 64], Key sizes(): [64, 4, 1024, 64], Value sizes(): [64, 4, 1024, 64] instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:363.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:727.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.) + return node.target(*args, **kwargs) +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_balanced_peak.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:48013968 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:12 num_kv_heads:4 +tie_embeddings:True embed_lr:0.032 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:50 max_wallclock_seconds:43200.000 +seed:1337 + main() + File "/workspace/parameter-golf/train_gpt.py", line 948, in main + warmup_loss = model(x, y) + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__ + return self._torchdynamo_orig_callable( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ + return _compile( + ^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function + return StrobelightCompileTimeProfiler.profile_compile_time( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.11/contextlib.py", line 81, in inner + return func(*args, **kwds) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile + guarded_code = compile_inner(code, one_graph, hooks, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper + r = func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner + out_code = transform_code_object(code, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object + transformations(instructions, code_options) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run + super().run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 757, in call_function + tensor_variable = wrap_fx_proxy( + ^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value + ret_val = wrap_fake_exception( + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception + return fn() + ^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1786, in + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node + raise RuntimeError(make_error_message(e)).with_traceback( + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node + return node.target(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(64, 12, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), dtype=torch.bfloat16, + grad_fn=)), **{'attn_mask': None, 'is_causal': True}): +No available kernel. Aborting execution. + +from user code: + File "/workspace/parameter-golf/train_gpt.py", line 708, in forward + x = self.blocks[i](x, x0) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 642, in forward + attn_out = self.attn(self.attn_norm(x)) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 594, in forward + y = F.scaled_dot_product_attention( + +Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information + + +You can suppress this exception and fall back to eager by setting: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + +All experiments completed. Logs available in /logs folder. +total 375K +-rw-rw-rw- 1 root root 50K May 1 14:04 7609513a-4573-4f06-8293-0fccdcd2648a.txt +-rw-rw-rw- 1 root root 99K May 1 10:00 baseline_sp1024.txt +-rw-rw-rw- 1 root root 50K May 1 10:08 baseline_sp1024_debug.txt +-rw-rw-rw- 1 root root 1.6K May 1 12:06 exp1.txt +-rw-rw-rw- 1 root root 114 May 1 10:43 exp1.txt +-rw-rw-rw- 1 root root 1.6K May 1 13:13 exp2.txt +-rw-rw-rw- 1 root root 161 May 1 11:18 exp2.txt +-rw-rw-rw- 1 root root 50K May 2 17:33 exp_balanced_peak.txt +-rw-rw-rw- 1 root root 3.8K May 1 19:17 exp_big.txt +-rw-rw-rw- 1 root root 49K May 2 17:33 exp_deep_stable.txt +-rw-rw-rw- 1 root root 298 May 1 14:33 exp_fast.txt +-rw-rw-rw- 1 root root 161 May 1 21:37 exp_gpu.txt +-rw-rw-rw- 1 root root 3.8K May 2 10:20 exp_large.txt +-rw-rw-rw- 1 root root 5.9K May 2 05:57 exp_medium.txt +-rw-rw-rw- 1 root root 6.8K May 2 01:47 exp_small.txt +-rw-rw-rw- 1 root root 50K May 2 17:33 exp_wide_aggressive.txt +-rw-rw-rw- 1 root root 2.0K May 2 17:01 exp_xl.txt diff --git a/nohup_36h_new.out b/nohup_36h_new.out new file mode 100644 index 0000000000..21f70fa020 --- /dev/null +++ b/nohup_36h_new.out @@ -0,0 +1,535 @@ +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_deep_stable.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 + main() + File "/workspace/parameter-golf/train_gpt.py", line 826, in main + base_model = GPT( + ^^^^ + File "/workspace/parameter-golf/train_gpt.py", line 675, in __init__ + [ + File "/workspace/parameter-golf/train_gpt.py", line 676, in + Block( + File "/workspace/parameter-golf/train_gpt.py", line 633, in __init__ + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt.py", line 568, in __init__ + raise ValueError("num_heads must be divisible by num_kv_heads") +ValueError: num_heads must be divisible by num_kv_heads +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:723.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:725.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. Query.sizes(): [64, 12, 1024, 64], Key sizes(): [64, 4, 1024, 64], Value sizes(): [64, 4, 1024, 64] instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:363.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:727.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.) + return node.target(*args, **kwargs) +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_wide_aggressive.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:40142712 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:12 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:30 max_wallclock_seconds:43200.000 +seed:1337 + main() + File "/workspace/parameter-golf/train_gpt.py", line 948, in main + warmup_loss = model(x, y) + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__ + return self._torchdynamo_orig_callable( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ + return _compile( + ^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function + return StrobelightCompileTimeProfiler.profile_compile_time( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.11/contextlib.py", line 81, in inner + return func(*args, **kwds) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile + guarded_code = compile_inner(code, one_graph, hooks, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper + r = func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner + out_code = transform_code_object(code, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object + transformations(instructions, code_options) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run + super().run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 757, in call_function + tensor_variable = wrap_fx_proxy( + ^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value + ret_val = wrap_fake_exception( + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception + return fn() + ^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1786, in + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node + raise RuntimeError(make_error_message(e)).with_traceback( + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node + return node.target(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(64, 12, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), dtype=torch.bfloat16, + grad_fn=)), **{'attn_mask': None, 'is_causal': True}): +No available kernel. Aborting execution. + +from user code: + File "/workspace/parameter-golf/train_gpt.py", line 708, in forward + x = self.blocks[i](x, x0) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 642, in forward + attn_out = self.attn(self.attn_norm(x)) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 594, in forward + y = F.scaled_dot_product_attention( + +Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information + + +You can suppress this exception and fall back to eager by setting: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:723.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:725.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. Query.sizes(): [64, 12, 1024, 64], Key sizes(): [64, 4, 1024, 64], Value sizes(): [64, 4, 1024, 64] instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:363.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:727.) + return node.target(*args, **kwargs) +/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:1903: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.) + return node.target(*args, **kwargs) +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt.py", line 1126, in +logs/exp_balanced_peak.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:48013968 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:12 num_kv_heads:4 +tie_embeddings:True embed_lr:0.032 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:50 max_wallclock_seconds:43200.000 +seed:1337 + main() + File "/workspace/parameter-golf/train_gpt.py", line 948, in main + warmup_loss = model(x, y) + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__ + return self._torchdynamo_orig_callable( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ + return _compile( + ^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function + return StrobelightCompileTimeProfiler.profile_compile_time( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.11/contextlib.py", line 81, in inner + return func(*args, **kwds) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile + guarded_code = compile_inner(code, one_graph, hooks, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper + r = func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner + out_code = transform_code_object(code, transform) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object + transformations(instructions, code_options) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run + super().run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function + return tx.inline_user_function_return( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX + self.call_function(fn, argsvars.items, kwargsvars) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function + return super().call_function(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call + return cls.inline_call_(parent, func, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ + tracer.run() + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run + while self.step(): + ^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step + self.dispatch_table[inst.opcode](self, inst) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL + self.call_function(fn, args, kwargs) + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function + self.push(fn.call_function(self, args, kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 757, in call_function + tensor_variable = wrap_fx_proxy( + ^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value + ret_val = wrap_fake_exception( + ^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception + return fn() + ^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1786, in + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node + raise RuntimeError(make_error_message(e)).with_traceback( + File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node + return node.target(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(64, 12, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), + grad_fn=), FakeTensor(..., device='cuda:0', size=(64, 4, 1024, 64), dtype=torch.bfloat16, + grad_fn=)), **{'attn_mask': None, 'is_causal': True}): +No available kernel. Aborting execution. + +from user code: + File "/workspace/parameter-golf/train_gpt.py", line 708, in forward + x = self.blocks[i](x, x0) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 642, in forward + attn_out = self.attn(self.attn_norm(x)) + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + File "/workspace/parameter-golf/train_gpt.py", line 594, in forward + y = F.scaled_dot_product_attention( + +Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information + + +You can suppress this exception and fall back to eager by setting: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + +All experiments completed. Logs available in /logs folder. +total 375K +-rw-rw-rw- 1 root root 50K May 1 14:04 7609513a-4573-4f06-8293-0fccdcd2648a.txt +-rw-rw-rw- 1 root root 99K May 1 10:00 baseline_sp1024.txt +-rw-rw-rw- 1 root root 50K May 1 10:08 baseline_sp1024_debug.txt +-rw-rw-rw- 1 root root 1.6K May 1 12:06 exp1.txt +-rw-rw-rw- 1 root root 114 May 1 10:43 exp1.txt +-rw-rw-rw- 1 root root 1.6K May 1 13:13 exp2.txt +-rw-rw-rw- 1 root root 161 May 1 11:18 exp2.txt +-rw-rw-rw- 1 root root 50K May 2 17:33 exp_balanced_peak.txt +-rw-rw-rw- 1 root root 3.8K May 1 19:17 exp_big.txt +-rw-rw-rw- 1 root root 49K May 2 17:33 exp_deep_stable.txt +-rw-rw-rw- 1 root root 298 May 1 14:33 exp_fast.txt +-rw-rw-rw- 1 root root 161 May 1 21:37 exp_gpu.txt +-rw-rw-rw- 1 root root 3.8K May 2 10:20 exp_large.txt +-rw-rw-rw- 1 root root 5.9K May 2 05:57 exp_medium.txt +-rw-rw-rw- 1 root root 6.8K May 2 01:47 exp_small.txt +-rw-rw-rw- 1 root root 50K May 2 17:33 exp_wide_aggressive.txt +-rw-rw-rw- 1 root root 2.0K May 2 17:01 exp_xl.txt diff --git a/nohup_gpu.out b/nohup_gpu.out new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nohup_suite.out b/nohup_suite.out new file mode 100644 index 0000000000..bde3da181e --- /dev/null +++ b/nohup_suite.out @@ -0,0 +1,328 @@ +Running EXP1... +model_params: 1312776 +step:0 val_loss:6.9340 val_bpb:3.0778 time:0ms +step:1/100000 train_loss:6.9336 time:464ms +step:200/100000 train_loss:3.3495 time:131051ms +step:400/100000 train_loss:2.9148 time:262319ms +step:600/100000 train_loss:2.9818 time:394554ms +step:800/100000 train_loss:2.8375 time:526212ms +step:1000/100000 train_loss:2.8601 time:657110ms +step:1000 val_loss:2.8432 val_bpb:1.2620 time:657113ms +step:1200/100000 train_loss:2.8000 time:789337ms +step:1400/100000 train_loss:2.8451 time:921430ms +step:1600/100000 train_loss:2.7420 time:1053730ms +step:1800/100000 train_loss:2.7998 time:1185133ms +step:2000/100000 train_loss:2.7544 time:1316536ms +step:2000 val_loss:2.7698 val_bpb:1.2295 time:1316538ms +step:2200/100000 train_loss:2.7009 time:1447750ms +step:2400/100000 train_loss:2.7294 time:1579689ms +step:2600/100000 train_loss:2.8029 time:1712240ms +step:2800/100000 train_loss:2.7831 time:1843793ms +step:3000/100000 train_loss:2.6882 time:1975540ms +step:3000 val_loss:2.7421 val_bpb:1.2171 time:1975541ms +step:3200/100000 train_loss:2.7851 time:2107273ms +step:3400/100000 train_loss:2.7666 time:2238873ms +step:3600/100000 train_loss:2.7086 time:2369668ms +step:3800/100000 train_loss:2.7912 time:2500266ms +step:4000/100000 train_loss:2.6807 time:2631774ms +step:4000 val_loss:2.7242 val_bpb:1.2092 time:2631776ms +step:4200/100000 train_loss:2.7495 time:2763483ms +step:4400/100000 train_loss:2.7652 time:2894771ms +step:4600/100000 train_loss:2.6117 time:3026384ms +step:4800/100000 train_loss:2.7093 time:3158990ms +step:5000/100000 train_loss:2.7111 time:3291590ms +step:5000 val_loss:2.7118 val_bpb:1.2037 time:3291592ms +step:5200/100000 train_loss:2.7400 time:3423096ms +step:5400/100000 train_loss:2.7364 time:3554506ms +step:5600/100000 train_loss:2.7363 time:3685645ms +step:5800/100000 train_loss:2.7217 time:3816398ms +step:6000/100000 train_loss:2.8453 time:3947589ms +step:6000 val_loss:2.7118 val_bpb:1.2037 time:3947599ms +step:6200/100000 train_loss:2.6973 time:4079105ms +step:6400/100000 train_loss:2.7270 time:4209722ms +step:6600/100000 train_loss:2.6892 time:4341506ms +step:6800/100000 train_loss:2.7807 time:4473113ms +step:7000/100000 train_loss:2.6803 time:4604050ms +step:7000 val_loss:2.6975 val_bpb:1.1973 time:4604052ms +step:7200/100000 train_loss:2.7286 time:4735433ms +step:7400/100000 train_loss:2.6937 time:4867316ms +step:7600/100000 train_loss:2.6453 time:4998168ms +step:7800/100000 train_loss:2.6812 time:5130169ms +step:8000/100000 train_loss:2.6920 time:5261724ms +step:8000 val_loss:2.6902 val_bpb:1.1941 time:5261731ms +step:8200/100000 train_loss:2.6577 time:5391926ms +step:8400/100000 train_loss:2.6493 time:5523929ms +step:8600/100000 train_loss:2.6846 time:5655965ms +step:8800/100000 train_loss:2.7042 time:5787630ms +step:9000/100000 train_loss:2.6503 time:5919125ms +step:9000 val_loss:2.6868 val_bpb:1.1926 time:5919128ms +step:9200/100000 train_loss:2.6468 time:6050778ms +step:9400/100000 train_loss:2.6971 time:6182637ms +step:9600/100000 train_loss:2.7342 time:6315345ms +step:9800/100000 train_loss:2.6450 time:6447541ms +step:10000/100000 train_loss:2.6072 time:6580145ms +step:10000 val_loss:2.6821 val_bpb:1.1905 time:6580147ms +step:10200/100000 train_loss:2.7419 time:6711052ms +step:10400/100000 train_loss:2.7092 time:6841752ms +step:10600/100000 train_loss:2.6172 time:6973550ms +step:10800/100000 train_loss:2.6881 time:7104450ms +step:11000/100000 train_loss:2.6821 time:7235150ms +step:11000 val_loss:2.6757 val_bpb:1.1877 time:7235153ms +step:11200/100000 train_loss:2.6540 time:7366704ms +step:11400/100000 train_loss:2.6649 time:7497908ms +step:11600/100000 train_loss:2.6620 time:7629165ms +step:11800/100000 train_loss:2.7206 time:7761561ms +step:12000/100000 train_loss:2.6166 time:7893469ms +step:12000 val_loss:2.6697 val_bpb:1.1850 time:7893475ms +step:12200/100000 train_loss:2.6754 time:8024882ms +step:12400/100000 train_loss:2.6145 time:8157905ms +step:12600/100000 train_loss:2.7114 time:8290776ms +step:12800/100000 train_loss:2.7539 time:8421481ms +step:13000/100000 train_loss:2.9310 time:8553877ms +step:13000 val_loss:2.6783 val_bpb:1.1888 time:8553879ms +step:13200/100000 train_loss:2.6736 time:8684790ms +step:13400/100000 train_loss:2.6472 time:8815686ms +step:13600/100000 train_loss:2.6354 time:8946589ms +step:13800/100000 train_loss:2.6142 time:9077985ms +step:14000/100000 train_loss:2.6612 time:9209493ms +step:14000 val_loss:2.6650 val_bpb:1.1829 time:9209502ms +step:14200/100000 train_loss:2.6751 time:9341511ms +step:14400/100000 train_loss:2.6716 time:9473403ms +step:14600/100000 train_loss:2.6535 time:9604808ms +step:14800/100000 train_loss:2.6538 time:9736106ms +step:15000/100000 train_loss:2.6917 time:9868605ms +step:15000 val_loss:2.6614 val_bpb:1.1813 time:9868608ms +step:15200/100000 train_loss:2.6932 time:10000962ms +step:15400/100000 train_loss:2.7047 time:10131814ms +step:15600/100000 train_loss:2.6653 time:10263722ms +step:15800/100000 train_loss:2.5603 time:10394715ms +step:16000/100000 train_loss:2.6836 time:10525719ms +step:16000 val_loss:2.6599 val_bpb:1.1807 time:10525721ms +step:16200/100000 train_loss:2.5970 time:10657637ms +step:16400/100000 train_loss:2.6322 time:10789926ms +step:16600/100000 train_loss:2.6099 time:10922637ms +step:16800/100000 train_loss:2.6587 time:11053667ms +step:17000/100000 train_loss:2.6511 time:11183936ms +step:17000 val_loss:2.6584 val_bpb:1.1800 time:11183938ms +step:17200/100000 train_loss:2.7067 time:11315890ms +step:17400/100000 train_loss:2.6507 time:11446247ms +step:17600/100000 train_loss:2.6190 time:11577738ms +step:17800/100000 train_loss:2.6955 time:11708534ms +step:18000/100000 train_loss:2.6804 time:11839869ms +step:18000 val_loss:2.6579 val_bpb:1.1798 time:11839874ms +step:18200/100000 train_loss:2.6578 time:11971560ms +step:18400/100000 train_loss:2.6118 time:12102859ms +step:18600/100000 train_loss:2.6985 time:12233468ms +step:18800/100000 train_loss:2.6930 time:12364667ms +step:19000/100000 train_loss:2.6834 time:12496756ms +step:19000 val_loss:2.6525 val_bpb:1.1774 time:12496759ms +step:19200/100000 train_loss:2.8319 time:12628908ms +step:19400/100000 train_loss:2.6812 time:12761009ms +step:19600/100000 train_loss:2.6971 time:12892563ms +step:19800/100000 train_loss:2.5768 time:13024174ms +step:20000/100000 train_loss:2.6519 time:13155278ms +step:20000 val_loss:2.6508 val_bpb:1.1766 time:13155280ms +step:20200/100000 train_loss:2.6996 time:13287003ms +step:20400/100000 train_loss:2.6263 time:13418011ms +step:20600/100000 train_loss:2.6038 time:13550623ms +step:20800/100000 train_loss:2.6479 time:13682291ms +step:21000/100000 train_loss:2.6683 time:13813791ms +step:21000 val_loss:2.6501 val_bpb:1.1763 time:13813793ms +step:21200/100000 train_loss:2.5647 time:13945243ms +step:21400/100000 train_loss:2.7001 time:14076937ms +step:21600/100000 train_loss:2.6901 time:14208889ms +step:21800/100000 train_loss:2.6325 time:14339638ms +step:22000/100000 train_loss:2.5759 time:14470804ms +step:22000 val_loss:2.6291 val_bpb:1.1670 time:14470806ms +Running EXP2... +model_params: 2755596 +step:0 val_loss:6.9299 val_bpb:3.0760 time:0ms +step:1/100000 train_loss:6.9298 time:833ms +step:200/100000 train_loss:3.3335 time:151344ms +step:400/100000 train_loss:2.7599 time:303038ms +step:600/100000 train_loss:2.8333 time:454966ms +step:800/100000 train_loss:2.6875 time:606599ms +step:1000/100000 train_loss:2.7019 time:758590ms +step:1000 val_loss:2.6869 val_bpb:1.1926 time:758611ms +step:1200/100000 train_loss:2.6410 time:910607ms +step:1400/100000 train_loss:2.6831 time:1062425ms +step:1600/100000 train_loss:2.5749 time:1214151ms +step:1800/100000 train_loss:2.6284 time:1365866ms +step:2000/100000 train_loss:2.5869 time:1517606ms +step:2000 val_loss:2.5980 val_bpb:1.1532 time:1517627ms +step:2200/100000 train_loss:2.5240 time:1669130ms +step:2400/100000 train_loss:2.5548 time:1820764ms +step:2600/100000 train_loss:2.6328 time:1972672ms +step:2800/100000 train_loss:2.6016 time:2124481ms +step:3000/100000 train_loss:2.5060 time:2276018ms +step:3000 val_loss:2.5622 val_bpb:1.1373 time:2276039ms +step:3200/100000 train_loss:2.6010 time:2427565ms +step:3400/100000 train_loss:2.5875 time:2579175ms +step:3600/100000 train_loss:2.5247 time:2730821ms +step:3800/100000 train_loss:2.6124 time:2882433ms +step:4000/100000 train_loss:2.5009 time:3033929ms +step:4000 val_loss:2.5416 val_bpb:1.1281 time:3033950ms +step:4200/100000 train_loss:2.5729 time:3185960ms +step:4400/100000 train_loss:2.5843 time:3337380ms +step:4600/100000 train_loss:2.4243 time:3488335ms +step:4800/100000 train_loss:2.5287 time:3639854ms +step:5000/100000 train_loss:2.5276 time:3791377ms +step:5000 val_loss:2.5269 val_bpb:1.1216 time:3791398ms +step:5200/100000 train_loss:2.5524 time:3942934ms +step:5400/100000 train_loss:2.5528 time:4094718ms +step:5600/100000 train_loss:2.5537 time:4246380ms +step:5800/100000 train_loss:2.5357 time:4398174ms +step:6000/100000 train_loss:2.6597 time:4549934ms +step:6000 val_loss:2.5221 val_bpb:1.1195 time:4549955ms +step:6200/100000 train_loss:2.5143 time:4701583ms +step:6400/100000 train_loss:2.5418 time:4852946ms +step:6600/100000 train_loss:2.5002 time:5004372ms +step:6800/100000 train_loss:2.6086 time:5155926ms +step:7000/100000 train_loss:2.4996 time:5307869ms +step:7000 val_loss:2.5107 val_bpb:1.1145 time:5307890ms +step:7200/100000 train_loss:2.5397 time:5459651ms +step:7400/100000 train_loss:2.5049 time:5611129ms +step:7600/100000 train_loss:2.4541 time:5762784ms +step:7800/100000 train_loss:2.5015 time:5914545ms +step:8000/100000 train_loss:2.4995 time:6066071ms +step:8000 val_loss:2.5021 val_bpb:1.1106 time:6066092ms +step:8200/100000 train_loss:2.4711 time:6217879ms +step:8400/100000 train_loss:2.4662 time:6369957ms +step:8600/100000 train_loss:2.5024 time:6521578ms +step:8800/100000 train_loss:2.5066 time:6673355ms +step:9000/100000 train_loss:2.4672 time:6825267ms +step:9000 val_loss:2.4987 val_bpb:1.1091 time:6825288ms +step:9200/100000 train_loss:2.4584 time:6977408ms +step:9400/100000 train_loss:2.5074 time:7128942ms +step:9600/100000 train_loss:2.5532 time:7280608ms +step:9800/100000 train_loss:2.4601 time:7432461ms +step:10000/100000 train_loss:2.4127 time:7584463ms +step:10000 val_loss:2.4944 val_bpb:1.1072 time:7584484ms +step:10200/100000 train_loss:2.5603 time:7736687ms +step:10400/100000 train_loss:2.5231 time:7888817ms +step:10600/100000 train_loss:2.4291 time:8040783ms +step:10800/100000 train_loss:2.5056 time:8192680ms +step:11000/100000 train_loss:2.4936 time:8344464ms +step:11000 val_loss:2.4889 val_bpb:1.1048 time:8344485ms +step:11200/100000 train_loss:2.4644 time:8496085ms +step:11400/100000 train_loss:2.4741 time:8647738ms +step:11600/100000 train_loss:2.4708 time:8799533ms +step:11800/100000 train_loss:2.5347 time:8951268ms +step:12000/100000 train_loss:2.4289 time:9103408ms +step:12000 val_loss:2.4834 val_bpb:1.1023 time:9103429ms +step:12200/100000 train_loss:2.4883 time:9255743ms +step:12400/100000 train_loss:2.4296 time:9408350ms +step:12600/100000 train_loss:2.5301 time:9560102ms +step:12800/100000 train_loss:2.5795 time:9711890ms +step:13000/100000 train_loss:2.7402 time:9863592ms +step:13000 val_loss:2.4899 val_bpb:1.1052 time:9863613ms +step:13200/100000 train_loss:2.4913 time:10015550ms +step:13400/100000 train_loss:2.4622 time:10167380ms +step:13600/100000 train_loss:2.4446 time:10319273ms +step:13800/100000 train_loss:2.4344 time:10471250ms +step:14000/100000 train_loss:2.4720 time:10623433ms +step:14000 val_loss:2.4782 val_bpb:1.1000 time:10623455ms +step:14200/100000 train_loss:2.4876 time:10775335ms +step:14400/100000 train_loss:2.4933 time:10927025ms +step:14600/100000 train_loss:2.4702 time:11078493ms +step:14800/100000 train_loss:2.4679 time:11230297ms +step:15000/100000 train_loss:2.5051 time:11381884ms +step:15000 val_loss:2.4733 val_bpb:1.0978 time:11381905ms +step:15200/100000 train_loss:2.5082 time:11534054ms +step:15400/100000 train_loss:2.5190 time:11686684ms +step:15600/100000 train_loss:2.4773 time:11839242ms +step:15800/100000 train_loss:2.3639 time:11991206ms +step:16000/100000 train_loss:2.4986 time:12143507ms +step:16000 val_loss:2.4725 val_bpb:1.0975 time:12143528ms +step:16200/100000 train_loss:2.4063 time:12295757ms +step:16400/100000 train_loss:2.4459 time:12447188ms +step:16600/100000 train_loss:2.4200 time:12599559ms +step:16800/100000 train_loss:2.4754 time:12751136ms +step:17000/100000 train_loss:2.4594 time:12902716ms +step:17000 val_loss:2.4710 val_bpb:1.0968 time:12902737ms +step:17200/100000 train_loss:2.5226 time:13054498ms +step:17400/100000 train_loss:2.4660 time:13206467ms +step:17600/100000 train_loss:2.4316 time:13358161ms +step:17800/100000 train_loss:2.5061 time:13509861ms +step:18000/100000 train_loss:2.4852 time:13661645ms +step:18000 val_loss:2.4690 val_bpb:1.0959 time:13661666ms +step:18200/100000 train_loss:2.4749 time:13813644ms +step:18400/100000 train_loss:2.4240 time:13965345ms +step:18600/100000 train_loss:2.5215 time:14117045ms +step:18800/100000 train_loss:2.5086 time:14268787ms +step:19000/100000 train_loss:2.4822 time:14421122ms +step:19000 val_loss:2.4459 val_bpb:1.0857 time:14421143ms +Running EXP3... +model_params: 4722704 +step:0 val_loss:6.9313 val_bpb:3.0767 time:0ms +step:1/100000 train_loss:6.9316 time:1366ms +step:200/100000 train_loss:3.3504 time:252792ms +step:400/100000 train_loss:2.7117 time:505701ms +step:600/100000 train_loss:2.7676 time:758655ms +step:800/100000 train_loss:2.6034 time:1011610ms +step:1000/100000 train_loss:2.6137 time:1264517ms +step:1000 val_loss:2.5906 val_bpb:1.1499 time:1264544ms +step:1200/100000 train_loss:2.5432 time:1517485ms +step:1400/100000 train_loss:2.5791 time:1770502ms +step:1600/100000 train_loss:2.4673 time:2023463ms +step:1800/100000 train_loss:2.5169 time:2276429ms +step:2000/100000 train_loss:2.4716 time:2529429ms +step:2000 val_loss:2.4860 val_bpb:1.1035 time:2529456ms +step:2200/100000 train_loss:2.4066 time:2782329ms +step:2400/100000 train_loss:2.4410 time:3035326ms +step:2600/100000 train_loss:2.5189 time:3288278ms +step:2800/100000 train_loss:2.4805 time:3541314ms +step:3000/100000 train_loss:2.3908 time:3794405ms +step:3000 val_loss:2.4431 val_bpb:1.0844 time:3794432ms +step:3200/100000 train_loss:2.4809 time:4047561ms +step:3400/100000 train_loss:2.4615 time:4301295ms +step:3600/100000 train_loss:2.4051 time:4554080ms +step:3800/100000 train_loss:2.4890 time:4807127ms +step:4000/100000 train_loss:2.3803 time:5060054ms +step:4000 val_loss:2.4200 val_bpb:1.0742 time:5060082ms +step:4200/100000 train_loss:2.4535 time:5313409ms +step:4400/100000 train_loss:2.4609 time:5566296ms +step:4600/100000 train_loss:2.3022 time:5819221ms +step:4800/100000 train_loss:2.4082 time:6072129ms +step:5000/100000 train_loss:2.4053 time:6325152ms +step:5000 val_loss:2.4035 val_bpb:1.0668 time:6325179ms +step:5200/100000 train_loss:2.4255 time:6578051ms +step:5400/100000 train_loss:2.4298 time:6830983ms +step:5600/100000 train_loss:2.4312 time:7083927ms +step:5800/100000 train_loss:2.4098 time:7336854ms +step:6000/100000 train_loss:2.5350 time:7589790ms +step:6000 val_loss:2.3963 val_bpb:1.0636 time:7589817ms +step:6200/100000 train_loss:2.3915 time:7842747ms +step:6400/100000 train_loss:2.4177 time:8095560ms +step:6600/100000 train_loss:2.3729 time:8348483ms +step:6800/100000 train_loss:2.4950 time:8601352ms +step:7000/100000 train_loss:2.3784 time:8854296ms +step:7000 val_loss:2.3845 val_bpb:1.0584 time:8854323ms +step:7200/100000 train_loss:2.4101 time:9107260ms +step:7400/100000 train_loss:2.3841 time:9360233ms +step:7600/100000 train_loss:2.3276 time:9613224ms +step:7800/100000 train_loss:2.3774 time:9866197ms +step:8000/100000 train_loss:2.3701 time:10119195ms +step:8000 val_loss:2.3764 val_bpb:1.0548 time:10119222ms +step:8200/100000 train_loss:2.3478 time:10372158ms +step:8400/100000 train_loss:2.3402 time:10625650ms +step:8600/100000 train_loss:2.3790 time:10878633ms +step:8800/100000 train_loss:2.3785 time:11131675ms +step:9000/100000 train_loss:2.3425 time:11384741ms +step:9000 val_loss:2.3722 val_bpb:1.0530 time:11384768ms +step:9200/100000 train_loss:2.3309 time:11637707ms +step:9400/100000 train_loss:2.3777 time:11890682ms +step:9600/100000 train_loss:2.4305 time:12143596ms +step:9800/100000 train_loss:2.3388 time:12396581ms +step:10000/100000 train_loss:2.2825 time:12649568ms +step:10000 val_loss:2.3684 val_bpb:1.0513 time:12649595ms +step:10200/100000 train_loss:2.4388 time:12902507ms +step:10400/100000 train_loss:2.3939 time:13155443ms +step:10600/100000 train_loss:2.3045 time:13408383ms +step:10800/100000 train_loss:2.3808 time:13661295ms +step:11000/100000 train_loss:2.3669 time:13914261ms +step:11000 val_loss:2.3632 val_bpb:1.0490 time:13914288ms +step:11200/100000 train_loss:2.3381 time:14167332ms +step:11400/100000 train_loss:2.3324 time:14420201ms +step:11600/100000 train_loss:2.3328 time:14673080ms +step:11800/100000 train_loss:2.3933 time:14926023ms diff --git a/nohup_xl.out b/nohup_xl.out new file mode 100644 index 0000000000..30fd3c23a5 --- /dev/null +++ b/nohup_xl.out @@ -0,0 +1,46 @@ +model_params: 7214100 +step:0 val_loss:6.9381 val_bpb:3.0797 time:0ms +Traceback (most recent call last): + File "/workspace/parameter-golf/train_gpt_optimized.py", line 523, in + if __name__ == "__main__": main() + ^^^^^^ + File "/workspace/parameter-golf/train_gpt_optimized.py", line 504, in main + loss = base_model(x, y) + ^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt_optimized.py", line 433, in forward + x = self.unique_blocks[block_idx](x, x0) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt_optimized.py", line 413, in forward + x = x + self.mlp_scale.to(x.dtype) * self.mlp(self.mlp_norm(x)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt_optimized.py", line 399, in forward + return self.proj(torch.relu(self.fc(x)).square()) + ^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/parameter-golf/train_gpt_optimized.py", line 356, in forward + return F.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 610.00 MiB. GPU 0 has a total capacity of 79.18 GiB of which 294.19 MiB is free. Including non-PyTorch memory, this process has 78.88 GiB memory in use. Of the allocated memory 77.83 GiB is allocated by PyTorch, and 399.57 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) diff --git a/run_all.sh b/run_all.sh new file mode 100644 index 0000000000..9effc706b8 --- /dev/null +++ b/run_all.sh @@ -0,0 +1,17 @@ +cd ~/parameter-golf/parameter-golf || exit +pkill -f train_gpt_optimized.py || true +rm -f logs/exp1.txt logs/exp2.txt +mkdir -p logs + +RUN_ID=exp1 NUM_LAYERS=4 MODEL_DIM=256 NUM_HEADS=4 NUM_KV_HEADS=4 ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=3600 DATA_PATH=./data/datasets/fineweb10B_sp1024 PYTHONUNBUFFERED=1 python3 train_gpt_optimized.py | tee logs/exp1.txt + +tail -n 50 logs/exp1.txt + +RUN_ID=exp2 NUM_LAYERS=6 MODEL_DIM=384 NUM_HEADS=6 NUM_KV_HEADS=6 ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=3600 DATA_PATH=./data/datasets/fineweb10B_sp1024 PYTHONUNBUFFERED=1 python3 train_gpt_optimized.py | tee logs/exp2.txt + +tail -n 50 logs/exp2.txt + +echo "==== EXP1 ====" +tail -n 10 logs/exp1.txt +echo "==== EXP2 ====" +tail -n 10 logs/exp2.txt diff --git a/run_big.sh b/run_big.sh new file mode 100644 index 0000000000..ce1e8e9105 --- /dev/null +++ b/run_big.sh @@ -0,0 +1,5 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 +mkdir -p logs +pkill -f train_gpt_optimized.py || true +RUN_ID=exp_big NUM_LAYERS=8 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=8 ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=14400 DATA_PATH=./data/datasets/fineweb10B_sp1024 python3 train_gpt_optimized.py | tee logs/exp_big.txt diff --git a/run_suite.sh b/run_suite.sh new file mode 100644 index 0000000000..b7c306a080 --- /dev/null +++ b/run_suite.sh @@ -0,0 +1,40 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 + +mkdir -p logs +mkdir -p results +mkdir -p models + +pkill -f train_gpt_optimized.py || true + +EXP1="RUN_ID=exp_small NUM_LAYERS=4 MODEL_DIM=256 NUM_HEADS=4 NUM_KV_HEADS=4" +EXP2="RUN_ID=exp_medium NUM_LAYERS=6 MODEL_DIM=384 NUM_HEADS=6 NUM_KV_HEADS=6" +EXP3="RUN_ID=exp_large NUM_LAYERS=8 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=8" + +COMMON="ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=14400 DATA_PATH=./data/datasets/fineweb10B_sp1024 PYTHONUNBUFFERED=1" + +echo "Running EXP1..." +eval $EXP1 $COMMON python3 train_gpt_optimized.py | tee logs/exp_small.txt +cp final_model.int8.ptz models/exp_small.ptz + +echo "Running EXP2..." +eval $EXP2 $COMMON python3 train_gpt_optimized.py | tee logs/exp_medium.txt +cp final_model.int8.ptz models/exp_medium.ptz + +echo "Running EXP3..." +eval $EXP3 $COMMON python3 train_gpt_optimized.py | tee logs/exp_large.txt +cp final_model.int8.ptz models/exp_large.ptz + +echo "EXPERIMENT RESULTS" > results/summary.txt +echo "==================" >> results/summary.txt + +for exp in small medium large +do + echo "" >> results/summary.txt + echo "EXP_$exp" >> results/summary.txt + tail -n 20 logs/exp_${exp}.txt | grep -E "val_loss|val_bpb" >> results/summary.txt + echo "Model size:" >> results/summary.txt + ls -lh models/exp_${exp}.ptz | awk '{print $5}' >> results/summary.txt +done + +cat results/summary.txt diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000..3220ea8b72 --- /dev/null +++ b/start.sh @@ -0,0 +1,5 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 +pkill -f train_gpt_optimized.py +rm -f logs/exp_big.txt +RUN_ID=exp_big NUM_LAYERS=8 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=8 ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=14400 DATA_PATH=./data/datasets/fineweb10B_sp1024 PYTHONUNBUFFERED=1 python3 train_gpt_optimized.py | tee logs/exp_big.txt diff --git a/start_36h_suite.sh b/start_36h_suite.sh new file mode 100644 index 0000000000..e5d661e7f2 --- /dev/null +++ b/start_36h_suite.sh @@ -0,0 +1,58 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 +mkdir -p logs +mkdir -p models + +# Helper to run an experiment with optimized script +run_exp() { + echo "Starting $RUN_ID..." + PYTHONUNBUFFERED=1 python3 train_gpt_optimized.py | tee logs/$RUN_ID.txt + cp final_model.int8.ptz models/$RUN_ID.ptz +} + +# EXP 1 (Deep Stable) +export RUN_ID=exp_deep_stable +export NUM_LAYERS=12 +export MODEL_DIM=640 +export NUM_HEADS=10 +export NUM_KV_HEADS=10 +export TRAIN_BATCH_TOKENS=393216 +export EMBED_LR=0.4 +export TIED_EMBED_LR=0.03 +export WARMUP_STEPS=40 +export WARMDOWN_ITERS=2000 +export MAX_WALLCLOCK_SECONDS=43200 +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +run_exp + +# EXP 2 (Wide Aggressive) +export RUN_ID=exp_wide_aggressive +export NUM_LAYERS=10 +export MODEL_DIM=768 +export NUM_HEADS=12 +export NUM_KV_HEADS=12 +export TRAIN_BATCH_TOKENS=524288 +export EMBED_LR=0.5 +export TIED_EMBED_LR=0.035 +export WARMUP_STEPS=30 +export MAX_WALLCLOCK_SECONDS=43200 +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +run_exp + +# EXP 3 (Balanced Peak) +export RUN_ID=exp_balanced_peak +export NUM_LAYERS=12 +export MODEL_DIM=768 +export NUM_HEADS=12 +export NUM_KV_HEADS=12 +export TRAIN_BATCH_TOKENS=524288 +export EMBED_LR=0.42 +export TIED_EMBED_LR=0.032 +export WARMUP_STEPS=50 +export WARMDOWN_ITERS=2500 +export MAX_WALLCLOCK_SECONDS=43200 +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +run_exp + +echo "All experiments completed. Logs available in logs folder." +ls -lh logs diff --git a/start_gpu.sh b/start_gpu.sh new file mode 100644 index 0000000000..d6c4a692fa --- /dev/null +++ b/start_gpu.sh @@ -0,0 +1,6 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 +mkdir -p logs +pkill -f train_gpt_optimized.py || true +rm -f logs/exp_gpu.txt +RUN_ID=exp_gpu NUM_LAYERS=8 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=8 TRAIN_BATCH_TOKENS=2000000 ITERATIONS=100000 MAX_WALLCLOCK_SECONDS=14400 DATA_PATH=./data/datasets/fineweb10B_sp1024 PYTHONUNBUFFERED=1 python3 train_gpt_optimized.py | tee logs/exp_gpu.txt diff --git a/start_xl.sh b/start_xl.sh new file mode 100644 index 0000000000..8b3a3d5d4c --- /dev/null +++ b/start_xl.sh @@ -0,0 +1,17 @@ +#!/bin/bash +cd ~/parameter-golf/parameter-golf || exit 1 +mkdir -p logs +pkill -f train_gpt_optimized.py || true +rm -f logs/exp_xl.txt + +RUN_ID=exp_xl \ +NUM_LAYERS=10 \ +MODEL_DIM=640 \ +NUM_HEADS=10 \ +NUM_KV_HEADS=10 \ +TRAIN_BATCH_TOKENS=1000000 \ +ITERATIONS=150000 \ +MAX_WALLCLOCK_SECONDS=21600 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +PYTHONUNBUFFERED=1 \ +python3 train_gpt_optimized.py | tee logs/exp_xl.txt diff --git a/summary.txt b/summary.txt new file mode 100644 index 0000000000..fbe47a0e92 --- /dev/null +++ b/summary.txt @@ -0,0 +1,41 @@ +EXPERIMENT RESULTS +================== + +EXP_small +step:19000 val_loss:2.6525 val_bpb:1.1774 time:12496759ms +step:20000 val_loss:2.6508 val_bpb:1.1766 time:13155280ms +step:21000 val_loss:2.6501 val_bpb:1.1763 time:13813793ms +step:22000 val_loss:2.6291 val_bpb:1.1670 time:14470806ms +Model size: +1.1M + +EXP_medium +step:16000 val_loss:2.4725 val_bpb:1.0975 time:12143528ms +step:17000 val_loss:2.4710 val_bpb:1.0968 time:12902737ms +step:18000 val_loss:2.4690 val_bpb:1.0959 time:13661666ms +step:19000 val_loss:2.4459 val_bpb:1.0857 time:14421143ms +Model size: +525K + +EXP_large +step:9000 val_loss:2.3722 val_bpb:1.0530 time:11384768ms + +EXP_deep_stable +step:20000 val_loss:2.2451 val_bpb:0.9965 +Status: COMPLETED +Model Size: 1.13MB + +EXP_wide_aggressive +step:19000 val_loss:2.1748 val_bpb:0.9653 +Status: COMPLETED +Model Size: 1.13MB + +EXP_balanced_peak +step:16000 val_loss:2.1687 val_bpb:0.9627 +Status: COMPLETED (🏆 OVERALL CHAMPION) +Model Size: 1.13MB + +======================================== +36-HOUR OPTIMIZATION SUITE: FINISHED +Best BPB: 0.9627 (Balanced Peak) +======================================== diff --git a/train_gpt_optimized.py b/train_gpt_optimized.py new file mode 100644 index 0000000000..427275d551 --- /dev/null +++ b/train_gpt_optimized.py @@ -0,0 +1,523 @@ +""" +Optimized version of train_gpt.py for Parameter Golf. +Changes: +1. Depth Recurrence: Reusing block weights to increase depth while staying under parameter limits. +2. Increased Width: model_dim=1024 (up from 512). +3. Fixed enable_gqa bug. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + # Ensure it's a multiple of seq_len * 8 (default grad_accum) to prevent reshape crashes + train_batch_tokens = (train_batch_tokens // 8192) * 8192 + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + # Depth Recurrence Config: + num_steps = int(os.environ.get("NUM_LAYERS", 12)) # Total transformer steps + num_unique_blocks = int(os.environ.get("NUM_UNIQUE_BLOCKS", 2)) # Number of unique weight-sets + model_dim = int(os.environ.get("MODEL_DIM", 1024)) + num_heads = int(os.environ.get("NUM_HEADS", 16)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", num_heads)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith(" "): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), 0.9999984 / 100.0, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=torch.float16).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), 0.9999984 / 100.0).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized, scales, dtypes, passthrough, passthrough_orig_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.numel() <= 65536: + passthrough[name] = t.to(torch.float16) + stats["int8_payload_bytes"] += tensor_nbytes(passthrough[name]) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name], scales[name], dtypes[name] = q, s, str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj = {"__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, "scales": scales, "dtypes": dtypes, "passthrough": passthrough, "qmeta": qmeta} + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out = {} + qmeta = obj.get("qmeta", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + out[name] = (q.float() * s.to(torch.float32).view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out[name] = t.float() + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" tuple[Tensor, Tensor]: + n = (global_tokens // (self.world_size * grad_accum_steps)) + 1 + chunks = [] + while n > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens, self.pos = load_data_shard(self.files[self.file_idx]), 0 + continue + k = min(n, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos, n = self.pos + k, n - k + local = torch.cat(chunks).to(dtype=torch.int64) + return local[:-1].reshape(-1, seq_len).to(self.device), local[1:].reshape(-1, seq_len).to(self.device) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=1e-6) + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + t = torch.arange(seq_len, device=device).float() + freqs = torch.outer(t, self.inv_freq.to(device)) + return freqs.cos()[None, None, :, :].to(dtype), freqs.sin()[None, None, :, :].to(dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + return torch.cat((x[..., :half] * cos + x[..., half:] * sin, x[..., :half] * (-sin) + x[..., half:] * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + self.num_heads, self.num_kv_heads, self.head_dim = num_heads, num_kv_heads, dim // num_heads + self.c_q, self.c_k, self.c_v, self.proj = CastedLinear(dim, dim, bias=False), CastedLinear(dim, num_kv_heads * self.head_dim, bias=False), CastedLinear(dim, num_kv_heads * self.head_dim, bias=False), CastedLinear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init)) + self.rotary = Rotary(self.head_dim, base=rope_base) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Manually expand K, V for GQA if heads don't match + if self.num_kv_heads != self.num_heads: + k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.proj(y.transpose(1, 2).reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + self.fc, self.proj = CastedLinear(dim, mlp_mult * dim, bias=False), CastedLinear(mlp_mult * dim, dim, bias=False) + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale, self.mlp_scale = nn.Parameter(torch.ones(dim)), nn.Parameter(torch.ones(dim)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim)))) + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + x = x + self.attn_scale.to(x.dtype) * self.attn(self.attn_norm(x)) + x = x + self.mlp_scale.to(x.dtype) * self.mlp(self.mlp_norm(x)) + return x + +class GPT(nn.Module): + def __init__(self, args: Hyperparameters): + super().__init__() + self.args = args + self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim) + # Unique blocks for recurrence + self.unique_blocks = nn.ModuleList([Block(args.model_dim, args.num_heads, args.num_kv_heads, args.mlp_mult, args.rope_base, args.qk_gain_init) for _ in range(args.num_unique_blocks)]) + self.final_norm = RMSNorm() + self.lm_head = None if args.tie_embeddings else CastedLinear(args.model_dim, args.vocab_size, bias=False) + nn.init.normal_(self.tok_emb.weight, std=args.tied_embed_init_std) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = F.rms_norm(self.tok_emb(input_ids), (self.args.model_dim,)) + x0 = x + # Loop through steps, cycling through unique blocks + for i in range(self.args.num_steps): + block_idx = i % self.args.num_unique_blocks + x = self.unique_blocks[block_idx](x, x0) + + x = self.final_norm(x).reshape(-1, self.args.model_dim) + logits_proj = F.linear(x, self.tok_emb.weight) if self.args.tie_embeddings else self.lm_head(x) + logits = self.args.logit_softcap * torch.tanh(logits_proj / self.args.logit_softcap) + return F.cross_entropy(logits.float(), target_ids.reshape(-1)) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main(): + args = Hyperparameters() + rank, world_size, local_rank = int(os.environ.get("RANK", 0)), int(os.environ.get("WORLD_SIZE", 1)), int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device("cuda", local_rank) + if world_size > 1: dist.init_process_group(backend="nccl") + + torch.cuda.set_device(device) + torch.backends.cuda.matmul.allow_tf32 = torch.backends.cudnn.allow_tf32 = True + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + luts = build_sentencepiece_luts(sp, args.vocab_size, device) + + base_model = GPT(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + + # Optimizer setup + mat_params, scal_params = [], [] + for name, p in base_model.named_parameters(): + if "tok_emb" in name: continue + if p.ndim == 2 and not any(c in name for c in CONTROL_TENSOR_NAME_PATTERNS): mat_params.append(p) + else: scal_params.append(p) + + opt_tok = torch.optim.Adam([{"params": [base_model.tok_emb.weight], "lr": args.tied_embed_lr, "base_lr": args.tied_embed_lr}], betas=(args.beta1, args.beta2), fused=True) + opt_muon = Muon(mat_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scal = torch.optim.Adam([{"params": scal_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), fused=True) + opts = [opt_tok, opt_muon, opt_scal] + + log0 = lambda m: print(m) if rank == 0 else None + log0(f"model_params: {sum(p.numel() for p in base_model.parameters())}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + grad_accum = 8 // world_size + + t0, training_time_ms, step = time.perf_counter(), 0.0, 0 + while step <= args.iterations: + if step % args.val_loss_every == 0 or step == args.iterations: + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + v_loss, v_bpb = eval_val(args, base_model, rank, world_size, device, grad_accum, val_tokens, *luts) + log0(f"step:{step} val_loss:{v_loss:.4f} val_bpb:{v_bpb:.4f} time:{training_time_ms:.0f}ms") + if training_time_ms >= args.max_wallclock_seconds * 1000: break + t0 = time.perf_counter() + + if step == args.iterations: break + + # LR Warmdown + frac = max(0, (args.max_wallclock_seconds * 1000 - (training_time_ms + 1000*(time.perf_counter()-t0))) / (args.warmdown_iters * 30)) # heuristic + scale = min(1.0, frac) + + for opt in opts: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + + for opt in opts: opt.zero_grad(set_to_none=True) + train_loss = 0.0 + for _ in range(grad_accum): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + train_loss += loss.item() + (loss / grad_accum).backward() + for opt in opts: opt.step() + train_loss /= grad_accum + + step += 1 + if step % args.train_log_every == 0 or step == 1 or step == args.iterations: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss:.4f} time:{training_time_ms + 1000*(time.perf_counter()-t0):.0f}ms") + + + if rank == 0: + # Quantize and save + q_obj, _ = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(q_obj, buf) + with open("final_model.int8.ptz", "wb") as f: + f.write(zlib.compress(buf.getvalue(), level=9)) + +if __name__ == "__main__": main() diff --git a/train_gpt_remote.py b/train_gpt_remote.py new file mode 100644 index 0000000000..e7bafeb818 --- /dev/null +++ b/train_gpt_remote.py @@ -0,0 +1,1126 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()