forked from karpathy/autoresearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
798 lines (691 loc) · 32 KB
/
train.py
File metadata and controls
798 lines (691 loc) · 32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
"""
Autoresearch pretraining script. Single-GPU, single-file.
Cherry-picked and simplified from nanochat.
Usage: uv run train.py
"""
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import gc
import time
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from kernels import get_kernel
cap = torch.cuda.get_device_capability()
# Flash Attention 3 does not currently run on this Blackwell-class 5070, so
# use PyTorch SDPA unless we're on the Hopper path this repo was tuned for.
USE_FLASH_ATTN3 = cap == (9, 0)
if USE_FLASH_ATTN3:
fa3 = get_kernel("varunneal/flash-attention-3").flash_attn_interface
else:
fa3 = None
from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
def getenv_int(name, default):
value = os.getenv(name)
return default if value is None else int(value)
def getenv_str(name, default):
value = os.getenv(name)
return default if value is None else value
def getenv_bool(name, default):
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def preset_value(name, default, presets):
preset_name = os.getenv("AR_PRESET", "").strip().lower()
if not preset_name:
return default
return presets.get(preset_name, {}).get(name, default)
def get_device_peak_flops():
override = os.getenv("AR_PEAK_FLOPS")
if override is not None:
return float(override), "AR_PEAK_FLOPS"
device_name = torch.cuda.get_device_name()
cap = torch.cuda.get_device_capability()
name_upper = device_name.upper()
# Keep this mapping small and operational: enough to avoid reporting H100
# MFU on clearly different RTX/A100-class cards without adding a full db.
if cap == (10, 0) or "B200" in name_upper or "B100" in name_upper:
return 2250e12, "auto:blackwell-datacenter"
if cap == (12, 0):
return 988e12, "auto:blackwell-rtx"
if cap == (9, 0):
return 989.5e12, "auto:hopper"
if "A100" in name_upper:
return 312e12, "auto:a100"
if "RTX 4090" in name_upper:
return 330e12, "auto:rtx4090"
if "RTX 4080" in name_upper:
return 242e12, "auto:rtx4080"
if "RTX 3090" in name_upper:
return 142e12, "auto:rtx3090"
if "RTX 3080" in name_upper:
return 119e12, "auto:rtx3080"
if "RTX 3070" in name_upper:
return 81e12, "auto:rtx3070"
return 989.5e12, "fallback:h100"
# ---------------------------------------------------------------------------
# GPT Model
# ---------------------------------------------------------------------------
@dataclass
class GPTConfig:
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 768
window_pattern: str = "SSSL"
def norm(x):
return F.rms_norm(x, (x.size(-1),))
def has_ve(layer_idx, n_layer):
"""Returns True if layer should have Value Embedding (alternating, last always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
def make_sliding_window_mask(seq_len, left_window, device):
q_idx = torch.arange(seq_len, device=device)[:, None]
k_idx = torch.arange(seq_len, device=device)[None, :]
return (k_idx <= q_idx) & (k_idx >= q_idx - (left_window - 1))
def run_attention(q, k, v, window_size, attn_mask=None):
if USE_FLASH_ATTN3:
return fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=attn_mask is None)
return y.transpose(1, 2)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size, attn_mask):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
v = v + gate.unsqueeze(-1) * ve
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
y = run_attention(q, k, v, window_size, attn_mask=attn_mask)
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, ve, cos_sin, window_size, attn_mask):
x = x + self.attn(norm(x), ve, cos_sin, window_size, attn_mask)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.window_sizes = self._compute_window_sizes(config)
self.sdpa_mask_names = self._build_sdpa_mask_names()
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
# Value embeddings
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({
str(i): nn.Embedding(config.vocab_size, kv_dim)
for i in range(config.n_layer) if has_ve(i, config.n_layer)
})
# Rotary embeddings
self.rotary_seq_len = config.sequence_len
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
if not USE_FLASH_ATTN3:
for mask_name, left_window in self.sdpa_mask_names.items():
mask = make_sliding_window_mask(config.sequence_len, left_window, device="cpu")
self.register_buffer(mask_name, mask, persistent=False)
@torch.no_grad()
def init_weights(self):
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
self.resid_lambdas.fill_(1.0)
self.x0_lambdas.fill_(0.1)
# Value embeddings
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast embeddings to bf16
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def _compute_window_sizes(self, config):
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern)
long_window = config.sequence_len
short_window = long_window // 2
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
window_sizes[-1] = (long_window, 0)
return window_sizes
def _build_sdpa_mask_names(self):
mask_names = {}
if USE_FLASH_ATTN3:
return mask_names
for left_window, _ in self.window_sizes:
if left_window < self.config.sequence_len:
mask_names[f"sdpa_mask_{left_window}"] = left_window
return mask_names
def _get_sdpa_attn_mask(self, window_size, seq_len):
if USE_FLASH_ATTN3:
return None
left_window = window_size[0]
if left_window >= self.config.sequence_len or left_window >= seq_len:
return None
mask_name = f"sdpa_mask_{left_window}"
mask = getattr(self, mask_name)
return mask[:seq_len, :seq_len]
def estimate_flops(self):
"""Estimated FLOPs per token (forward + backward)."""
nparams = sum(p.numel() for p in self.parameters())
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
h = self.config.n_head
q = self.config.n_embd // self.config.n_head
t = self.config.sequence_len
attn_flops = 0
for window_size in self.window_sizes:
window = window_size[0]
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
return 6 * (nparams - nparams_exclude) + attn_flops
def num_scaling_params(self):
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
return {
'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5,
embedding_weight_decay=0.0, value_embed_weight_decay=0.0,
lm_head_weight_decay=0.0):
model_dim = self.config.n_embd
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
# Scale LR ∝ 1/√dmodel (tuned at 768 dim)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
param_groups = [
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=lm_head_weight_decay),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=embedding_weight_decay),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=value_embed_weight_decay),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
]
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
optimizer = MuonAdamW(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, reduction='mean'):
B, T = idx.size()
assert T <= self.cos.size(1)
cos_sin = self.cos[:, :T], self.sin[:, :T]
x = self.transformer.wte(idx)
x = norm(x)
x0 = x
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
attn_mask = self._get_sdpa_attn_mask(self.window_sizes[i], T)
x = block(x, ve, cos_sin, self.window_sizes[i], attn_mask)
x = norm(x)
softcap = 15
logits = self.lm_head(x)
logits = logits.float()
logits = softcap * torch.tanh(logits / softcap)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
ignore_index=-1, reduction=reduction)
return loss
return logits
# ---------------------------------------------------------------------------
# Optimizer (MuonAdamW, single GPU only)
# ---------------------------------------------------------------------------
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# NorMuon variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class MuonAdamW(torch.optim.Optimizer):
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
def __init__(self, param_groups):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
def _step_muon(self, group):
params = group['params']
if not params:
return
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(stacked_grads, stacked_params,
state["momentum_buffer"], state["second_momentum_buffer"],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
self._muon_beta2_t, group["ns_steps"], red_dim)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
# ---------------------------------------------------------------------------
# Experiment surface
# ---------------------------------------------------------------------------
PRESET_NAME = os.getenv("AR_PRESET", "").strip().lower()
# Model architecture
ASPECT_RATIO = getenv_int("AR_ASPECT_RATIO", preset_value("aspect_ratio", 64, {
"8gb": {"aspect_ratio": 48},
"12gb": {"aspect_ratio": 64},
"h100": {"aspect_ratio": 64},
})) # model_dim = depth * ASPECT_RATIO
HEAD_DIM = getenv_int("AR_HEAD_DIM", preset_value("head_dim", 128, {
"8gb": {"head_dim": 128},
"12gb": {"head_dim": 128},
"h100": {"head_dim": 128},
})) # target head dimension for attention
WINDOW_PATTERN = getenv_str("AR_WINDOW_PATTERN", preset_value("window_pattern", "SSSL", {
"8gb": {"window_pattern": "SSSL"},
"12gb": {"window_pattern": "L"},
"h100": {"window_pattern": "SSSL"},
})) # L=full, S=half context
# Optimization
TOTAL_BATCH_SIZE = getenv_int("AR_TOTAL_BATCH_SIZE", preset_value("total_batch_size", 2**16, {
"8gb": {"total_batch_size": 2**15},
"12gb": {"total_batch_size": 2**15},
"h100": {"total_batch_size": 2**16},
})) # global tokens / optimizer step
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
EMBEDDING_WEIGHT_DECAY = float(os.getenv("AR_EMBEDDING_WEIGHT_DECAY", "0.0"))
VALUE_EMBED_WEIGHT_DECAY = float(os.getenv("AR_VALUE_EMBED_WEIGHT_DECAY", "0.0"))
LM_HEAD_WEIGHT_DECAY = float(os.getenv("AR_LM_HEAD_WEIGHT_DECAY", "0.0"))
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
WARMDOWN_RATIO = float(os.getenv("AR_WARMDOWN_RATIO", str(preset_value("warmdown_ratio", 0.5, {
"8gb": {"warmdown_ratio": 0.5},
"12gb": {"warmdown_ratio": 0.75},
"h100": {"warmdown_ratio": 0.5},
})))) # fraction of time budget for LR warmdown
FINAL_LR_FRAC = float(os.getenv("AR_FINAL_LR_FRAC", str(preset_value("final_lr_frac", 0.0, {
"8gb": {"final_lr_frac": 0.0},
"12gb": {"final_lr_frac": 0.05},
"h100": {"final_lr_frac": 0.0},
})))) # final LR as fraction of initial
# Model size
DEPTH = getenv_int("AR_DEPTH", preset_value("depth", 8, {
"8gb": {"depth": 6},
"12gb": {"depth": 8},
"h100": {"depth": 8},
})) # number of transformer layers
DEVICE_BATCH_SIZE = getenv_int("AR_DEVICE_BATCH_SIZE", preset_value("device_batch_size", 8, {
"8gb": {"device_batch_size": 2},
"12gb": {"device_batch_size": 4},
"h100": {"device_batch_size": 8},
})) # per-device microbatch size
# Platform controls
COMPILE_MODEL = getenv_bool("AR_COMPILE", preset_value("compile_model", True, {
"8gb": {"compile_model": False},
"12gb": {"compile_model": False},
"h100": {"compile_model": True},
}))
ATTN_BACKEND = "flash-attn3" if USE_FLASH_ATTN3 else "sdpa"
PEAK_FLOPS, PEAK_FLOPS_SOURCE = get_device_peak_flops()
# ---------------------------------------------------------------------------
# Setup: tokenizer, model, optimizer, dataloader
# ---------------------------------------------------------------------------
t_start = time.perf_counter()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
device_name = torch.cuda.get_device_name()
print(f"Vocab size: {vocab_size:,}")
if PRESET_NAME:
print(f"Preset: {PRESET_NAME}")
print(f"GPU: {device_name}")
print(f"Compute capability: {cap}")
print(f"Peak FLOPS: {PEAK_FLOPS / 1e12:.1f} TFLOPS ({PEAK_FLOPS_SOURCE})")
print(f"Attention backend: {ATTN_BACKEND}")
if not USE_FLASH_ATTN3:
print(f"Flash Attention 3 disabled for compute capability {cap}; using PyTorch SDPA.")
print(f"Torch compile: {COMPILE_MODEL}")
print(f"Device batch size: {DEVICE_BATCH_SIZE}")
print(f"Total batch size: {TOTAL_BATCH_SIZE}")
print(f"Warmdown ratio: {WARMDOWN_RATIO}")
print(f"Final LR frac: {FINAL_LR_FRAC}")
print(f"AdamW decay (embed/ve/lm_head): {EMBEDDING_WEIGHT_DECAY}/{VALUE_EMBED_WEIGHT_DECAY}/{LM_HEAD_WEIGHT_DECAY}")
def build_model_config(depth):
base_dim = depth * ASPECT_RATIO
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
num_heads = model_dim // HEAD_DIM
return GPTConfig(
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=WINDOW_PATTERN,
)
config = build_model_config(DEPTH)
print(f"Model config: {asdict(config)}")
with torch.device("meta"):
model = GPT(config)
model.to_empty(device=device)
model.init_weights()
param_counts = model.num_scaling_params()
print("Parameter counts:")
for key, value in param_counts.items():
print(f" {key:24s}: {value:,}")
num_params = param_counts['total']
num_flops_per_token = model.estimate_flops()
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
optimizer = model.setup_optimizer(
unembedding_lr=UNEMBEDDING_LR,
embedding_lr=EMBEDDING_LR,
scalar_lr=SCALAR_LR,
adam_betas=ADAM_BETAS,
matrix_lr=MATRIX_LR,
weight_decay=WEIGHT_DECAY,
embedding_weight_decay=EMBEDDING_WEIGHT_DECAY,
value_embed_weight_decay=VALUE_EMBED_WEIGHT_DECAY,
lm_head_weight_decay=LM_HEAD_WEIGHT_DECAY,
)
if COMPILE_MODEL:
model = torch.compile(model, dynamic=False)
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
x, y, epoch = next(train_loader) # prefetch first batch
print(f"Time budget: {TIME_BUDGET}s")
print(f"Gradient accumulation steps: {grad_accum_steps}")
# Schedules (all based on progress = training_time / TIME_BUDGET)
def get_lr_multiplier(progress):
if progress < WARMUP_RATIO:
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
elif progress < 1.0 - WARMDOWN_RATIO:
return 1.0
else:
cooldown = (1.0 - progress) / WARMDOWN_RATIO
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
def get_muon_momentum(step):
frac = min(step / 300, 1)
return (1 - frac) * 0.85 + frac * 0.95
def get_weight_decay(progress):
return WEIGHT_DECAY * (1 - progress)
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
t_start_training = time.perf_counter()
smooth_train_loss = 0
total_training_time = 0
step = 0
while True:
torch.cuda.synchronize()
t0 = time.perf_counter()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y, epoch = next(train_loader)
# Progress and schedules
progress = min(total_training_time / TIME_BUDGET, 1.0)
lrm = get_lr_multiplier(progress)
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(progress)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item()
# Fast fail: abort if loss is exploding
if train_loss_f > 100:
print("FAIL")
exit(1)
torch.cuda.synchronize()
t1 = time.perf_counter()
dt = t1 - t0
if step >= 10:
total_training_time += dt
# Logging
ema_beta = 0.9
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
pct_done = 100 * progress
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / PEAK_FLOPS
remaining = max(0, TIME_BUDGET - total_training_time)
print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
# GC management (Python's GC causes ~500ms stalls)
if step == 0:
gc.collect()
gc.freeze()
gc.disable()
elif (step + 1) % 5000 == 0:
gc.collect()
step += 1
# Time's up — but only stop after warmup steps so we don't count compilation
if step > 10 and total_training_time >= TIME_BUDGET:
break
print() # newline after \r training log
total_tokens = step * TOTAL_BATCH_SIZE
# Final eval
model.eval()
with autocast_ctx:
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
# Final summary
t_end = time.perf_counter()
startup_time = t_start_training - t_start
steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / PEAK_FLOPS if total_training_time > 0 else 0
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print("---")
print(f"val_bpb: {val_bpb:.6f}")
print(f"training_seconds: {total_training_time:.1f}")
print(f"total_seconds: {t_end - t_start:.1f}")
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
print(f"mfu_percent: {steady_state_mfu:.2f}")
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
print(f"num_steps: {step}")
print(f"num_params_M: {num_params / 1e6:.1f}")
print(f"depth: {DEPTH}")