From a477efbc9f92e0c5a330d66bfcba66455e9392ac Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:22:04 -0700 Subject: [PATCH 01/80] Use ReGLU, Decoder-only, tune hyperparameters, torch.jit, .to optimizations, and partial I/O improvements --- train_gpt.py | 100 ++++++++++++++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..8c7e9c640f 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -47,13 +47,13 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 2000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 400)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 16)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -61,28 +61,28 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) + model_dim = int(os.environ.get("MODEL_DIM", 384)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) + embed_lr = float(os.environ.get("EMBED_LR", 0.5)) + head_lr = float(os.environ.get("HEAD_LR", 0.006)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.035)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) + beta2 = float(os.environ.get("BETA2", 0.975)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) @@ -93,7 +93,8 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: +@torch.jit.script +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) @@ -252,6 +253,12 @@ def eval_val( batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) raw_start = batch_seq_start * args.train_seq_len raw_end = batch_seq_end * args.train_seq_len + 1 + if batch_seq_end < seq_end: + next_raw_start = batch_seq_end * args.train_seq_len + next_raw_end = min(next_raw_start + args.train_seq_len * local_batch_seqs, + seq_end * args.train_seq_len + 1) + torch.cuda.stream(torch.cuda.Stream()).wait_stream(torch.cuda.current_stream()) + _ = val_tokens[next_raw_start:next_raw_end].to(device=device, dtype=torch.int64, non_blocking=True) local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) @@ -488,10 +495,11 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> per_rank_span = local_tokens + 1 chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + # Combine dtype conversion + device transfer in single operation + local = chunk[start : start + per_rank_span].to(device=self.device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + return x, y # ----------------------------- # TRANSFORMER MODULES @@ -507,10 +515,24 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cached_weight_dtype = None + self._cached_weight = None + self._cached_bias_dtype = None + self._cached_bias = None + def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) + target_dtype = x.dtype + if self._cached_weight_dtype != target_dtype: + self._cached_weight = self.weight.to(target_dtype) + self._cached_weight_dtype = target_dtype + if self.bias is not None and self._cached_bias_dtype != target_dtype: + self._cached_bias = self.bias.to(target_dtype) + self._cached_bias_dtype = target_dtype + else: + self._cached_bias = None + return F.linear(x, self._cached_weight, self._cached_bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: @@ -545,7 +567,7 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - +@torch.jit.script def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] @@ -604,17 +626,17 @@ def forward(self, x: Tensor) -> Tensor: class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup + # Using ReGLU as described in Shazeer (2020) but without bias. def __init__(self, dim: int, mlp_mult: int): super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) + hidden = mlp_mult * dim // 1.5 + self.fc1 = CastedLinear(dim, hidden, bias=False) + self.fc2 = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) + x = torch.relu(self.fc1(x)) * self.fc2(x) + return self.proj(x) class Block(nn.Module): @@ -636,9 +658,7 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + def forward(self, x: Tensor) -> Tensor: attn_out = self.attn(self.attn_norm(x)) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) @@ -667,10 +687,7 @@ def __init__( self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.num_layers = num_layers self.blocks = nn.ModuleList( [ Block( @@ -681,7 +698,7 @@ def __init__( rope_base, qk_gain_init, ) - for i in range(num_layers) + for _ in range(num_layers) ] ) self.final_norm = RMSNorm() @@ -700,17 +717,10 @@ def _init_weights(self) -> None: def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + for i in range(self.num_layers): + x = self.blocks[i](x) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) @@ -1082,7 +1092,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if master_process: with open("final_model.int8.ptz", "wb") as f: f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") + quant_file_bytes = len(quant_blob) # ← Use len() instead of getsize() code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( @@ -1095,7 +1105,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() From 117effb674ffb443f797de7b9496e605e53a80ec Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 17:23:08 -0700 Subject: [PATCH 02/80] Fix some stuff --- train_gpt.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 8c7e9c640f..a17ed62e10 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -80,7 +80,7 @@ class Hyperparameters: muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.975)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) @@ -93,7 +93,7 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ -@torch.jit.script + def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. @@ -296,7 +296,7 @@ def eval_val( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", ).split(",") if pattern ) @@ -582,6 +582,7 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + use_rope: bool=True, ): super().__init__() if dim % num_heads != 0: @@ -600,7 +601,7 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) + self.rotary = Rotary(self.head_dim, base=rope_base) if use_rope else None def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape @@ -609,9 +610,10 @@ def forward(self, x: Tensor) -> Tensor: v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + if self.rotary: + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( q, @@ -718,7 +720,6 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) - # First half stores skips; second half reuses them in reverse order. for i in range(self.num_layers): x = self.blocks[i](x) @@ -869,8 +870,6 @@ def log0(msg: str, console: bool = True) -> None: for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], From 9611ceaae14ae175f98325f2a6a5eb2554bc2713 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 17:46:58 -0700 Subject: [PATCH 03/80] Integer lol --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index a17ed62e10..deef7f1a05 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -631,7 +631,7 @@ class MLP(nn.Module): # Using ReGLU as described in Shazeer (2020) but without bias. def __init__(self, dim: int, mlp_mult: int): super().__init__() - hidden = mlp_mult * dim // 1.5 + hidden = int(mlp_mult * dim // 1.5) self.fc1 = CastedLinear(dim, hidden, bias=False) self.fc2 = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) From b4c48cb7fc7f9a7cac2cbdca60304dd3b425f623 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 18:06:30 -0700 Subject: [PATCH 04/80] FEAT: Colab --- train_gpt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index deef7f1a05..e7e1ebf8da 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -257,7 +257,8 @@ def eval_val( next_raw_start = batch_seq_end * args.train_seq_len next_raw_end = min(next_raw_start + args.train_seq_len * local_batch_seqs, seq_end * args.train_seq_len + 1) - torch.cuda.stream(torch.cuda.Stream()).wait_stream(torch.cuda.current_stream()) + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) _ = val_tokens[next_raw_start:next_raw_end].to(device=device, dtype=torch.int64, non_blocking=True) local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) @@ -777,7 +778,7 @@ def main() -> None: enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) - enable_math_sdp(False) + enable_math_sdp(True) logfile = None if master_process: @@ -851,8 +852,8 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=True) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR From 506ef4c87a14ed0aa9b1da92b7670012d84c1a4a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 20:24:25 -0700 Subject: [PATCH 05/80] LFMShortConv --- train_gpt.py | 109 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e7e1ebf8da..fb1e2660da 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -47,13 +47,13 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 2000)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 400)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 16)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -61,13 +61,13 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 384)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) + model_dim = int(os.environ.get("MODEL_DIM", 768)) + num_heads = int(os.environ.get("NUM_HEADS", 12)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + rope_base = float(os.environ.get("ROPE_BASE", 512.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. @@ -75,8 +75,8 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.006)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.035)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.035)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -628,6 +628,47 @@ def forward(self, x: Tensor) -> Tensor: return self.proj(y) +class LFMShortConv(nn.Module): + # Based on Liquid Foundation Models + + def __init__(self, dim, kernel_size=4): + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + + # Use CastedLinear to match the script's optimization for H100s + self.proj = CastedLinear(dim, dim * 3, bias=False) + + # Depthwise convolution: dim parameters instead of dim^2 + self.conv = nn.Conv1d( + dim, dim, + kernel_size=kernel_size, + padding=kernel_size - 1, + groups=dim, + bias=False + ) + + def forward(self, x: Tensor) -> Tensor: + B, L, D = x.shape + + # 1. Project to input (u), gate_b, and gate_c + projected = self.proj(x) + u, gate_b, gate_c = projected.chunk(3, dim=-1) + + # 2. Apply Liquid-style gating + gate_b = torch.sigmoid(gate_b) + gate_c = torch.sigmoid(gate_c) + + # 3. Short-range Convolution (Parallelized local mixing) + u_conv = u.transpose(1, 2) + # Shift/truncate to maintain causality (only look at past tokens) + u_conv = self.conv(u_conv)[..., :L] + u_conv = u_conv.transpose(1, 2) + + # 4. Multiplicative interaction + return (u_conv * gate_b) * gate_c + + class MLP(nn.Module): # Using ReGLU as described in Shazeer (2020) but without bias. def __init__(self, dim: int, mlp_mult: int): @@ -667,6 +708,35 @@ def forward(self, x: Tensor) -> Tensor: x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x +class BlockConv(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, # Kept for signature compatibility + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.lfm_norm = RMSNorm() + + # Global Context (Standard Attention) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + + # Local Liquid Dynamics (Your new LFM layer) + self.lfm = LFMShortConv(dim, kernel_size=4) + + # Learnable residual scales (Standard in the competition script) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.lfm_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # Standard residual pattern with RMSNorm + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + x = x + self.lfm_scale.to(dtype=x.dtype)[None, None, :] * self.lfm(self.lfm_norm(x)) + return x class GPT(nn.Module): def __init__( @@ -691,8 +761,19 @@ def __init__( self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers - self.blocks = nn.ModuleList( - [ + self.blocks = nn.ModuleList() + for _ in range(num_layers // 4): + self.blocks.extend([ + BlockConv( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for _ in range(3) + ] + [ Block( model_dim, num_heads, @@ -701,9 +782,7 @@ def __init__( rope_base, qk_gain_init, ) - for _ in range(num_layers) - ] - ) + ]) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: @@ -1133,4 +1212,4 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if __name__ == "__main__": - main() + main() \ No newline at end of file From 253f2278aad3b4199efcabf82ec85cd3db05b3c0 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:41:50 -0700 Subject: [PATCH 06/80] Compile + No LFMConv --- train_gpt.py | 155 ++++++++------------------------------------------- 1 file changed, 23 insertions(+), 132 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index fb1e2660da..ca6d7a0fb9 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -511,29 +511,15 @@ def __init__(self, eps: float | None = None): super().__init__() self.eps = eps + @torch._dynamo.disable def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) class CastedLinear(nn.Linear): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._cached_weight_dtype = None - self._cached_weight = None - self._cached_bias_dtype = None - self._cached_bias = None - def forward(self, x: Tensor) -> Tensor: - target_dtype = x.dtype - if self._cached_weight_dtype != target_dtype: - self._cached_weight = self.weight.to(target_dtype) - self._cached_weight_dtype = target_dtype - if self.bias is not None and self._cached_bias_dtype != target_dtype: - self._cached_bias = self.bias.to(target_dtype) - self._cached_bias_dtype = target_dtype - else: - self._cached_bias = None - return F.linear(x, self._cached_weight, self._cached_bias) + return F.linear(x, self.weight.to(x.dtype), + self.bias.to(x.dtype) if self.bias is not None else None) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: @@ -545,28 +531,16 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim: int, max_seq_len: int = 1024, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None + self.max_seq_len = max_seq_len def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + return freqs.cos()[None, None, :, :].to(dtype), freqs.sin()[None, None, :, :].to(dtype) @torch.jit.script def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: @@ -628,49 +602,8 @@ def forward(self, x: Tensor) -> Tensor: return self.proj(y) -class LFMShortConv(nn.Module): - # Based on Liquid Foundation Models - - def __init__(self, dim, kernel_size=4): - super().__init__() - self.dim = dim - self.kernel_size = kernel_size - - # Use CastedLinear to match the script's optimization for H100s - self.proj = CastedLinear(dim, dim * 3, bias=False) - - # Depthwise convolution: dim parameters instead of dim^2 - self.conv = nn.Conv1d( - dim, dim, - kernel_size=kernel_size, - padding=kernel_size - 1, - groups=dim, - bias=False - ) - - def forward(self, x: Tensor) -> Tensor: - B, L, D = x.shape - - # 1. Project to input (u), gate_b, and gate_c - projected = self.proj(x) - u, gate_b, gate_c = projected.chunk(3, dim=-1) - - # 2. Apply Liquid-style gating - gate_b = torch.sigmoid(gate_b) - gate_c = torch.sigmoid(gate_c) - - # 3. Short-range Convolution (Parallelized local mixing) - u_conv = u.transpose(1, 2) - # Shift/truncate to maintain causality (only look at past tokens) - u_conv = self.conv(u_conv)[..., :L] - u_conv = u_conv.transpose(1, 2) - - # 4. Multiplicative interaction - return (u_conv * gate_b) * gate_c - - class MLP(nn.Module): - # Using ReGLU as described in Shazeer (2020) but without bias. + # Using SwiGLU as described in Shazeer (2020) but without bias. def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = int(mlp_mult * dim // 1.5) @@ -679,8 +612,7 @@ def __init__(self, dim: int, mlp_mult: int): self.proj = CastedLinear(hidden, dim, bias=False) def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc1(x)) * self.fc2(x) - return self.proj(x) + return self.proj(F.silu(self.fc1(x)) * self.fc2(x)) class Block(nn.Module): @@ -700,7 +632,6 @@ def __init__( self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) def forward(self, x: Tensor) -> Tensor: attn_out = self.attn(self.attn_norm(x)) @@ -708,35 +639,6 @@ def forward(self, x: Tensor) -> Tensor: x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x -class BlockConv(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, # Kept for signature compatibility - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.lfm_norm = RMSNorm() - - # Global Context (Standard Attention) - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - - # Local Liquid Dynamics (Your new LFM layer) - self.lfm = LFMShortConv(dim, kernel_size=4) - - # Learnable residual scales (Standard in the competition script) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.lfm_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - # Standard residual pattern with RMSNorm - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) - x = x + self.lfm_scale.to(dtype=x.dtype)[None, None, :] * self.lfm(self.lfm_norm(x)) - return x class GPT(nn.Module): def __init__( @@ -761,28 +663,16 @@ def __init__( self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers - self.blocks = nn.ModuleList() - for _ in range(num_layers // 4): - self.blocks.extend([ - BlockConv( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for _ in range(3) - ] + [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - ]) + self.blocks = nn.ModuleList([ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) for _ in range(num_layers) + ]) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: @@ -865,6 +755,7 @@ def main() -> None: logfile = f"logs/{args.run_id}.txt" print(logfile) + @torch.compiler.disable def log0(msg: str, console: bool = True) -> None: if not master_process: return @@ -931,8 +822,8 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = base_model - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=True) if distributed else compiled_model + compiled_model = torch.compile(base_model) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR From 38eadfd42d5e01245caa7b0cc92415b919c1f093 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:59:53 -0700 Subject: [PATCH 07/80] Update Hyperparameters --- train_gpt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ca6d7a0fb9..ccb5e4a62f 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,18 +53,18 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 1_048_576)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 12)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 768)) - num_heads = int(os.environ.get("NUM_HEADS", 12)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 512.0)) From eef75a1c3c238413d2512cec2d78f84e01b105a3 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 6 Apr 2026 22:21:36 -0700 Subject: [PATCH 08/80] Partial RoPE + Hyperparameter Tuning --- train_gpt.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ccb5e4a62f..be00be1ccb 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -48,24 +48,24 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 400)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 1_048_576)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) + model_dim = int(os.environ.get("MODEL_DIM", 384)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 512.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -82,7 +82,7 @@ class Hyperparameters: muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.975)) + beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) @@ -94,7 +94,7 @@ class Hyperparameters: # Background on Muon: https://kellerjordan.github.io/posts/muon/ -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 4, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) @@ -603,7 +603,7 @@ def forward(self, x: Tensor) -> Tensor: class MLP(nn.Module): - # Using SwiGLU as described in Shazeer (2020) but without bias. + # Using SwiGLU as introduced in Shazeer (2020) def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = int(mlp_mult * dim // 1.5) @@ -624,11 +624,12 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + use_rope: bool=True, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -671,7 +672,8 @@ def __init__( mlp_mult, rope_base, qk_gain_init, - ) for _ in range(num_layers) + use_rope=(i % 2 == 0) + ) for i in range(num_layers) ]) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) From e05af07731fd6e4e8956c2eccba3adcb567d8a78 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:07:35 -0700 Subject: [PATCH 09/80] Minor optimizations --- train_gpt.py | 49 ++++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index be00be1ccb..502513ddf2 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -67,7 +67,7 @@ class Hyperparameters: num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 512.0)) + rope_base = float(os.environ.get("ROPE_BASE", 1024.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. @@ -78,9 +78,9 @@ class Hyperparameters: matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) @@ -516,12 +516,6 @@ def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - return F.linear(x, self.weight.to(x.dtype), - self.bias.to(x.dtype) if self.bias is not None else None) - - def restore_low_dim_params_to_fp32(module: nn.Module) -> None: # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): @@ -542,11 +536,12 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup freqs = torch.outer(t, self.inv_freq) return freqs.cos()[None, None, :, :].to(dtype), freqs.sin()[None, None, :, :].to(dtype) -@torch.jit.script def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + out = torch.empty_like(x) half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + out[..., :half] = x[..., :half] * cos - x[..., half:] * sin + out[..., half:] = x[..., :half] * sin + x[..., half:] * cos + return out class CausalSelfAttention(nn.Module): @@ -570,10 +565,10 @@ def __init__( if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) + self.c_q = nn.Linear(dim, dim, bias=False) + self.c_k = nn.Linear(dim, kv_dim, bias=False) + self.c_v = nn.Linear(dim, kv_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) if use_rope else None @@ -606,13 +601,15 @@ class MLP(nn.Module): # Using SwiGLU as introduced in Shazeer (2020) def __init__(self, dim: int, mlp_mult: int): super().__init__() - hidden = int(mlp_mult * dim // 1.5) - self.fc1 = CastedLinear(dim, hidden, bias=False) - self.fc2 = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) + self.hidden = int(mlp_mult * dim // 1.5) + # Combine fc1 and fc2 into one "fused" layer + self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) + self.proj = nn.Linear(self.hidden, dim, bias=False) def forward(self, x: Tensor) -> Tensor: - return self.proj(F.silu(self.fc1(x)) * self.fc2(x)) + fused_out = self.fused_fc(x) + x1, x2 = fused_out.chunk(2, dim=-1) + return self.proj(F.silu(x1) * x2) class Block(nn.Module): @@ -635,10 +632,8 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: Tensor) -> Tensor: - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + return x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) class GPT(nn.Module): @@ -676,7 +671,7 @@ def __init__( ) for i in range(num_layers) ]) self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() @@ -821,7 +816,7 @@ def log0(msg: str, console: bool = True) -> None: qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() for module in base_model.modules(): - if isinstance(module, CastedLinear): + if isinstance(module, nn.Linear): module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model) From ed564a5c41ddb1b60026650de654f392d01ed7da Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:21:40 -0700 Subject: [PATCH 10/80] Hyperparam tuning, resid scaling, and attention qkv projection fusing --- train_gpt.py | 55 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 502513ddf2..dddde8cb2b 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,7 +54,7 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -64,7 +64,7 @@ class Hyperparameters: num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 384)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_heads = int(os.environ.get("NUM_HEADS", 12)) mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 1024.0)) @@ -74,9 +74,9 @@ class Hyperparameters: embed_lr = float(os.environ.get("EMBED_LR", 0.5)) head_lr = float(os.environ.get("HEAD_LR", 0.006)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.035)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -94,7 +94,7 @@ class Hyperparameters: # Background on Muon: https://kellerjordan.github.io/posts/muon/ -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 4, eps: float = 1e-7) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) @@ -109,7 +109,6 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 4, eps: float = 1e-7) -> X = a * X + B @ X return X.T if transposed else X - class Muon(torch.optim.Optimizer): def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( @@ -564,10 +563,8 @@ def __init__( self.head_dim = dim // num_heads if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = nn.Linear(dim, dim, bias=False) - self.c_k = nn.Linear(dim, kv_dim, bias=False) - self.c_v = nn.Linear(dim, kv_dim, bias=False) + self.kv_dim = self.num_kv_heads * self.head_dim + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) @@ -575,15 +572,20 @@ def __init__( def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) if self.rotary: cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( q, @@ -602,6 +604,7 @@ class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: int): super().__init__() self.hidden = int(mlp_mult * dim // 1.5) + self.hidden = (self.hidden + 63) // 64 * 64 # Pad to multiple of 64 # Combine fc1 and fc2 into one "fused" layer self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) self.proj = nn.Linear(self.hidden, dim, bias=False) @@ -629,11 +632,13 @@ def __init__( self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: Tensor) -> Tensor: - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) - return x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + x = self.resid_attn_scale.to(dtype=x.dtype)[None, None, :] * x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + return self.resid_mlp_scale.to(dtype=x.dtype)[None, None, :] * x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) class GPT(nn.Module): @@ -693,13 +698,19 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) + logits = F.linear(x, self.tok_emb.weight) else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") + logits = self.lm_head(x) + + # In-place operations are fine for the first two steps + logits.div_(self.logit_softcap) + torch.tanh_(logits) + + # DO NOT use .mul_() here. Use out-of-place multiplication (*) + # to create a new tensor for the loss function. + logits = logits * self.logit_softcap + + return F.cross_entropy(logits.float(), targets) # ----------------------------- @@ -711,7 +722,7 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, mode="max-autotune", fullgraph=True) # ----------------------------- # DISTRIBUTED + CUDA SETUP From ae5ce6736b8ae5b1c2faa52d25dbc9aa06482eb1 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:19:47 -0700 Subject: [PATCH 11/80] Hyperparameter tuning --- train_gpt.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index dddde8cb2b..b0cd01bbb4 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,7 +54,7 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -63,21 +63,21 @@ class Hyperparameters: vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 384)) - num_heads = int(os.environ.get("NUM_HEADS", 12)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 1024.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.5)) - head_lr = float(os.environ.get("HEAD_LR", 0.006)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.035)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.98)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) From ad7ba338ecd1350e9eeeecbc66adf671ef94a92e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:06:49 -0700 Subject: [PATCH 12/80] LoRA and stuff --- train_gpt.py | 288 +++++++++++++++++++++++++-------------------------- 1 file changed, 140 insertions(+), 148 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index b0cd01bbb4..b077a39df5 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -48,27 +48,29 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 40)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.2)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 12)) + num_layers = int(os.environ.get("NUM_LAYERS", 32)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 1024.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) + rank_min = int(os.environ.get("RANK_MIN", 32)) + rank_step = int(os.environ.get("RANK_STEP", 32)) + rank_max = int(os.environ.get("RANK_MAX", 128)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.5)) @@ -84,7 +86,8 @@ class Hyperparameters: beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 2.0)) + # ----------------------------- # MUON OPTIMIZER @@ -96,10 +99,9 @@ class Hyperparameters: def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() - X /= X.norm() + eps + X = X / (X.norm().to(X.dtype) + eps) transposed = G.size(0) > G.size(1) if transposed: X = X.T @@ -510,7 +512,6 @@ def __init__(self, eps: float | None = None): super().__init__() self.eps = eps - @torch._dynamo.disable def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) @@ -542,176 +543,176 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: out[..., half:] = x[..., :half] * sin + x[..., half:] * cos return out +class LoRALinear(nn.Module): + def __init__(self, out_features: int, in_features: int, rank: int = 64): + super().__init__() + # Unique adapters for this specific block + self.lora_A = nn.Parameter(torch.zeros((in_features, rank))) + self.lora_B = nn.Parameter(torch.zeros((rank, out_features))) + + # Init: A is random, B is zero so the layer starts as an identity of Master + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + self.scaling = 1.0 / rank + + def forward(self, x: Tensor, master_weight: Tensor) -> Tensor: + # Reconstruct weight: W = W_master + (A @ B).T + # We transpose (A@B) to match nn.Linear weight shape [out, in] + weight = master_weight + (self.lora_A @ self.lora_B).T * self.scaling + return F.linear(x, weight) + class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - use_rope: bool=True, - ): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=64): super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) - self.proj = nn.Linear(dim, dim, bias=False) - self.proj._zero_init = True + + qkv_out_dim = num_heads * self.head_dim + 2 * self.kv_dim + self.c_qkv = LoRALinear(qkv_out_dim, dim, rank=rank) + self.proj = LoRALinear(dim, num_heads * self.head_dim, rank=rank) + + # Layer-specific gain parameters self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) if use_rope else None + self.rotary = Rotary(self.head_dim, base=rope_base) - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape + def forward(self, x: Tensor, qkv_master: Tensor, proj_master: Tensor) -> Tensor: + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x, qkv_master) + + # Split into Q, K, V based on GQA dimensions + q, k, v = qkv.split([self.num_heads * self.head_dim, self.kv_dim, self.kv_dim], dim=-1) - qkv = self.c_qkv(x) - q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - + + # Apply RMSNorm to Q and K as per the baseline's stability strategy q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) - if self.rotary: - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - + + # Apply Rotary Positional Embeddings + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + + # Apply the Q-Gain scaling q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + + # Scaled Dot Product Attention with GQA support y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), + q, k, v, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads) ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1) + return self.proj(y, proj_master) class MLP(nn.Module): - # Using SwiGLU as introduced in Shazeer (2020) - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim, hidden, rank=64): super().__init__() - self.hidden = int(mlp_mult * dim // 1.5) - self.hidden = (self.hidden + 63) // 64 * 64 # Pad to multiple of 64 - # Combine fc1 and fc2 into one "fused" layer - self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) - self.proj = nn.Linear(self.hidden, dim, bias=False) + self.fused_fc = LoRALinear(2 * hidden, dim, rank=rank) + self.proj = LoRALinear(dim, hidden, rank=rank) - def forward(self, x: Tensor) -> Tensor: - fused_out = self.fused_fc(x) + def forward(self, x: Tensor, fc1_master: Tensor, fc2_master: Tensor) -> Tensor: + fused_out = self.fused_fc(x, fc1_master) x1, x2 = fused_out.chunk(2, dim=-1) - return self.proj(F.silu(x1) * x2) - + return self.proj(F.silu(x1) * x2, fc2_master) class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - use_rope: bool=True, - ): + def __init__(self, dim, hidden, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=64): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + # Now passing all required positional arguments to CausalSelfAttention + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=rank) + self.mlp = MLP(dim, hidden, rank=rank) + + self.attn_scale = nn.Parameter(torch.ones(dim)) + self.mlp_scale = nn.Parameter(torch.ones(dim)) - def forward(self, x: Tensor) -> Tensor: - x = self.resid_attn_scale.to(dtype=x.dtype)[None, None, :] * x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) - return self.resid_mlp_scale.to(dtype=x.dtype)[None, None, :] * x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + def forward(self, x: Tensor, qkv_master: Tensor, proj_master: Tensor, fc1_master: Tensor, fc2_master: Tensor) -> Tensor: + x = x + self.attn_scale * self.attn(self.attn_norm(x), qkv_master, proj_master) + x = x + self.mlp_scale * self.mlp(self.mlp_norm(x), fc1_master, fc2_master) + return x +# Calculate U-shaped ranks +def get_rank(layer_idx, num_layers, r_min, r_max, step=32): + mid = (num_layers - 1) / 2 + dist = abs(layer_idx - mid) / mid + rank = r_min + (r_max - r_min) * (dist ** 2) + return int(round(rank / step) * step) class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, args: Hyperparameters): super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_layers = num_layers - self.blocks = nn.ModuleList([ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - use_rope=(i % 2 == 0) - ) for i in range(num_layers) - ]) + # Store attributes needed in forward() + self.num_layers = args.num_layers + self.tie_embeddings = args.tie_embeddings + + hidden = int(args.mlp_mult * args.model_dim // 1.5) + hidden = (hidden + 63) // 64 * 64 + + self.masters = nn.ParameterDict({ + 'qkv': nn.Parameter(torch.randn(args.model_dim + 2 * (args.num_kv_heads * (args.model_dim // args.num_heads)), args.model_dim)), + 'proj': nn.Parameter(torch.randn(args.model_dim, args.model_dim)), + 'fc1': nn.Parameter(torch.randn(2 * hidden, args.model_dim)), + 'fc2': nn.Parameter(torch.randn(args.model_dim, hidden)), + }) + + self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim) + + # Initialize lm_head as None or Linear to avoid torch.compile errors + self.lm_head = None if args.tie_embeddings else nn.Linear(args.model_dim, args.vocab_size, bias=False) + + self.blocks = nn.ModuleList() + for i in range(args.num_layers): + layer_rank = get_rank(i, args.num_layers, args.rank_min, args.rank_max, args.rank_step) + self.blocks.append( + Block( + args.model_dim, + hidden, + args.num_heads, + args.num_kv_heads, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rank=layer_rank + ) + ) self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) + def _init_weights(self): + for p in self.masters.values(): + nn.init.normal_(p, std=0.02) + nn.init.normal_(self.tok_emb.weight, std=0.0075) + if self.lm_head is not None: + nn.init.normal_(self.lm_head.weight, std=0.02) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) - for i in range(self.num_layers): - x = self.blocks[i](x) + qkv_master = self.masters['qkv'] + proj_master = self.masters['proj'] + fc1_master = self.masters['fc1'] + fc2_master = self.masters['fc2'] + + for block in self.blocks: + x = block(x, qkv_master, proj_master, fc1_master, fc2_master) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) + x = self.final_norm(x) if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) - # In-place operations are fine for the first two steps - logits.div_(self.logit_softcap) - torch.tanh_(logits) - - # DO NOT use .mul_() here. Use out-of-place multiplication (*) - # to create a new tensor for the loss function. - logits = logits * self.logit_softcap - - return F.cross_entropy(logits.float(), targets) - + return F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1)) # ----------------------------- # TRAINING @@ -813,19 +814,7 @@ def log0(msg: str, console: bool = True) -> None: # MODEL + OPTIMIZER SETUP # ----------------------------- - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() + base_model = GPT(args).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, nn.Linear): module.float() @@ -838,15 +827,18 @@ def log0(msg: str, console: bool = True) -> None: # - untied lm_head (Adam) uses HEAD_LR # - matrix params in transformer blocks use MATRIX_LR via Muon # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) + body_named_params = [ + (n, p) for n, p in base_model.named_parameters() + if "tok_emb" not in n and "lm_head" not in n + ] matrix_params = [ p - for name, p in block_named_params + for name, p in body_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ p - for name, p in block_named_params + for name, p in body_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr From 9679737e768507b02ab8b70cf043456b9231d7f6 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 22:05:45 -0700 Subject: [PATCH 13/80] Attempt LoRA improvement --- train_gpt.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index b077a39df5..cffc010d14 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -546,21 +546,20 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: class LoRALinear(nn.Module): def __init__(self, out_features: int, in_features: int, rank: int = 64): super().__init__() - # Unique adapters for this specific block self.lora_A = nn.Parameter(torch.zeros((in_features, rank))) self.lora_B = nn.Parameter(torch.zeros((rank, out_features))) - - # Init: A is random, B is zero so the layer starts as an identity of Master nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) self.scaling = 1.0 / rank def forward(self, x: Tensor, master_weight: Tensor) -> Tensor: - # Reconstruct weight: W = W_master + (A @ B).T - # We transpose (A@B) to match nn.Linear weight shape [out, in] - weight = master_weight + (self.lora_A @ self.lora_B).T * self.scaling - return F.linear(x, weight) + res = F.linear(x, master_weight) + + orig_shape = x.shape + x_flat = x.view(-1, orig_shape[-1]) + lora_res = (x_flat @ self.lora_A) @ self.lora_B + return res + (lora_res * self.scaling).view(orig_shape[0], orig_shape[1], -1) class CausalSelfAttention(nn.Module): def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=64): From c4446e46f926a49563c7d17b4ab188a715de9fd4 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 22:32:50 -0700 Subject: [PATCH 14/80] Hyperparam config --- train_gpt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index cffc010d14..cb7b5440fc 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,23 +53,23 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.2)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 32)) + num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) - rank_min = int(os.environ.get("RANK_MIN", 32)) - rank_step = int(os.environ.get("RANK_STEP", 32)) + rank_min = int(os.environ.get("RANK_MIN", 64)) + rank_step = int(os.environ.get("RANK_STEP", 64)) rank_max = int(os.environ.get("RANK_MAX", 128)) # Optimizer hyperparameters. From 1893d8098c547ff691307bb6b35d61dba8505e7c Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:07:13 -0700 Subject: [PATCH 15/80] Revert LoRA for now. Trying extreme hyperparameters. --- train_gpt.py | 293 ++++++++++++++++++++++++++------------------------- 1 file changed, 151 insertions(+), 142 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index cb7b5440fc..db2655cde3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,30 +54,28 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.2)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 12)) + num_layers = int(os.environ.get("NUM_LAYERS", 6)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) + model_dim = int(os.environ.get("MODEL_DIM", 768)) + num_heads = int(os.environ.get("NUM_HEADS", 12)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) - rank_min = int(os.environ.get("RANK_MIN", 64)) - rank_step = int(os.environ.get("RANK_STEP", 64)) - rank_max = int(os.environ.get("RANK_MAX", 128)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.5)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) @@ -86,8 +84,7 @@ class Hyperparameters: beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 2.0)) - + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) # ----------------------------- # MUON OPTIMIZER @@ -99,9 +96,10 @@ class Hyperparameters: def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() - X = X / (X.norm().to(X.dtype) + eps) + X /= X.norm() + eps transposed = G.size(0) > G.size(1) if transposed: X = X.T @@ -512,6 +510,7 @@ def __init__(self, eps: float | None = None): super().__init__() self.eps = eps + @torch._dynamo.disable def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) @@ -543,175 +542,176 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: out[..., half:] = x[..., :half] * sin + x[..., half:] * cos return out -class LoRALinear(nn.Module): - def __init__(self, out_features: int, in_features: int, rank: int = 64): - super().__init__() - self.lora_A = nn.Parameter(torch.zeros((in_features, rank))) - self.lora_B = nn.Parameter(torch.zeros((rank, out_features))) - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - self.scaling = 1.0 / rank - - def forward(self, x: Tensor, master_weight: Tensor) -> Tensor: - res = F.linear(x, master_weight) - - orig_shape = x.shape - x_flat = x.view(-1, orig_shape[-1]) - lora_res = (x_flat @ self.lora_A) @ self.lora_B - - return res + (lora_res * self.scaling).view(orig_shape[0], orig_shape[1], -1) class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=64): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_rope: bool=True, + ): super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") self.kv_dim = self.num_kv_heads * self.head_dim - - qkv_out_dim = num_heads * self.head_dim + 2 * self.kv_dim - self.c_qkv = LoRALinear(qkv_out_dim, dim, rank=rank) - self.proj = LoRALinear(dim, num_heads * self.head_dim, rank=rank) - - # Layer-specific gain parameters + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) + self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) + self.rotary = Rotary(self.head_dim, base=rope_base) if use_rope else None - def forward(self, x: Tensor, qkv_master: Tensor, proj_master: Tensor) -> Tensor: - bsz, seqlen, _ = x.shape - qkv = self.c_qkv(x, qkv_master) - - # Split into Q, K, V based on GQA dimensions - q, k, v = qkv.split([self.num_heads * self.head_dim, self.kv_dim, self.kv_dim], dim=-1) + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Apply RMSNorm to Q and K as per the baseline's stability strategy + q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) - - # Apply Rotary Positional Embeddings - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - - # Apply the Q-Gain scaling + if self.rotary: + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - - # Scaled Dot Product Attention with GQA support y = F.scaled_dot_product_attention( - q, k, v, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads) + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), ) - - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1) - return self.proj(y, proj_master) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) class MLP(nn.Module): - def __init__(self, dim, hidden, rank=64): + # Using SwiGLU as introduced in Shazeer (2020) + def __init__(self, dim: int, mlp_mult: int): super().__init__() - self.fused_fc = LoRALinear(2 * hidden, dim, rank=rank) - self.proj = LoRALinear(dim, hidden, rank=rank) + self.hidden = int(mlp_mult * dim // 1.5) + self.hidden = (self.hidden + 63) // 64 * 64 # Pad to multiple of 64 + # Combine fc1 and fc2 into one "fused" layer + self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) + self.proj = nn.Linear(self.hidden, dim, bias=False) - def forward(self, x: Tensor, fc1_master: Tensor, fc2_master: Tensor) -> Tensor: - fused_out = self.fused_fc(x, fc1_master) + def forward(self, x: Tensor) -> Tensor: + fused_out = self.fused_fc(x) x1, x2 = fused_out.chunk(2, dim=-1) - return self.proj(F.silu(x1) * x2, fc2_master) + return self.proj(F.silu(x1) * x2) + class Block(nn.Module): - def __init__(self, dim, hidden, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=64): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_rope: bool=True, + ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - - # Now passing all required positional arguments to CausalSelfAttention - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rank=rank) - self.mlp = MLP(dim, hidden, rank=rank) - - self.attn_scale = nn.Parameter(torch.ones(dim)) - self.mlp_scale = nn.Parameter(torch.ones(dim)) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - def forward(self, x: Tensor, qkv_master: Tensor, proj_master: Tensor, fc1_master: Tensor, fc2_master: Tensor) -> Tensor: - x = x + self.attn_scale * self.attn(self.attn_norm(x), qkv_master, proj_master) - x = x + self.mlp_scale * self.mlp(self.mlp_norm(x), fc1_master, fc2_master) - return x + def forward(self, x: Tensor) -> Tensor: + x = self.resid_attn_scale.to(dtype=x.dtype)[None, None, :] * x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + return self.resid_mlp_scale.to(dtype=x.dtype)[None, None, :] * x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) -# Calculate U-shaped ranks -def get_rank(layer_idx, num_layers, r_min, r_max, step=32): - mid = (num_layers - 1) / 2 - dist = abs(layer_idx - mid) / mid - rank = r_min + (r_max - r_min) * (dist ** 2) - return int(round(rank / step) * step) class GPT(nn.Module): - def __init__(self, args: Hyperparameters): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): super().__init__() - # Store attributes needed in forward() - self.num_layers = args.num_layers - self.tie_embeddings = args.tie_embeddings - - hidden = int(args.mlp_mult * args.model_dim // 1.5) - hidden = (hidden + 63) // 64 * 64 - - self.masters = nn.ParameterDict({ - 'qkv': nn.Parameter(torch.randn(args.model_dim + 2 * (args.num_kv_heads * (args.model_dim // args.num_heads)), args.model_dim)), - 'proj': nn.Parameter(torch.randn(args.model_dim, args.model_dim)), - 'fc1': nn.Parameter(torch.randn(2 * hidden, args.model_dim)), - 'fc2': nn.Parameter(torch.randn(args.model_dim, hidden)), - }) - - self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim) - - # Initialize lm_head as None or Linear to avoid torch.compile errors - self.lm_head = None if args.tie_embeddings else nn.Linear(args.model_dim, args.vocab_size, bias=False) - - self.blocks = nn.ModuleList() - for i in range(args.num_layers): - layer_rank = get_rank(i, args.num_layers, args.rank_min, args.rank_max, args.rank_step) - self.blocks.append( - Block( - args.model_dim, - hidden, - args.num_heads, - args.num_kv_heads, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - rank=layer_rank - ) - ) + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_layers = num_layers + self.blocks = nn.ModuleList([ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_rope=(i % 2 == 0) + ) for i in range(num_layers) + ]) self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True self._init_weights() - def _init_weights(self): - for p in self.masters.values(): - nn.init.normal_(p, std=0.02) - nn.init.normal_(self.tok_emb.weight, std=0.0075) - if self.lm_head is not None: - nn.init.normal_(self.lm_head.weight, std=0.02) + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) - qkv_master = self.masters['qkv'] - proj_master = self.masters['proj'] - fc1_master = self.masters['fc1'] - fc2_master = self.masters['fc2'] - - for block in self.blocks: - x = block(x, qkv_master, proj_master, fc1_master, fc2_master) + for i in range(self.num_layers): + x = self.blocks[i](x) - x = self.final_norm(x) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) - return F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1)) + # In-place operations are fine for the first two steps + logits.div_(self.logit_softcap) + torch.tanh_(logits) + + # DO NOT use .mul_() here. Use out-of-place multiplication (*) + # to create a new tensor for the loss function. + logits = logits * self.logit_softcap + + return F.cross_entropy(logits.float(), targets) + # ----------------------------- # TRAINING @@ -736,7 +736,7 @@ def main() -> None: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size + grad_accum_steps = 1 #// world_size # Original: 8 // world_size for distributed. grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") @@ -813,7 +813,19 @@ def log0(msg: str, console: bool = True) -> None: # MODEL + OPTIMIZER SETUP # ----------------------------- - base_model = GPT(args).to(device).bfloat16() + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, nn.Linear): module.float() @@ -826,18 +838,15 @@ def log0(msg: str, console: bool = True) -> None: # - untied lm_head (Adam) uses HEAD_LR # - matrix params in transformer blocks use MATRIX_LR via Muon # - vectors/scalars use SCALAR_LR via Adam - body_named_params = [ - (n, p) for n, p in base_model.named_parameters() - if "tok_emb" not in n and "lm_head" not in n - ] + block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p - for name, p in body_named_params + for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ p - for name, p in body_named_params + for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr From a6fb8158ea88eb081c4065d19d593f4e2b93cbaa Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:22:43 -0700 Subject: [PATCH 16/80] Resid + Model Adjustments + Better GPU contiguous dimensions --- train_gpt.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index db2655cde3..935af8bdcf 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,10 +61,10 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 6)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 768)) - num_heads = int(os.environ.get("NUM_HEADS", 12)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) @@ -598,14 +598,17 @@ def forward(self, x: Tensor) -> Tensor: y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) +def calculate_hidden(mlp_mult: int, dim: int): + raw_hidden = int(mlp_mult * dim // 1.5) + multiplier = raw_hidden / 64 + return 64 if multiplier == 0 else (2**round(math.log2(multiplier))) * 64 + class MLP(nn.Module): # Using SwiGLU as introduced in Shazeer (2020) def __init__(self, dim: int, mlp_mult: int): super().__init__() - self.hidden = int(mlp_mult * dim // 1.5) - self.hidden = (self.hidden + 63) // 64 * 64 # Pad to multiple of 64 - # Combine fc1 and fc2 into one "fused" layer + self.hidden = calculate_hidden(mlp_mult, dim) self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) self.proj = nn.Linear(self.hidden, dim, bias=False) @@ -632,13 +635,12 @@ def __init__( self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.125)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: Tensor) -> Tensor: - x = self.resid_attn_scale.to(dtype=x.dtype)[None, None, :] * x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) - return self.resid_mlp_scale.to(dtype=x.dtype)[None, None, :] * x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + y = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) class GPT(nn.Module): @@ -672,7 +674,7 @@ def __init__( mlp_mult, rope_base, qk_gain_init, - use_rope=(i % 2 == 0) + use_rope=(i % 2 == 1) ) for i in range(num_layers) ]) self.final_norm = RMSNorm() @@ -722,7 +724,7 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, mode="max-autotune", fullgraph=True) + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP From dc5a514400306535bf8b472f7c09335f4abd418e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 8 Apr 2026 12:35:48 -0700 Subject: [PATCH 17/80] Dynamic U-shaped KV head count --- train_gpt.py | 64 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 935af8bdcf..9d856d2ec7 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,19 +53,19 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 64)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -73,11 +73,11 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.5)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) @@ -93,7 +93,7 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ - +@torch.compile(mode="max-autotune", fullgraph=True) def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. @@ -643,6 +643,19 @@ def forward(self, x: Tensor) -> Tensor: return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) +def get_diamond_kv_heads(layer_idx, total_layers, num_heads, p=2.0): + midpoint = (total_layers - 1) / 2 + norm_dist = abs(layer_idx - midpoint) / midpoint + curve = norm_dist ** p + max_kv = max(1, num_heads // 2) + raw_kv = 1 + (max_kv - 1) * curve + kv_heads = 2 ** round(math.log2(raw_kv)) + while num_heads % kv_heads != 0: + kv_heads //= 2 + + return int(max(2, kv_heads)) + + class GPT(nn.Module): def __init__( self, @@ -657,6 +670,7 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + ramp_len: int=32, ): super().__init__() if logit_softcap <= 0.0: @@ -670,7 +684,7 @@ def __init__( Block( model_dim, num_heads, - num_kv_heads, + get_diamond_kv_heads(i, num_layers, num_heads), mlp_mult, rope_base, qk_gain_init, @@ -683,6 +697,11 @@ def __init__( self.lm_head._zero_init = True self._init_weights() + ramp = torch.linspace(0.0, 1.0, ramp_len) + # Assuming args.train_seq_len is 1024 + full_mask = torch.cat([ramp, torch.ones(1024 - ramp_len)]) + self.register_buffer("loss_mask", full_mask, persistent=False) + def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) @@ -699,20 +718,32 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) + if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) - # In-place operations are fine for the first two steps logits.div_(self.logit_softcap) torch.tanh_(logits) - - # DO NOT use .mul_() here. Use out-of-place multiplication (*) - # to create a new tensor for the loss function. logits = logits * self.logit_softcap - return F.cross_entropy(logits.float(), targets) + # --- MODIFIED LOSS CALCULATION --- + if self.training: + # reduction='none' gives us a loss value for every single token + loss = F.cross_entropy(logits.float(), targets, reduction='none') + + # Reshape loss back to [batch, seq_len] to align with our mask + loss = loss.view(input_ids.size(0), input_ids.size(1)) + + # Apply the ramp [0.0 ... 1.0 ... 1.0] + # The mask broadcasts across the batch dimension automatically + weighted_loss = loss * self.loss_mask + + return weighted_loss.mean() + else: + # Validation remains standard mean cross-entropy + return F.cross_entropy(logits.float(), targets) # ----------------------------- @@ -720,11 +751,8 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: # ----------------------------- def main() -> None: - global zeropower_via_newtonschulz5 - code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP @@ -738,7 +766,7 @@ def main() -> None: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 1 #// world_size # Original: 8 // world_size for distributed. + grad_accum_steps = 8 // world_size if distributed else 1 grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") From 94b6c695b287b4be4580c22e323056bf959d4dbe Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 8 Apr 2026 14:18:48 -0700 Subject: [PATCH 18/80] Optimize --- train_gpt.py | 97 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 9d856d2ec7..e2ea603569 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -30,7 +30,7 @@ # ----------------------------- # HYPERPARAMETERS # ----------------------------- -# Default Simple Baseline run: +# Default Simple Baseline run (not actually reflected within Hyperparameters): # - 9 transformer blocks at width 512 # - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion # - vocab size 1024, sequence length 1024, tied embeddings @@ -53,8 +53,8 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -65,7 +65,7 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -93,7 +93,7 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ -@torch.compile(mode="max-autotune", fullgraph=True) +@torch.compile(mode="max-autotune", fullgraph=True, dynamic=True) def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. @@ -247,19 +247,19 @@ def eval_val( val_byte_count = torch.zeros((), device=device, dtype=torch.float64) model.eval() + # Pre-load the first batch + raw_start = seq_start * args.train_seq_len + raw_end = min(raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) + next_batch = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - if batch_seq_end < seq_end: - next_raw_start = batch_seq_end * args.train_seq_len - next_raw_end = min(next_raw_start + args.train_seq_len * local_batch_seqs, - seq_end * args.train_seq_len + 1) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - _ = val_tokens[next_raw_start:next_raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + torch.cuda.current_stream().wait_stream(torch.cuda.default_stream()) + local = next_batch + next_seq_start = batch_seq_start + local_batch_seqs + if next_seq_start < seq_end: + n_raw_start = next_seq_start * args.train_seq_len + n_raw_end = min(n_raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) + next_batch = val_tokens[n_raw_start:n_raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): @@ -524,16 +524,19 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int = 1024, base: float = 10000.0): + def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len = max_seq_len + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos()[None, None, :, :].to(torch.bfloat16) + sin = freqs.sin()[None, None, :, :].to(torch.bfloat16) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq) - return freqs.cos()[None, None, :, :].to(dtype), freqs.sin()[None, None, :, :].to(dtype) + def forward(self): + # Return the pre-calculated, pre-casted buffers + return self.cos, self.sin def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: out = torch.empty_like(x) @@ -551,6 +554,7 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + seq_len: int=1024, use_rope: bool=True, ): super().__init__() @@ -568,7 +572,8 @@ def __init__( self.proj = nn.Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) if use_rope else None + self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, base=rope_base) + self.use_rope = use_rope def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape @@ -581,8 +586,8 @@ def forward(self, x: Tensor) -> Tensor: q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) - if self.rotary: - cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.use_rope: + cos, sin = self.rotary() q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) @@ -611,6 +616,7 @@ def __init__(self, dim: int, mlp_mult: int): self.hidden = calculate_hidden(mlp_mult, dim) self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) self.proj = nn.Linear(self.hidden, dim, bias=False) + self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: fused_out = self.fused_fc(x) @@ -627,12 +633,13 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + seq_len: int=1024, use_rope: bool=True, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_rope) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.125)) @@ -671,6 +678,7 @@ def __init__( rope_base: float, qk_gain_init: float, ramp_len: int=32, + seq_len: int=1024, ): super().__init__() if logit_softcap <= 0.0: @@ -680,14 +688,15 @@ def __init__( self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers - self.blocks = nn.ModuleList([ + self.blocks = nn.Sequential(*[ Block( model_dim, num_heads, - get_diamond_kv_heads(i, num_layers, num_heads), + get_diamond_kv_heads(i, num_layers, num_kv_heads), mlp_mult, rope_base, qk_gain_init, + seq_len=seq_len, use_rope=(i % 2 == 1) ) for i in range(num_layers) ]) @@ -699,7 +708,7 @@ def __init__( ramp = torch.linspace(0.0, 1.0, ramp_len) # Assuming args.train_seq_len is 1024 - full_mask = torch.cat([ramp, torch.ones(1024 - ramp_len)]) + full_mask = torch.cat([ramp, torch.ones(seq_len - ramp_len)]) self.register_buffer("loss_mask", full_mask, persistent=False) def _init_weights(self) -> None: @@ -713,8 +722,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) - for i in range(self.num_layers): - x = self.blocks[i](x) + x = self.blocks(x) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) @@ -855,9 +863,10 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len ).to(device).bfloat16() for module in base_model.modules(): - if isinstance(module, nn.Linear): + if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model) @@ -940,15 +949,25 @@ def zero_grad_all() -> None: max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None def lr_mul(step: int, elapsed_ms: float) -> float: + warmup_mul = min(step / args.warmup_steps, 1.0) if args.warmup_steps > 0 else 1.0 + + # --- Warmdown Logic --- if args.warmdown_iters <= 0: - return 1.0 + return warmup_mul + if max_wallclock_ms is None: + # Iteration-based warmdown warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + warmdown_mul = max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if step >= warmdown_start else 1.0 + else: + # Time-based warmdown + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + warmdown_mul = remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # The effective multiplier is the intersection of warmup and warmdown + return min(warmup_mul, warmdown_mul) # Warmup primes the compiled forward/backward/optimizer paths, then we restore the # initial weights/optimizer state so measured training starts from the true init. From 5debd9249de0ea283c52c5aa439beddc9cc1500a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:26:01 -0700 Subject: [PATCH 19/80] Decent --- train_gpt.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e2ea603569..8f1c6e9706 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,7 +54,7 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -62,12 +62,12 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 12)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 6)) + model_dim = int(os.environ.get("MODEL_DIM", 384)) + num_heads = int(os.environ.get("NUM_HEADS", 6)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) + rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. @@ -78,7 +78,7 @@ class Hyperparameters: matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) beta1 = float(os.environ.get("BETA1", 0.9)) @@ -539,11 +539,8 @@ def forward(self): return self.cos, self.sin def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - out = torch.empty_like(x) - half = x.size(-1) // 2 - out[..., :half] = x[..., :half] * cos - x[..., half:] * sin - out[..., half:] = x[..., :half] * sin + x[..., half:] * cos - return out + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) class CausalSelfAttention(nn.Module): @@ -608,7 +605,6 @@ def calculate_hidden(mlp_mult: int, dim: int): multiplier = raw_hidden / 64 return 64 if multiplier == 0 else (2**round(math.log2(multiplier))) * 64 - class MLP(nn.Module): # Using SwiGLU as introduced in Shazeer (2020) def __init__(self, dim: int, mlp_mult: int): @@ -706,8 +702,7 @@ def __init__( self.lm_head._zero_init = True self._init_weights() - ramp = torch.linspace(0.0, 1.0, ramp_len) - # Assuming args.train_seq_len is 1024 + ramp = torch.sin(torch.linspace(0, math.pi/2, ramp_len))**2 full_mask = torch.cat([ramp, torch.ones(seq_len - ramp_len)]) self.register_buffer("loss_mask", full_mask, persistent=False) @@ -732,10 +727,6 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: else: logits = self.lm_head(x) - logits.div_(self.logit_softcap) - torch.tanh_(logits) - logits = logits * self.logit_softcap - # --- MODIFIED LOSS CALCULATION --- if self.training: # reduction='none' gives us a loss value for every single token @@ -750,7 +741,6 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: return weighted_loss.mean() else: - # Validation remains standard mean cross-entropy return F.cross_entropy(logits.float(), targets) @@ -774,7 +764,7 @@ def main() -> None: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size if distributed else 1 + grad_accum_steps = 8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") From 6efb59ee82a48a092a90e8fecb75dfda849b85de Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:34:06 -0700 Subject: [PATCH 20/80] Testing --- train_gpt.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 8f1c6e9706..47734c9ae1 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,24 +54,24 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 12)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 6)) - model_dim = int(os.environ.get("MODEL_DIM", 384)) - num_heads = int(os.environ.get("NUM_HEADS", 6)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.5)) + embed_lr = float(os.environ.get("EMBED_LR", 0.4)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) @@ -639,10 +639,11 @@ def __init__( self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.125)) + self.emb_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.025)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - y = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + def forward(self, x: Tensor, emb: Tensor) -> Tensor: + y = self.emb_scale.to(dtype=x.dtype)[None, None, :] * emb + x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) @@ -684,7 +685,7 @@ def __init__( self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers - self.blocks = nn.Sequential(*[ + self.blocks = nn.ModuleList([ Block( model_dim, num_heads, @@ -714,10 +715,11 @@ def _init_weights(self) -> None: nn.init.zeros_(module.weight) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) + emb = self.tok_emb(input_ids) + x = F.rms_norm(emb, (emb.size(-1),)) - x = self.blocks(x) + for block in self.blocks: + x = block(x, emb) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) @@ -764,7 +766,7 @@ def main() -> None: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size + grad_accum_steps = 1 #8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") From e667cd5986e284664fad08eda89226b8e71c3a31 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:10:46 -0700 Subject: [PATCH 21/80] QK and V split --- train_gpt.py | 58 +++++++++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 47734c9ae1..6dd8cdfd6c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,15 +54,15 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 12)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) @@ -72,13 +72,13 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.4)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) + head_lr = float(os.environ.get("HEAD_LR", 0.01)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) beta1 = float(os.environ.get("BETA1", 0.9)) @@ -544,58 +544,42 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - seq_len: int=1024, - use_rope: bool=True, - ): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True): super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) + self.c_qk = nn.Linear(dim, dim + self.kv_dim, bias=False) + self.c_v = nn.Linear(dim, self.kv_dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, base=rope_base) self.use_rope = use_rope - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, emb: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - - qkv = self.c_qkv(x) - q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) + qk = self.c_qk(x) + q, k = qk.split([dim, self.kv_dim], dim=-1) + v = self.c_v(emb) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - + q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) if self.use_rope: cos, sin = self.rotary() q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) - + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), + q, k, v, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads) ) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) @@ -639,11 +623,11 @@ def __init__( self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.125)) - self.emb_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.025)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: Tensor, emb: Tensor) -> Tensor: - y = self.emb_scale.to(dtype=x.dtype)[None, None, :] * emb + x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + attn_out = self.attn(self.attn_norm(x), emb) + y = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) @@ -766,7 +750,7 @@ def main() -> None: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 1 #8 // world_size + grad_accum_steps = 8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") From b390a288b918ff2a330cf8d83060ea8a581d7c6c Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:55:05 -0700 Subject: [PATCH 22/80] Architecture tuning --- train_gpt.py | 60 ++++++++++++++++++---------------------------------- 1 file changed, 20 insertions(+), 40 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 6dd8cdfd6c..d3e6d89417 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,14 +61,13 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 12)) + num_layers = int(os.environ.get("NUM_LAYERS", 14)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.4)) @@ -510,7 +509,6 @@ def __init__(self, eps: float | None = None): super().__init__() self.eps = eps - @torch._dynamo.disable def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) @@ -615,6 +613,7 @@ def __init__( qk_gain_init: float, seq_len: int=1024, use_rope: bool=True, + resid_scale: float=0.125, ): super().__init__() self.attn_norm = RMSNorm() @@ -622,7 +621,7 @@ def __init__( self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.125)) + self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: Tensor, emb: Tensor) -> Tensor: @@ -631,12 +630,16 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) -def get_diamond_kv_heads(layer_idx, total_layers, num_heads, p=2.0): - midpoint = (total_layers - 1) / 2 - norm_dist = abs(layer_idx - midpoint) / midpoint - curve = norm_dist ** p - max_kv = max(1, num_heads // 2) - raw_kv = 1 + (max_kv - 1) * curve +def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): + # Progresses from 2 heads at layer 0 to num_heads at the final layer + min_kv = 2 + max_kv = num_heads + + # Simple linear interpolation from layer 0 to total_layers + fraction = layer_idx / (total_layers - 1) + raw_kv = min_kv + (max_kv - min_kv) * fraction + + # Constraints: Must be power of 2 and divide num_heads kv_heads = 2 ** round(math.log2(raw_kv)) while num_heads % kv_heads != 0: kv_heads //= 2 @@ -655,30 +658,26 @@ def __init__( mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float, - logit_softcap: float, rope_base: float, qk_gain_init: float, - ramp_len: int=32, seq_len: int=1024, ): super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers self.blocks = nn.ModuleList([ Block( model_dim, num_heads, - get_diamond_kv_heads(i, num_layers, num_kv_heads), + get_linear_progression_kv_heads(i, num_layers, num_kv_heads), mlp_mult, rope_base, qk_gain_init, seq_len=seq_len, - use_rope=(i % 2 == 1) + use_rope=(i % 2 == 1), + resid_scale=1/math.sqrt(2 * num_layers) ) for i in range(num_layers) ]) self.final_norm = RMSNorm() @@ -687,10 +686,6 @@ def __init__( self.lm_head._zero_init = True self._init_weights() - ramp = torch.sin(torch.linspace(0, math.pi/2, ramp_len))**2 - full_mask = torch.cat([ramp, torch.ones(seq_len - ramp_len)]) - self.register_buffer("loss_mask", full_mask, persistent=False) - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) @@ -713,21 +708,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: else: logits = self.lm_head(x) - # --- MODIFIED LOSS CALCULATION --- - if self.training: - # reduction='none' gives us a loss value for every single token - loss = F.cross_entropy(logits.float(), targets, reduction='none') - - # Reshape loss back to [batch, seq_len] to align with our mask - loss = loss.view(input_ids.size(0), input_ids.size(1)) - - # Apply the ramp [0.0 ... 1.0 ... 1.0] - # The mask broadcasts across the batch dimension automatically - weighted_loss = loss * self.loss_mask - - return weighted_loss.mean() - else: - return F.cross_entropy(logits.float(), targets) + return F.cross_entropy(logits.float(), targets) # ----------------------------- @@ -769,7 +750,7 @@ def main() -> None: enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) - enable_math_sdp(True) + enable_math_sdp(False) logfile = None if master_process: @@ -836,7 +817,6 @@ def log0(msg: str, console: bool = True) -> None: mlp_mult=args.mlp_mult, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, seq_len=args.train_seq_len From cfe4bdd494b187a26f9b21851af0c4324c464937 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:49:40 -0700 Subject: [PATCH 23/80] fp8 --- train_gpt.py | 123 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 105 insertions(+), 18 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index d3e6d89417..9e2f0acd4c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,11 +61,11 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 14)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) @@ -77,7 +77,7 @@ class Hyperparameters: matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) beta1 = float(os.environ.get("BETA1", 0.9)) @@ -504,6 +504,82 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> # TRANSFORMER MODULES # ----------------------------- +def _is_fp8_supported(): + if not torch.cuda.is_available(): + return False + cc = torch.cuda.get_device_capability() + # FP8 is supported on compute capability 8.9 (Ada) and 9.0+ (Hopper) + return cc[0] >= 9 or (cc[0] == 8 and cc[1] >= 9) + +FP8_SUPPORTED = _is_fp8_supported() + +def _fp8_matmul(x: Tensor, weight: Tensor) -> Tensor: + x_2d = x.reshape(-1, x.size(-1)) + x_scale = (x_2d.abs().max().clamp(min=1e-12) / 448.0).float() + w_scale = (weight.abs().max().clamp(min=1e-12) / 448.0).float() + x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) + w_fp8 = (weight / w_scale).to(torch.float8_e4m3fn) + # weight is (Out, In), so w_fp8.t() is (In, Out) and is column-major + out = torch._scaled_mm(x_fp8, w_fp8.t(), scale_a=x_scale, scale_b=w_scale, out_dtype=x.dtype) + return out.reshape(*x.shape[:-1], weight.size(0)) + +class FP8LinearFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, x, weight): + ctx.x_shape = x.shape + ctx.dtype = x.dtype + ctx.weight_dtype = weight.dtype + + x_2d = x.reshape(-1, x.size(-1)) + x_scale = (x_2d.abs().max().clamp(min=1e-12) / 448.0).float() + w_scale = (weight.abs().max().clamp(min=1e-12) / 448.0).float() + + x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) + w_fp8 = (weight / w_scale).to(torch.float8_e4m3fn) + + ctx.save_for_backward(x_fp8, w_fp8, x_scale, w_scale) + + # mat2 must be col-major: w_fp8.t() is col-major because w_fp8 is row-major + out = torch._scaled_mm(x_fp8, w_fp8.t(), scale_a=x_scale, scale_b=w_scale, out_dtype=ctx.dtype) + return out.reshape(*ctx.x_shape[:-1], weight.size(0)) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + x_fp8, w_fp8, x_scale, w_scale = ctx.saved_tensors + grad_output_2d = grad_output.reshape(-1, grad_output.size(-1)) + + go_scale = (grad_output_2d.abs().max().clamp(min=1e-12) / 448.0).float() + go_fp8 = (grad_output_2d / go_scale).to(torch.float8_e4m3fn) + + # 1. Grad_x = Grad_output @ Weight + # To make Weight (mat2) col-major, we transpose it, make it contiguous, and transpose back. + # This creates a tensor with the same shape but (1, Rows) strides. + w_fp8_col_major = w_fp8.t().contiguous().t() + grad_x = torch._scaled_mm(go_fp8, w_fp8_col_major, scale_a=go_scale, scale_b=w_scale, out_dtype=ctx.dtype) + + # 2. Grad_weight = Grad_output^T @ x + # mat1 must be row-major, mat2 must be col-major. + go_fp8_t = go_fp8.t().contiguous() + x_fp8_col_major = x_fp8.t().contiguous().t() + grad_weight = torch._scaled_mm(go_fp8_t, x_fp8_col_major, scale_a=go_scale, scale_b=x_scale, out_dtype=ctx.weight_dtype) + + return grad_x.reshape(*ctx.x_shape), grad_weight + +class FP8Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + if not FP8_SUPPORTED: + return super().forward(x) + if not torch.is_grad_enabled(): + out = _fp8_matmul(x, self.weight) + else: + out = FP8LinearFunction.apply(x, self.weight) + if self.bias is not None: + out = out + self.bias + return out + + class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() @@ -522,9 +598,10 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0): + def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.rotary_dim = (int(dim * p) // 2) * 2 + inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos = freqs.cos()[None, None, :, :].to(torch.bfloat16) @@ -533,27 +610,29 @@ def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0): self.register_buffer("sin", sin, persistent=False) def forward(self): - # Return the pre-calculated, pre-casted buffers return self.cos, self.sin def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) + rotary_dim = cos.shape[-1] * 2 + x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] + x1, x2 = x_rot.chunk(2, dim=-1) + x_rotated = torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) + return torch.cat((x_rotated, x_pass), dim=-1) class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): super().__init__() self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qk = nn.Linear(dim, dim + self.kv_dim, bias=False) - self.c_v = nn.Linear(dim, self.kv_dim, bias=False) - self.proj = nn.Linear(dim, dim, bias=False) + self.c_qk = FP8Linear(dim, dim + self.kv_dim, bias=False) + self.c_v = FP8Linear(dim, self.kv_dim, bias=False) + self.proj = FP8Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, base=rope_base) + self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) self.use_rope = use_rope def forward(self, x: Tensor, emb: Tensor) -> Tensor: @@ -592,8 +671,8 @@ class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: int): super().__init__() self.hidden = calculate_hidden(mlp_mult, dim) - self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) - self.proj = nn.Linear(self.hidden, dim, bias=False) + self.fused_fc = FP8Linear(dim, 2 * self.hidden, bias=False) + self.proj = FP8Linear(self.hidden, dim, bias=False) self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: @@ -614,11 +693,12 @@ def __init__( seq_len: int=1024, use_rope: bool=True, resid_scale: float=0.125, + rope_proportion: float=0.5, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) @@ -646,6 +726,12 @@ def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): return int(max(2, kv_heads)) +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: + if num_layers <= 1: + return p_min + progress = i / (num_layers - 1) + scale = math.sin(progress * math.pi) + return p_min + (p_max - p_min) * scale class GPT(nn.Module): def __init__( @@ -677,7 +763,8 @@ def __init__( qk_gain_init, seq_len=seq_len, use_rope=(i % 2 == 1), - resid_scale=1/math.sqrt(2 * num_layers) + resid_scale=1/math.sqrt(2 * num_layers), + rope_proportion=get_rope_p_smooth(i, num_layers) ) for i in range(num_layers) ]) self.final_norm = RMSNorm() From faced9e35e809d410ad83d1b59862505fa28763f Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:07:50 -0700 Subject: [PATCH 24/80] Revert fp8 and try new architecture --- train_gpt.py | 118 ++++++++++++--------------------------------------- 1 file changed, 27 insertions(+), 91 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 9e2f0acd4c..1b96b86910 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,11 +53,11 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.2)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -92,10 +92,10 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ -@torch.compile(mode="max-autotune", fullgraph=True, dynamic=True) -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. +@torch.compile +def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): + if torch.isnan(G).any() or torch.isinf(G).any(): + return torch.zeros_like(G) a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -504,82 +504,6 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> # TRANSFORMER MODULES # ----------------------------- -def _is_fp8_supported(): - if not torch.cuda.is_available(): - return False - cc = torch.cuda.get_device_capability() - # FP8 is supported on compute capability 8.9 (Ada) and 9.0+ (Hopper) - return cc[0] >= 9 or (cc[0] == 8 and cc[1] >= 9) - -FP8_SUPPORTED = _is_fp8_supported() - -def _fp8_matmul(x: Tensor, weight: Tensor) -> Tensor: - x_2d = x.reshape(-1, x.size(-1)) - x_scale = (x_2d.abs().max().clamp(min=1e-12) / 448.0).float() - w_scale = (weight.abs().max().clamp(min=1e-12) / 448.0).float() - x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) - w_fp8 = (weight / w_scale).to(torch.float8_e4m3fn) - # weight is (Out, In), so w_fp8.t() is (In, Out) and is column-major - out = torch._scaled_mm(x_fp8, w_fp8.t(), scale_a=x_scale, scale_b=w_scale, out_dtype=x.dtype) - return out.reshape(*x.shape[:-1], weight.size(0)) - -class FP8LinearFunction(torch.autograd.Function): - @staticmethod - @torch.amp.custom_fwd(device_type="cuda") - def forward(ctx, x, weight): - ctx.x_shape = x.shape - ctx.dtype = x.dtype - ctx.weight_dtype = weight.dtype - - x_2d = x.reshape(-1, x.size(-1)) - x_scale = (x_2d.abs().max().clamp(min=1e-12) / 448.0).float() - w_scale = (weight.abs().max().clamp(min=1e-12) / 448.0).float() - - x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) - w_fp8 = (weight / w_scale).to(torch.float8_e4m3fn) - - ctx.save_for_backward(x_fp8, w_fp8, x_scale, w_scale) - - # mat2 must be col-major: w_fp8.t() is col-major because w_fp8 is row-major - out = torch._scaled_mm(x_fp8, w_fp8.t(), scale_a=x_scale, scale_b=w_scale, out_dtype=ctx.dtype) - return out.reshape(*ctx.x_shape[:-1], weight.size(0)) - - @staticmethod - @torch.amp.custom_bwd(device_type="cuda") - def backward(ctx, grad_output): - x_fp8, w_fp8, x_scale, w_scale = ctx.saved_tensors - grad_output_2d = grad_output.reshape(-1, grad_output.size(-1)) - - go_scale = (grad_output_2d.abs().max().clamp(min=1e-12) / 448.0).float() - go_fp8 = (grad_output_2d / go_scale).to(torch.float8_e4m3fn) - - # 1. Grad_x = Grad_output @ Weight - # To make Weight (mat2) col-major, we transpose it, make it contiguous, and transpose back. - # This creates a tensor with the same shape but (1, Rows) strides. - w_fp8_col_major = w_fp8.t().contiguous().t() - grad_x = torch._scaled_mm(go_fp8, w_fp8_col_major, scale_a=go_scale, scale_b=w_scale, out_dtype=ctx.dtype) - - # 2. Grad_weight = Grad_output^T @ x - # mat1 must be row-major, mat2 must be col-major. - go_fp8_t = go_fp8.t().contiguous() - x_fp8_col_major = x_fp8.t().contiguous().t() - grad_weight = torch._scaled_mm(go_fp8_t, x_fp8_col_major, scale_a=go_scale, scale_b=x_scale, out_dtype=ctx.weight_dtype) - - return grad_x.reshape(*ctx.x_shape), grad_weight - -class FP8Linear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - if not FP8_SUPPORTED: - return super().forward(x) - if not torch.is_grad_enabled(): - out = _fp8_matmul(x, self.weight) - else: - out = FP8LinearFunction.apply(x, self.weight) - if self.bias is not None: - out = out + self.bias - return out - - class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() @@ -627,9 +551,10 @@ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_le self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qk = FP8Linear(dim, dim + self.kv_dim, bias=False) - self.c_v = FP8Linear(dim, self.kv_dim, bias=False) - self.proj = FP8Linear(dim, dim, bias=False) + self.c_qk = nn.Linear(dim, dim + self.kv_dim, bias=False) + self.c_v = nn.Linear(dim, self.kv_dim, bias=False) + self.v_mix = nn.Parameter(torch.zeros(dim)) + self.proj = nn.Linear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) @@ -639,7 +564,10 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: bsz, seqlen, dim = x.shape qk = self.c_qk(x) q, k = qk.split([dim, self.kv_dim], dim=-1) - v = self.c_v(emb) + + mix = self.v_mix[None, None, :] + v_input = (mix * x) + ((1.0-mix) * emb) + v = self.c_v(v_input) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) @@ -671,8 +599,8 @@ class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: int): super().__init__() self.hidden = calculate_hidden(mlp_mult, dim) - self.fused_fc = FP8Linear(dim, 2 * self.hidden, bias=False) - self.proj = FP8Linear(self.hidden, dim, bias=False) + self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) + self.proj = nn.Linear(self.hidden, dim, bias=False) self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: @@ -733,6 +661,13 @@ def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: scale = math.sin(progress * math.pi) return p_min + (p_max - p_min) * scale +def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: + # If base_mult is 2, this progresses from 1.0 (Layer 0) to 3.0 (Final Layer) + min_mult = float(base_mult) * 0.5 + max_mult = float(base_mult) * 1.5 + fraction = layer_idx / max(1, total_layers - 1) + return min_mult + (max_mult - min_mult) * fraction + class GPT(nn.Module): def __init__( self, @@ -758,7 +693,7 @@ def __init__( model_dim, num_heads, get_linear_progression_kv_heads(i, num_layers, num_kv_heads), - mlp_mult, + get_linear_progression_mlp_mult(i, num_layers, mlp_mult), rope_base, qk_gain_init, seq_len=seq_len, @@ -794,8 +729,9 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) - - return F.cross_entropy(logits.float(), targets) + + logits = 30.0 * torch.tanh(logits.float() / 30.0) + return F.cross_entropy(logits, targets) # ----------------------------- From 79303f6bd536217e1a0f1b1d0e13ceb5a50299ac Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 15 Apr 2026 14:36:17 -0700 Subject: [PATCH 25/80] ZerO --- train_gpt.py | 91 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 30 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 1b96b86910..895159d2cb 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -52,30 +52,30 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 256)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.2)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.1)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.5)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.4)) + embed_lr = float(os.environ.get("EMBED_LR", 0.15)) head_lr = float(os.environ.get("HEAD_LR", 0.01)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.035)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -83,7 +83,7 @@ class Hyperparameters: beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) # ----------------------------- # MUON OPTIMIZER @@ -108,6 +108,47 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): X = a * X + B @ X return X.T if transposed else X +@torch.no_grad() +def get_hadamard_matrix(n, device): + """Generates a deterministic, orthonormal Hadamard matrix.""" + p2 = 2**math.ceil(math.log2(n)) + H = torch.tensor([[1.0]], device=device) + while H.shape[0] < p2: + H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) + return H[:n, :n] / math.sqrt(p2) + +@torch.no_grad() +def apply_zero_init(model, std=0.02): + """ + Unified ZerO Init: + - Embeddings: Hadamard + - Linear Layers: Last in any sub-module is 0 (Exit), others are Hadamard (Internal). + """ + for m in model.modules(): + # 1. Handle Embeddings (Identity-like initialization) + if isinstance(m, nn.Embedding): + d_out, d_in = m.weight.shape + H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) + m.weight.copy_(H[:d_out, :d_in]*std) + + # 2. Handle Linear Layers by inspecting the module's direct children + # We look for direct Linear children to identify "Branches" + linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] + if linears: + for i, l in enumerate(linears): + d_out, d_in = l.weight.shape + + # The 'Exit' layer is the last Linear in this specific module + if i == len(linears) - 1: + nn.init.zeros_(l.weight) + # 'Internal' layers get Hadamard symmetry breaking + else: + H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) + l.weight.copy_(H[:d_out, :d_in]) + + if l.bias is not None: + nn.init.zeros_(l.bias) + class Muon(torch.optim.Optimizer): def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( @@ -589,24 +630,22 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) -def calculate_hidden(mlp_mult: int, dim: int): +def calculate_hidden(mlp_mult: float, dim: int): raw_hidden = int(mlp_mult * dim // 1.5) multiplier = raw_hidden / 64 return 64 if multiplier == 0 else (2**round(math.log2(multiplier))) * 64 class MLP(nn.Module): - # Using SwiGLU as introduced in Shazeer (2020) - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim: int, mlp_mult: float): super().__init__() - self.hidden = calculate_hidden(mlp_mult, dim) - self.fused_fc = nn.Linear(dim, 2 * self.hidden, bias=False) - self.proj = nn.Linear(self.hidden, dim, bias=False) - self.proj._zero_init = True + self.fused_down = nn.Linear(dim, 2 * calculate_hidden(mlp_mult, dim), bias=False) + self.w_u = nn.Linear(calculate_hidden(mlp_mult, dim), dim, bias=False) + self.w_u._zero_init = True # For ZerO Init def forward(self, x: Tensor) -> Tensor: - fused_out = self.fused_fc(x) - x1, x2 = fused_out.chunk(2, dim=-1) - return self.proj(F.silu(x1) * x2) + gate, val = self.fused_down(x).chunk(2, dim=-1) + hidden = F.silu(gate) * val + return self.w_u(hidden) class Block(nn.Module): @@ -615,7 +654,7 @@ def __init__( dim: int, num_heads: int, num_kv_heads: int, - mlp_mult: int, + mlp_mult: float, rope_base: float, qk_gain_init: float, seq_len: int=1024, @@ -676,7 +715,7 @@ def __init__( model_dim: int, num_heads: int, num_kv_heads: int, - mlp_mult: int, + mlp_mult: float, tie_embeddings: bool, tied_embed_init_std: float, rope_base: float, @@ -706,14 +745,7 @@ def __init__( self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) + apply_zero_init(self) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: emb = self.tok_emb(input_ids) @@ -1137,6 +1169,5 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: dist.destroy_process_group() - if __name__ == "__main__": main() \ No newline at end of file From b03634494f6c2835b578b3aa9d6d3e2b9fb3af72 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:24:27 -0700 Subject: [PATCH 26/80] Hyperparameter tuning --- train_gpt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 895159d2cb..9f1cfc76f2 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -65,17 +65,17 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.5)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.15)) - head_lr = float(os.environ.get("HEAD_LR", 0.01)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.025)) + head_lr = float(os.environ.get("HEAD_LR", 0.03)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.01)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.035)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.045)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) From b69df5e88bd8e03c07addc4393e2d7ea7c33f1b8 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 15 Apr 2026 22:15:20 -0700 Subject: [PATCH 27/80] Just one more layer --- train_gpt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 9f1cfc76f2..60ccc8aa0e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,16 +61,16 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = float(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 2048.0)) + rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.025)) + embed_lr = float(os.environ.get("EMBED_LR", 0.03)) head_lr = float(os.environ.get("HEAD_LR", 0.03)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.01)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) @@ -1143,7 +1143,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob)), map_location="cpu") + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() From d83e32a72c2d5189278c4063cff01a31f5c3c395 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 16 Apr 2026 09:22:40 -0700 Subject: [PATCH 28/80] Whoops fixed too many parameters --- check.py | 17 +++++++++++++++++ train_gpt.py | 12 ++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 check.py diff --git a/check.py b/check.py new file mode 100644 index 0000000000..1e3c6490c6 --- /dev/null +++ b/check.py @@ -0,0 +1,17 @@ +from train_gpt import Hyperparameters, GPT +args = Hyperparameters() +base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len + ) + +print(sum([p.numel() for p in base_model.parameters()])) \ No newline at end of file diff --git a/train_gpt.py b/train_gpt.py index 60ccc8aa0e..c0e84621e0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,11 +61,11 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 2.5)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) @@ -74,10 +74,10 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.03)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.01)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.045)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.03)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) beta1 = float(os.environ.get("BETA1", 0.9)) @@ -881,7 +881,7 @@ def log0(msg: str, console: bool = True) -> None: module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR From f74eda1d5b90c64c6a62bb0ea74b98eb9afe982a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:01:12 -0700 Subject: [PATCH 29/80] DDP BEFORE Compile + Pin memory --- train_gpt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index c0e84621e0..acd75878f8 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -290,7 +290,7 @@ def eval_val( # Pre-load the first batch raw_start = seq_start * args.train_seq_len raw_end = min(raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) - next_batch = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + next_batch = val_tokens[raw_start:raw_end].pin_memory().to(device=device, dtype=torch.int64, non_blocking=True) with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): torch.cuda.current_stream().wait_stream(torch.cuda.default_stream()) @@ -299,7 +299,7 @@ def eval_val( if next_seq_start < seq_end: n_raw_start = next_seq_start * args.train_seq_len n_raw_end = min(n_raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) - next_batch = val_tokens[n_raw_start:n_raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + next_batch = val_tokens[n_raw_start:n_raw_end].pin_memory().to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): @@ -536,7 +536,7 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span # Combine dtype conversion + device transfer in single operation - local = chunk[start : start + per_rank_span].to(device=self.device, dtype=torch.int64, non_blocking=True) + local = chunk[start : start + per_rank_span].pin_memory().to(device=self.device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x, y @@ -880,8 +880,11 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) if distributed else compiled_model + if distributed: + model = DDP(base_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) + else: + model = base_model + model = torch.compile(model) # Optimizer split: # - token embedding (Adam) uses EMBED_LR From 850e81b98b0a5f73b6eb0fa57c6468ae3c0ce048 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:07:44 -0700 Subject: [PATCH 30/80] Actually compile before ddp --- train_gpt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index acd75878f8..52851e5586 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -880,11 +880,10 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) + + model = torch.compile(base_model, mode="reduce-overhead") if distributed: - model = DDP(base_model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) - else: - model = base_model - model = torch.compile(model) + model = DDP(model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) # Optimizer split: # - token embedding (Adam) uses EMBED_LR From 24b71e612b4b64bd1d15a8ac65ace0a6fb8d587b Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:12:12 -0700 Subject: [PATCH 31/80] Maybe no static graph? --- train_gpt.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 52851e5586..58b6b5a83c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -880,10 +880,8 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) - - model = torch.compile(base_model, mode="reduce-overhead") - if distributed: - model = DDP(model, device_ids=[local_rank], broadcast_buffers=False, static_graph=True) + compiled_model = torch.compile(base_model) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR From 6736fe46f7dbbb72a2cf8186bb28c5d970b19cdc Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 13:06:00 -0700 Subject: [PATCH 32/80] Optimizations attempt 1 --- train_gpt.py | 141 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 105 insertions(+), 36 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 58b6b5a83c..aee5b2239c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -92,20 +92,34 @@ class Hyperparameters: # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ -@torch.compile +@torch.compile(fullgraph=True) def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): - if torch.isnan(G).any() or torch.isinf(G).any(): - return torch.zeros_like(G) - a, b, c = (3.4445, -4.7750, 2.0315) + # 1. Eliminate CPU-GPU synchronization (Graph Break) + # isfinite().all() creates a scalar tensor on the device. + # Multiplying by this mask eliminates NaNs/Infs without a host-sync. + valid_mask = torch.isfinite(G).all().to(G.dtype) + G = G * valid_mask + + a, b, c = 3.4445, -4.7750, 2.0315 X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) + + # 2. Multiply by reciprocal (faster than tensor division) + X.mul_(1.0 / (X.norm() + eps)) + + transposed = X.size(0) > X.size(1) if transposed: X = X.T + for _ in range(steps): A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X + + # 3. Fuse pointwise operations into the GEMM kernel via addmm + # Equivalent to: B = b * A + c * (A @ A) + B = torch.addmm(A, A, A, beta=b, alpha=c) + + # Equivalent to: X = a * X + 1.0 * (B @ X) + X = torch.addmm(X, B, X, beta=a, alpha=1.0) + return X.T if transposed else X @torch.no_grad() @@ -155,6 +169,46 @@ def __init__(self, params, lr: float, momentum: float, backend_steps: int, neste params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), ) + self._is_initialized = False + + def _init_state(self): + """Pre-allocates buffers and computes static distribution logic once.""" + if self._is_initialized: + return + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + + # 1. Pre-allocate the communication buffer ONCE + total_params = sum(p.numel() for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + group["updates_flat"] = updates_flat + group["param_views"] = [] + group["rank_params"] = [] + group["rank_param_views"] = [] + + curr = 0 + for i, p in enumerate(params): + numel = p.numel() + # 2. Pre-compute views to avoid slicing in the hot loop + view = updates_flat[curr : curr + numel].view_as(p) + group["param_views"].append(view) + + # 3. Statically determine which params belong to this rank + if i % world_size == rank: + group["rank_params"].append(p) + group["rank_param_views"].append(view) + + curr += numel + + self._is_initialized = True @torch.no_grad() def step(self, closure=None): @@ -163,50 +217,65 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + self._init_state() distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 for group in self.param_groups: - params = group["params"] - if not params: + if not group["params"]: continue + lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + # Clear the shared buffer efficiently + updates_flat = group["updates_flat"] + updates_flat.zero_() - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() + # Only iterate over parameters assigned to this rank + for p, view in zip(group["rank_params"], group["rank_param_views"]): + if p.grad is None: + continue + g = p.grad + state = self.state[p] + + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + # In-place momentum update + buf.mul_(momentum).add_(g) + + update = g.add(buf, alpha=momentum) if nesterov else g + + # Assuming zeropower_via_newtonschulz5 returns a new tensor + g_ns = zeropower_via_newtonschulz5(update, steps=backend_steps) + + # In-place scaling + g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1)) ** 0.5) + + # Direct copy into the pre-sliced view of updates_flat + view.copy_(g_ns) + + # Synchronize updates across all GPUs if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() + # 4. Kernel Fusion: Update all parameters simultaneously via foreach + params = group["params"] + views = group["param_views"] - return loss + # Avoid casting if the parameter is already bfloat16 + if updates_flat.dtype == params[0].dtype: + torch._foreach_add_(params, views, alpha=-lr) + else: + # Cast quickly and apply fused update + casted_views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] + torch._foreach_add_(params, casted_views, alpha=-lr) + return loss # ----------------------------- # TOKENIZER-AGNOSTIC EVALUATION SETUP From 2c103aa76b7a8d1267645ef60b3c6262a845179f Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 13:51:55 -0700 Subject: [PATCH 33/80] Optims pt 2 --- train_gpt.py | 163 +++++++++++++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 70 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index aee5b2239c..543ed391f0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,7 +53,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 256)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -72,14 +72,14 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.03)) head_lr = float(os.environ.get("HEAD_LR", 0.03)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.01)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.03)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.9)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.06)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.01)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 192)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) @@ -164,29 +164,26 @@ def apply_zero_init(model, std=0.02): nn.init.zeros_(l.bias) class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) + def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) + super().__init__(params, defaults) self._is_initialized = False def _init_state(self): - """Pre-allocates buffers and computes static distribution logic once.""" - if self._is_initialized: - return - + """Pre-allocates buffers for each group individually.""" distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 for group in self.param_groups: - params = group["params"] - if not params: + # Check if THIS specific group needs initialization + if "updates_flat" in group or not group["params"]: continue - # 1. Pre-allocate the communication buffer ONCE + params = group["params"] total_params = sum(p.numel() for p in params) + + # Pre-allocate the communication buffer for this group updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) group["updates_flat"] = updates_flat @@ -197,28 +194,19 @@ def _init_state(self): curr = 0 for i, p in enumerate(params): numel = p.numel() - # 2. Pre-compute views to avoid slicing in the hot loop view = updates_flat[curr : curr + numel].view_as(p) group["param_views"].append(view) - # 3. Statically determine which params belong to this rank if i % world_size == rank: group["rank_params"].append(p) group["rank_param_views"].append(view) curr += numel - self._is_initialized = True - @torch.no_grad() def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - + loss = closure() if closure is not None else None self._init_state() - distributed = dist.is_available() and dist.is_initialized() for group in self.param_groups: if not group["params"]: @@ -228,12 +216,11 @@ def step(self, closure=None): momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - - # Clear the shared buffer efficiently updates_flat = group["updates_flat"] + updates_flat.zero_() - # Only iterate over parameters assigned to this rank + # 1. Local Computation for p, view in zip(group["rank_params"], group["rank_param_views"]): if p.grad is None: continue @@ -243,37 +230,38 @@ def step(self, closure=None): if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] - # In-place momentum update + # Momentum update: m = m * momentum + g buf.mul_(momentum).add_(g) - update = g.add(buf, alpha=momentum) if nesterov else g + # Nesterov logic: use (g + momentum * m) or just m + u = g.add(buf, alpha=momentum) if nesterov else buf - # Assuming zeropower_via_newtonschulz5 returns a new tensor - g_ns = zeropower_via_newtonschulz5(update, steps=backend_steps) + # Apply Newton-Schulz (on 2D flattened version if ndim > 2) + original_shape = u.shape + if u.ndim > 2: + u = u.view(u.size(0), -1) + + g_ns = zeropower_via_newtonschulz5(u, steps=backend_steps) - # In-place scaling - g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1)) ** 0.5) + # Scaling factor (standard for Muon) + g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) - # Direct copy into the pre-sliced view of updates_flat - view.copy_(g_ns) + view.copy_(g_ns.view(original_shape)) - # Synchronize updates across all GPUs - if distributed: + # 2. Distributed Sync + if dist.is_available() and dist.is_initialized(): dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - # 4. Kernel Fusion: Update all parameters simultaneously via foreach + # 3. Parameter Update (Kernel Fusion) + # Use foreach_add with a list of tensors to avoid multiple kernel launches params = group["params"] views = group["param_views"] - - # Avoid casting if the parameter is already bfloat16 - if updates_flat.dtype == params[0].dtype: - torch._foreach_add_(params, views, alpha=-lr) - else: - # Cast quickly and apply fused update - casted_views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] - torch._foreach_add_(params, casted_views, alpha=-lr) + if updates_flat.dtype != params[0].dtype: + views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] + torch._foreach_add_(params, views, alpha=-lr) return loss @@ -591,23 +579,37 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device + self.rank, self.world_size, self.device = rank, world_size, device self.stream = TokenStream(pattern) + self.next_x, self.next_y = None, None + # Create a separate CUDA stream for background transfers + self.transfer_stream = torch.cuda.Stream(device) + + def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + with torch.cuda.stream(self.transfer_stream): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + + # This transfer now happens entirely in the background + local = chunk[start : start + per_rank_span].pin_memory().to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + self.next_x = local[:-1].reshape(-1, seq_len) + self.next_y = local[1:].reshape(-1, seq_len) def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - # Combine dtype conversion + device transfer in single operation - local = chunk[start : start + per_rank_span].pin_memory().to(device=self.device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) + if self.next_x is None: + self.preload(global_tokens, seq_len, grad_accum_steps) + + # Ensure the default stream waits for the background transfer to finish + torch.cuda.current_stream().wait_stream(self.transfer_stream) + x, y = self.next_x, self.next_y + + # Immediately queue up the next batch in the background + self.preload(global_tokens, seq_len, grad_accum_steps) return x, y # ----------------------------- @@ -648,10 +650,24 @@ def forward(self): def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: rotary_dim = cos.shape[-1] * 2 - x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] - x1, x2 = x_rot.chunk(2, dim=-1) - x_rotated = torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) - return torch.cat((x_rotated, x_pass), dim=-1) + half_dim = rotary_dim // 2 + + # 1. Allocate the final tensor exactly once + out = torch.empty_like(x) + + # 2. View extraction (essentially zero cost) + x1 = x[..., :half_dim] + x2 = x[..., half_dim:rotary_dim] + + # 3. Write math directly into the output tensor's memory + out[..., :half_dim] = x1 * cos - x2 * sin + out[..., half_dim:rotary_dim] = x1 * sin + x2 * cos + + # 4. Copy pass-through features (if using partial RoPE) + if x.shape[-1] > rotary_dim: + out[..., rotary_dim:] = x[..., rotary_dim:] + + return out class CausalSelfAttention(nn.Module): @@ -742,8 +758,8 @@ def __init__( def forward(self, x: Tensor, emb: Tensor) -> Tensor: attn_out = self.attn(self.attn_norm(x), emb) - y = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - return self.resid_scale.to(dtype=x.dtype)[None, None, :] * x + y + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(y)) + y = x + self.attn_scale[None, None, :] * attn_out + return self.resid_scale[None, None, :] * x + y + self.mlp_scale[None, None, :] * self.mlp(self.mlp_norm(y)) def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): @@ -814,7 +830,7 @@ def __init__( self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True - apply_zero_init(self) + apply_zero_init(self, std=self.tied_embed_init_std) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: emb = self.tok_emb(input_ids) @@ -861,6 +877,13 @@ def main() -> None: raise RuntimeError("CUDA is required") device = torch.device("cuda", local_rank) torch.cuda.set_device(device) + # 1. Force the PyTorch dispatcher to use Tensor Cores for all matmuls + torch.set_float32_matmul_precision('high') + import torch._inductor.config as inductor_config + inductor_config.fx_graph_cache = True # Caches compiled kernels to disk (saves 5+ minutes on restart) + inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP + inductor_config.freezing = True # Aggressive constant-folding for inference/eval + if distributed: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() From eb41e01a9e2833865b13ca26a53037077c304b5a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:33:49 -0700 Subject: [PATCH 34/80] data --- train_gpt.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 543ed391f0..be97f37a56 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -17,6 +17,7 @@ import time import uuid import zlib +import concurrent.futures from pathlib import Path import numpy as np @@ -72,10 +73,10 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.03)) head_lr = float(os.environ.get("HEAD_LR", 0.03)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.0075)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.06)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.01)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -154,7 +155,7 @@ def apply_zero_init(model, std=0.02): # The 'Exit' layer is the last Linear in this specific module if i == len(linears) - 1: - nn.init.zeros_(l.weight) + nn.init.normal_(l.weight, std=1e-5) # 'Internal' layers get Hadamard symmetry breaking else: H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) @@ -548,20 +549,30 @@ def load_data_shard(file: Path) -> Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: raise FileNotFoundError(f"No files found for pattern: {pattern}") self.file_idx = 0 + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + # Initial load and pre-queue the NEXT file self.tokens = load_data_shard(self.files[0]) + self.future_tokens = self._queue_next_file() self.pos = 0 + def _queue_next_file(self): + next_idx = (self.file_idx + 1) % len(self.files) + return self.executor.submit(load_data_shard, self.files[next_idx]) + def _advance_file(self) -> None: self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) + # result() will block ONLY if the disk is slower than your training + # (unlikely with NVMe), otherwise it returns instantly. + self.tokens = self.future_tokens.result() self.pos = 0 + # Start loading the one AFTER this + self.future_tokens = self._queue_next_file() def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] @@ -583,7 +594,6 @@ def __init__(self, pattern: str, rank: int, world_size: int, device: torch.devic self.rank, self.world_size, self.device = rank, world_size, device self.stream = TokenStream(pattern) self.next_x, self.next_y = None, None - # Create a separate CUDA stream for background transfers self.transfer_stream = torch.cuda.Stream(device) def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): @@ -593,12 +603,13 @@ def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span - # This transfer now happens entirely in the background + # OPTIMIZATION: Move to device as int32 (4 bytes) instead of int64 (8 bytes) local = chunk[start : start + per_rank_span].pin_memory().to( - device=self.device, dtype=torch.int64, non_blocking=True + device=self.device, dtype=torch.int32, non_blocking=True ) + # Embedding layer handles int32 perfectly. self.next_x = local[:-1].reshape(-1, seq_len) - self.next_y = local[1:].reshape(-1, seq_len) + self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) # CrossEntropy needs long def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: if self.next_x is None: @@ -752,9 +763,9 @@ def __init__( self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) def forward(self, x: Tensor, emb: Tensor) -> Tensor: attn_out = self.attn(self.attn_norm(x), emb) From 00cf003bb90860419f3a1274728a65ed9067fe5f Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:44:29 -0700 Subject: [PATCH 35/80] Readd skips --- train_gpt.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index be97f37a56..955ffc0ad0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,7 +53,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 256)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 512)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) @@ -782,12 +782,10 @@ def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): fraction = layer_idx / (total_layers - 1) raw_kv = min_kv + (max_kv - min_kv) * fraction - # Constraints: Must be power of 2 and divide num_heads - kv_heads = 2 ** round(math.log2(raw_kv)) - while num_heads % kv_heads != 0: - kv_heads //= 2 - - return int(max(2, kv_heads)) + valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] + kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) + + return int(max(1, kv_heads)) def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: if num_layers <= 1: @@ -847,8 +845,14 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: emb = self.tok_emb(input_ids) x = F.rms_norm(emb, (emb.size(-1),)) - for block in self.blocks: - x = block(x, emb) + skips = [] + for i, block in enumerate(self.blocks): + if i < len(self.blocks): # Layers 0, 1, 2, 3 + skips.append(x) # Just saves a reference to the tensor in memory + elif i >= len(self.blocks): # Layers 4, 5, 6, 7 + x = x + skips.pop() # A simple element-wise addition + + x = block(x, emb) # The actual Transformer Block computation x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) From 472be7eb5bd5be427773045271c8ecfa6edc0bb0 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:11:21 -0700 Subject: [PATCH 36/80] Step things up --- train_gpt.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 955ffc0ad0..bcb28ac852 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -49,12 +49,12 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 40)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 512)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -62,11 +62,11 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 2.5)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) @@ -74,11 +74,11 @@ class Hyperparameters: embed_lr = float(os.environ.get("EMBED_LR", 0.03)) head_lr = float(os.environ.get("HEAD_LR", 0.03)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.05)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.25)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.03)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) From 441ecca4101467c96c83c669b2526d88194764c1 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:25:12 -0700 Subject: [PATCH 37/80] Skip correction --- train_gpt.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index bcb28ac852..53fe8fd30a 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -49,7 +49,7 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 20)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) @@ -62,7 +62,7 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_layers = int(os.environ.get("NUM_LAYERS", 8)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) @@ -839,6 +839,7 @@ def __init__( self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True + self.skip_scales = nn.Parameter(torch.zeros(num_layers // 2)) apply_zero_init(self, std=self.tied_embed_init_std) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: @@ -846,13 +847,15 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = F.rms_norm(emb, (emb.size(-1),)) skips = [] + half = len(self.blocks) // 2 for i, block in enumerate(self.blocks): - if i < len(self.blocks): # Layers 0, 1, 2, 3 - skips.append(x) # Just saves a reference to the tensor in memory - elif i >= len(self.blocks): # Layers 4, 5, 6, 7 - x = x + skips.pop() # A simple element-wise addition + if i < half: + skips.append(x) + else: + scale = self.skip_scales[i - half] + x = x + scale * skips.pop() - x = block(x, emb) # The actual Transformer Block computation + x = block(x, emb) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) From adf8b121a3d0459789d5b5605b05627563bb2548 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:57:45 -0700 Subject: [PATCH 38/80] rope and reduce parameters --- train_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 53fe8fd30a..ed2802c0a6 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -58,7 +58,7 @@ class Hyperparameters: train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.1)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.25)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -785,9 +785,9 @@ def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) - return int(max(1, kv_heads)) + return int(max(min_kv, kv_heads)) -def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.5, p_max=0.75) -> float: if num_layers <= 1: return p_min progress = i / (num_layers - 1) From de6769633016ea6ef87d7959e180fa73477d154b Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:49:35 -0700 Subject: [PATCH 39/80] ZerO stuff and redo ReLU^2 --- train_gpt.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ed2802c0a6..48a0f1be08 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -62,11 +62,11 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 8)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.4)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) @@ -692,7 +692,6 @@ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_le self.c_v = nn.Linear(dim, self.kv_dim, bias=False) self.v_mix = nn.Parameter(torch.zeros(dim)) self.proj = nn.Linear(dim, dim, bias=False) - self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) self.use_rope = use_rope @@ -726,22 +725,16 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) -def calculate_hidden(mlp_mult: float, dim: int): - raw_hidden = int(mlp_mult * dim // 1.5) - multiplier = raw_hidden / 64 - return 64 if multiplier == 0 else (2**round(math.log2(multiplier))) * 64 - class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: float): super().__init__() - self.fused_down = nn.Linear(dim, 2 * calculate_hidden(mlp_mult, dim), bias=False) - self.w_u = nn.Linear(calculate_hidden(mlp_mult, dim), dim, bias=False) - self.w_u._zero_init = True # For ZerO Init + hidden_dim = int(mlp_mult * dim // 64) * 64 + self.input = nn.Linear(dim, hidden_dim, bias=False) + self.out = nn.Linear(hidden_dim, dim, bias=False) def forward(self, x: Tensor) -> Tensor: - gate, val = self.fused_down(x).chunk(2, dim=-1) - hidden = F.silu(gate) * val - return self.w_u(hidden) + x = torch.relu(self.input(x)) + return self.out(x * x) class Block(nn.Module): @@ -837,8 +830,6 @@ def __init__( ]) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True self.skip_scales = nn.Parameter(torch.zeros(num_layers // 2)) apply_zero_init(self, std=self.tied_embed_init_std) From 81bde6f0309e484888f775588b12c2e896ddb232 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:40:03 -0700 Subject: [PATCH 40/80] Attempt optimizations --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index 48a0f1be08..328e911f88 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -982,7 +982,7 @@ def log0(msg: str, console: bool = True) -> None: module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR From e62da0d4ff7fb10e9200424af270c14a7c7a3a7f Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:43:10 -0700 Subject: [PATCH 41/80] warmup plus slight increase of qk_gain and mlp_mult --- train_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 328e911f88..786b168472 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,11 +54,11 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 512)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 128)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 192)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.25)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 2)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -66,7 +66,7 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.4)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.5)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) From c55d3f16cb5d3c83f7336afbc0d84fc93b9a1133 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 13:05:20 -0700 Subject: [PATCH 42/80] Reducing sequential steps --- train_gpt.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 786b168472..e06939d37d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -696,13 +696,13 @@ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_le self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) self.use_rope = use_rope - def forward(self, x: Tensor, emb: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qk = self.c_qk(x) + def forward(self, x_unnorm: Tensor, x_norm: Tensor, emb: Tensor) -> Tensor: + bsz, seqlen, dim = x_norm.shape + qk = self.c_qk(x_norm) q, k = qk.split([dim, self.kv_dim], dim=-1) mix = self.v_mix[None, None, :] - v_input = (mix * x) + ((1.0-mix) * emb) + v_input = (mix * x_unnorm) + ((1.0-mix) * emb) v = self.c_v(v_input) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) @@ -761,10 +761,11 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) def forward(self, x: Tensor, emb: Tensor) -> Tensor: - attn_out = self.attn(self.attn_norm(x), emb) - y = x + self.attn_scale[None, None, :] * attn_out - return self.resid_scale[None, None, :] * x + y + self.mlp_scale[None, None, :] * self.mlp(self.mlp_norm(y)) - + attn_out = self.attn(x, self.attn_norm(x), emb) + mlp_out = self.mlp(self.mlp_norm(x)) + + # Parallel residual additions folded mathematically using (1.0 + resid_scale) + return (1.0 + self.resid_scale[None, None, :]) * x + self.attn_scale[None, None, :] * attn_out + self.mlp_scale[None, None, :] * mlp_out def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): # Progresses from 2 heads at layer 0 to num_heads at the final layer From 2cae1e7d972267483b6ab7cb4c0701a7595a4b26 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 13:30:19 -0700 Subject: [PATCH 43/80] Do compile --- train_gpt.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e06939d37d..7b65472877 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -734,7 +734,8 @@ def __init__(self, dim: int, mlp_mult: float): def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.input(x)) - return self.out(x * x) + x = x.square() + return self.out(x) class Block(nn.Module): @@ -892,7 +893,6 @@ def main() -> None: import torch._inductor.config as inductor_config inductor_config.fx_graph_cache = True # Caches compiled kernels to disk (saves 5+ minutes on restart) inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP - inductor_config.freezing = True # Aggressive constant-folding for inference/eval if distributed: dist.init_process_group(backend="nccl", device_id=device) @@ -982,7 +982,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model) + compiled_model = torch.compile(base_model, mode="reduce-overhead", fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model # Optimizer split: @@ -1090,6 +1090,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.train() for warmup_step in range(args.warmup_steps): zero_grad_all() + torch.compiler.cudagraph_mark_step_begin() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 @@ -1158,6 +1159,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) + torch.compiler.cudagraph_mark_step_begin() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 From f60e102d16b2a129a1d45e05499e944882c3890d Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:48:08 -0700 Subject: [PATCH 44/80] Revert + Fix --- train_gpt.py | 55 +++++++++++++++++++++++----------------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 7b65472877..15b7604a01 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -661,25 +661,20 @@ def forward(self): def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: rotary_dim = cos.shape[-1] * 2 - half_dim = rotary_dim // 2 + x_ro = x[..., :rotary_dim] + x_pass = x[..., rotary_dim:] # Handles partial RoPE - # 1. Allocate the final tensor exactly once - out = torch.empty_like(x) + x1, x2 = x_ro.chunk(2, dim=-1) - # 2. View extraction (essentially zero cost) - x1 = x[..., :half_dim] - x2 = x[..., half_dim:rotary_dim] + # Use concatenation instead of slice assignment + rotated = torch.cat([ + x1 * cos - x2 * sin, + x1 * sin + x2 * cos + ], dim=-1) - # 3. Write math directly into the output tensor's memory - out[..., :half_dim] = x1 * cos - x2 * sin - out[..., half_dim:rotary_dim] = x1 * sin + x2 * cos - - # 4. Copy pass-through features (if using partial RoPE) - if x.shape[-1] > rotary_dim: - out[..., rotary_dim:] = x[..., rotary_dim:] - - return out - + if x_pass.numel() > 0: + return torch.cat([rotated, x_pass], dim=-1) + return rotated class CausalSelfAttention(nn.Module): def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): @@ -696,13 +691,13 @@ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_le self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) self.use_rope = use_rope - def forward(self, x_unnorm: Tensor, x_norm: Tensor, emb: Tensor) -> Tensor: - bsz, seqlen, dim = x_norm.shape - qk = self.c_qk(x_norm) + def forward(self, x: Tensor, emb: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qk = self.c_qk(x) q, k = qk.split([dim, self.kv_dim], dim=-1) mix = self.v_mix[None, None, :] - v_input = (mix * x_unnorm) + ((1.0-mix) * emb) + v_input = (self.v_mix * x) + ((1.0 - self.v_mix) * emb) v = self.c_v(v_input) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) @@ -716,7 +711,7 @@ def forward(self, x_unnorm: Tensor, x_norm: Tensor, emb: Tensor) -> Tensor: q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + q = q * self.q_gain.to(dtype=q.dtype).view(1, -1, 1, 1) y = F.scaled_dot_product_attention( q, k, v, is_causal=True, @@ -734,8 +729,7 @@ def __init__(self, dim: int, mlp_mult: float): def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.input(x)) - x = x.square() - return self.out(x) + return self.out(x * x) class Block(nn.Module): @@ -761,12 +755,12 @@ def __init__( self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) + # Inside Block.forward def forward(self, x: Tensor, emb: Tensor) -> Tensor: - attn_out = self.attn(x, self.attn_norm(x), emb) - mlp_out = self.mlp(self.mlp_norm(x)) - - # Parallel residual additions folded mathematically using (1.0 + resid_scale) - return (1.0 + self.resid_scale[None, None, :]) * x + self.attn_scale[None, None, :] * attn_out + self.mlp_scale[None, None, :] * mlp_out + attn_out = self.attn(self.attn_norm(x), emb) + y = x + self.attn_scale * attn_out + return self.resid_scale * x + y + self.mlp_scale * self.mlp(self.mlp_norm(y)) + def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): # Progresses from 2 heads at layer 0 to num_heads at the final layer @@ -893,6 +887,7 @@ def main() -> None: import torch._inductor.config as inductor_config inductor_config.fx_graph_cache = True # Caches compiled kernels to disk (saves 5+ minutes on restart) inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP + inductor_config.freezing = True # Aggressive constant-folding for inference/eval if distributed: dist.init_process_group(backend="nccl", device_id=device) @@ -982,7 +977,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, mode="reduce-overhead", fullgraph=True) + compiled_model = torch.compile(base_model, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model # Optimizer split: @@ -1090,7 +1085,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.train() for warmup_step in range(args.warmup_steps): zero_grad_all() - torch.compiler.cudagraph_mark_step_begin() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 @@ -1159,7 +1153,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) - torch.compiler.cudagraph_mark_step_begin() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 From 42beafafead05b230582c9dcc3eb77d75b441f4e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:07:01 -0700 Subject: [PATCH 45/80] SwiGLU again --- train_gpt.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 15b7604a01..559691ffa5 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -723,13 +723,14 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: float): super().__init__() - hidden_dim = int(mlp_mult * dim // 64) * 64 - self.input = nn.Linear(dim, hidden_dim, bias=False) - self.out = nn.Linear(hidden_dim, dim, bias=False) + self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 + self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) + self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.input(x)) - return self.out(x * x) + fused_x = self.c_fc(x) + gate, value = fused_x.chunk(2, dim=-1) + return self.c_proj(F.silu(gate) * value) class Block(nn.Module): From 3c518d48447589b05ad2bef3b1e2b254ed180665 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:39:25 -0700 Subject: [PATCH 46/80] Hyperparam tuning --- train_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 559691ffa5..96a382912d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -58,7 +58,7 @@ class Hyperparameters: train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 2)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -66,7 +66,7 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.5)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.6)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) @@ -75,7 +75,7 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.03)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.25)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.03)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) From ecfac553ccdcceab7abad698510f30c2e7a85229 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:53:06 -0700 Subject: [PATCH 47/80] RoPE changes --- train_gpt.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 96a382912d..84654ce92a 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -659,22 +659,16 @@ def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 100 def forward(self): return self.cos, self.sin +def rotate_half(x: Tensor): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: rotary_dim = cos.shape[-1] * 2 x_ro = x[..., :rotary_dim] - x_pass = x[..., rotary_dim:] # Handles partial RoPE - - x1, x2 = x_ro.chunk(2, dim=-1) - - # Use concatenation instead of slice assignment - rotated = torch.cat([ - x1 * cos - x2 * sin, - x1 * sin + x2 * cos - ], dim=-1) - - if x_pass.numel() > 0: - return torch.cat([rotated, x_pass], dim=-1) - return rotated + x_pass = x[..., rotary_dim:] + x_rotated = (x_ro * cos.repeat_interleave(2, dim=-1)) + (rotate_half(x_ro) * sin.repeat_interleave(2, dim=-1)) + return torch.cat((x_rotated, x_pass), dim=-1) class CausalSelfAttention(nn.Module): def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): @@ -695,8 +689,6 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: bsz, seqlen, dim = x.shape qk = self.c_qk(x) q, k = qk.split([dim, self.kv_dim], dim=-1) - - mix = self.v_mix[None, None, :] v_input = (self.v_mix * x) + ((1.0 - self.v_mix) * emb) v = self.c_v(v_input) @@ -820,7 +812,7 @@ def __init__( rope_base, qk_gain_init, seq_len=seq_len, - use_rope=(i % 2 == 1), + use_rope=True, resid_scale=1/math.sqrt(2 * num_layers), rope_proportion=get_rope_p_smooth(i, num_layers) ) for i in range(num_layers) @@ -974,9 +966,6 @@ def log0(msg: str, console: bool = True) -> None: qk_gain_init=args.qk_gain_init, seq_len=args.train_seq_len ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model From 15527e5be8587638dd19889f5dc0596675c3a81d Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:26:16 -0700 Subject: [PATCH 48/80] Reduce learning rate --- train_gpt.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 84654ce92a..e1790852b3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -54,7 +54,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 512)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 192)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -71,12 +71,12 @@ class Hyperparameters: rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.03)) - head_lr = float(os.environ.get("HEAD_LR", 0.03)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.25)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.03)) + embed_lr = float(os.environ.get("EMBED_LR", 0.025)) + head_lr = float(os.environ.get("HEAD_LR", 0.025)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.0125)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.015)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) From bd450288054d3864be7ff0f5889a4061b2a1c693 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 29 Apr 2026 21:18:21 -0700 Subject: [PATCH 49/80] Hail mary --- train_gpt.py | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e1790852b3..cdb9cb1a16 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -69,6 +69,7 @@ class Hyperparameters: mlp_mult = float(os.environ.get("MLP_MULT", 1.6)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.025)) @@ -85,6 +86,7 @@ class Hyperparameters: beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) + keep_prob = float(os.environ.get("KEEP_PROB", 0.8)) # ----------------------------- # MUON OPTIMIZER @@ -647,7 +649,7 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): super().__init__() - self.rotary_dim = (int(dim * p) // 2) * 2 + self.rotary_dim = (int(dim * p) // 8) * 8 inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) @@ -794,15 +796,19 @@ def __init__( mlp_mult: float, tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, qk_gain_init: float, seq_len: int=1024, + keep_prob: float=0.8 ): super().__init__() self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers self.blocks = nn.ModuleList([ Block( model_dim, @@ -818,8 +824,11 @@ def __init__( ) for i in range(num_layers) ]) self.final_norm = RMSNorm() + self.logit_softcap = logit_softcap + self.keep_prob = keep_prob self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - self.skip_scales = nn.Parameter(torch.zeros(num_layers // 2)) + self.skip_scales = nn.Parameter(torch.full((self.num_encoder_layers,), 0.02, dtype=torch.float32)) + apply_zero_init(self, std=self.tied_embed_init_std) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: @@ -827,16 +836,17 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = F.rms_norm(emb, (emb.size(-1),)) skips = [] - half = len(self.blocks) // 2 - for i, block in enumerate(self.blocks): - if i < half: - skips.append(x) - else: - scale = self.skip_scales[i - half] - x = x + scale * skips.pop() + for i in range(self.num_encoder_layers): + skips.append(x) + x = self.blocks[i](x, emb) + for j in range(self.num_decoder_layers): + block_idx = j + self.num_encoder_layers + skip_idx = self.num_encoder_layers - 1 - j + if skip_idx >= 0: + scale = self.skip_scales[skip_idx] + x = x + (scale * skips[skip_idx]) - x = block(x, emb) - + x = self.blocks[block_idx](x, emb) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) @@ -845,7 +855,12 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: else: logits = self.lm_head(x) - logits = 30.0 * torch.tanh(logits.float() / 30.0) + logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) + if self.training: + loss = F.cross_entropy(logits, targets, reduction='none') + mask = torch.bernoulli(torch.full_like(loss, self.keep_prob)) + loss = (loss * mask) / self.keep_prob + return loss.mean() return F.cross_entropy(logits, targets) @@ -962,9 +977,11 @@ def log0(msg: str, console: bool = True) -> None: mlp_mult=args.mlp_mult, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len + seq_len=args.train_seq_len, + keep_prob=args.keep_prob ).to(device).bfloat16() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, fullgraph=True) @@ -986,6 +1003,8 @@ def log0(msg: str, console: bool = True) -> None: for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] + gpt_scalars = [p for name, p in base_model.named_parameters() if "skip_scales" in name or "skip_gain" in name] + scalar_params.extend(gpt_scalars) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], From d34a510772e9a51942fc29c5058197c1b957221a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:25:43 -0700 Subject: [PATCH 50/80] One can only hope --- train_gpt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index cdb9cb1a16..c9e85f8e4e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -49,12 +49,12 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 20)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 40)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 512)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 256)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 384)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -827,7 +827,7 @@ def __init__( self.logit_softcap = logit_softcap self.keep_prob = keep_prob self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - self.skip_scales = nn.Parameter(torch.full((self.num_encoder_layers,), 0.02, dtype=torch.float32)) + self.skip_scales = nn.Parameter(torch.zeros(self.num_encoder_layers, model_dim, dtype=torch.float32)) apply_zero_init(self, std=self.tied_embed_init_std) @@ -843,8 +843,8 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: block_idx = j + self.num_encoder_layers skip_idx = self.num_encoder_layers - 1 - j if skip_idx >= 0: - scale = self.skip_scales[skip_idx] - x = x + (scale * skips[skip_idx]) + skip_x = F.rms_norm(skips[skip_idx], (skips[skip_idx].size(-1),), eps=self.final_norm.eps) + x = x + (self.skip_scales[skip_idx] * skip_x) x = self.blocks[block_idx](x, emb) x = self.final_norm(x).reshape(-1, x.size(-1)) From ac0a7282319d6c141c5c8086de7e324aa527b29b Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:06:12 -0700 Subject: [PATCH 51/80] Restore original eval_val implementation to ensure correctness --- train_gpt.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index c9e85f8e4e..48b576f19d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -347,19 +347,12 @@ def eval_val( val_byte_count = torch.zeros((), device=device, dtype=torch.float64) model.eval() - # Pre-load the first batch - raw_start = seq_start * args.train_seq_len - raw_end = min(raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) - next_batch = val_tokens[raw_start:raw_end].pin_memory().to(device=device, dtype=torch.int64, non_blocking=True) with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - torch.cuda.current_stream().wait_stream(torch.cuda.default_stream()) - local = next_batch - next_seq_start = batch_seq_start + local_batch_seqs - if next_seq_start < seq_end: - n_raw_start = next_seq_start * args.train_seq_len - n_raw_end = min(n_raw_start + local_batch_seqs * args.train_seq_len + 1, val_tokens.numel()) - next_batch = val_tokens[n_raw_start:n_raw_end].pin_memory().to(device=device, dtype=torch.int64, non_blocking=True) + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): From effba2a04be36dbd2f58b830c17b6ca84f992d6e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 07:21:11 -0700 Subject: [PATCH 52/80] AI-based implementation optimizations --- train_gpt.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 48b576f19d..4475d49d3f 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -646,24 +646,21 @@ def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 100 inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) - cos = freqs.cos()[None, None, :, :].to(torch.bfloat16) - sin = freqs.sin()[None, None, :, :].to(torch.bfloat16) - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) + self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) + self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) - def forward(self): - return self.cos, self.sin - -def rotate_half(x: Tensor): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + T = x.size(2) + return self.cos[:, :, :T, :].to(x.dtype), self.sin[:, :, :T, :].to(x.dtype) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rotary_dim = cos.shape[-1] * 2 - x_ro = x[..., :rotary_dim] - x_pass = x[..., rotary_dim:] - x_rotated = (x_ro * cos.repeat_interleave(2, dim=-1)) + (rotate_half(x_ro) * sin.repeat_interleave(2, dim=-1)) - return torch.cat((x_rotated, x_pass), dim=-1) + d = cos.shape[-1] * 2 + x_rop = x[..., :d] + x_pass = x[..., d:] + x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) + x0, x1 = x_rop.unbind(-1) + res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) + return torch.cat([res.flatten(-2), x_pass], dim=-1) class CausalSelfAttention(nn.Module): def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): @@ -684,7 +681,7 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: bsz, seqlen, dim = x.shape qk = self.c_qk(x) q, k = qk.split([dim, self.kv_dim], dim=-1) - v_input = (self.v_mix * x) + ((1.0 - self.v_mix) * emb) + v_input = torch.lerp(emb, x, self.v_mix.to(x.dtype)) v = self.c_v(v_input) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) @@ -698,7 +695,7 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype).view(1, -1, 1, 1) + q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) y = F.scaled_dot_product_attention( q, k, v, is_causal=True, @@ -745,10 +742,13 @@ def __init__( # Inside Block.forward def forward(self, x: Tensor, emb: Tensor) -> Tensor: - attn_out = self.attn(self.attn_norm(x), emb) - y = x + self.attn_scale * attn_out - return self.resid_scale * x + y + self.mlp_scale * self.mlp(self.mlp_norm(y)) - + dtype = x.dtype + normed_x = self.attn_norm(x) + attn_out = self.attn(normed_x, emb) + y = x + self.attn_scale.to(dtype) * attn_out + normed_y = self.mlp_norm(y) + mlp_out = self.mlp(normed_y) + return self.resid_scale.to(dtype) * x + y + self.mlp_scale.to(dtype) * mlp_out def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): # Progresses from 2 heads at layer 0 to num_heads at the final layer From 412a17eb45899d985c6aa86da97e21fc6a1fd15c Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 07:45:10 -0700 Subject: [PATCH 53/80] Delete masking --- train_gpt.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 4475d49d3f..ac36584bbc 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -646,19 +646,25 @@ def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 100 inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) + # Register as float32; shapes: (1, 1, max_seq_len, rotary_dim // 2) self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) - def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: - T = x.size(2) - return self.cos[:, :, :T, :].to(x.dtype), self.sin[:, :, :T, :].to(x.dtype) + def forward(self, x: Tensor): + t = x.size(2) + return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) + +def rotate_half(x: Tensor): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: d = cos.shape[-1] * 2 x_rop = x[..., :d] x_pass = x[..., d:] x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) - x0, x1 = x_rop.unbind(-1) + x0 = x_rop[..., 0] + x1 = x_rop[..., 1] res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) return torch.cat([res.flatten(-2), x_pass], dim=-1) @@ -691,7 +697,7 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) if self.use_rope: - cos, sin = self.rotary() + cos, sin = self.rotary(q) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) @@ -849,11 +855,6 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: logits = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) - if self.training: - loss = F.cross_entropy(logits, targets, reduction='none') - mask = torch.bernoulli(torch.full_like(loss, self.keep_prob)) - loss = (loss * mask) / self.keep_prob - return loss.mean() return F.cross_entropy(logits, targets) From d6ce05e5d64fa7bb7e0219fba5bf1a52f45ddd0a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 09:41:12 -0700 Subject: [PATCH 54/80] Tuning to squeeze every last parameter --- train_gpt.py | 64 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ac36584bbc..c10c5d2f9a 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,8 +53,8 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1024)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 384)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -62,21 +62,22 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.6)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.85)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 4096.0)) + rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 4096.0)) + rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 256.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.025)) - head_lr = float(os.environ.get("HEAD_LR", 0.025)) + embed_lr = float(os.environ.get("EMBED_LR", 0.0225)) + head_lr = float(os.environ.get("HEAD_LR", 0.0225)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.0125)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.0225)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.015)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) @@ -86,7 +87,6 @@ class Hyperparameters: beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) - keep_prob = float(os.environ.get("KEEP_PROB", 0.8)) # ----------------------------- # MUON OPTIMIZER @@ -777,6 +777,12 @@ def get_rope_p_smooth(i: int, num_layers: int, p_min=0.5, p_max=0.75) -> float: scale = math.sin(progress * math.pi) return p_min + (p_max - p_min) * scale +def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float, max_base: float) -> float: + if total_layers <= 1: + return max_base + fraction = layer_idx / (total_layers - 1) + return min_base * ((max_base / min_base) ** fraction) + def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: # If base_mult is 2, this progresses from 1.0 (Layer 0) to 3.0 (Final Layer) min_mult = float(base_mult) * 0.5 @@ -796,10 +802,10 @@ def __init__( tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, - rope_base: float, + rope_max_base: float, + rope_min_base: float, qk_gain_init: float, seq_len: int=1024, - keep_prob: float=0.8 ): super().__init__() self.tie_embeddings = tie_embeddings @@ -814,7 +820,7 @@ def __init__( num_heads, get_linear_progression_kv_heads(i, num_layers, num_kv_heads), get_linear_progression_mlp_mult(i, num_layers, mlp_mult), - rope_base, + get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), qk_gain_init, seq_len=seq_len, use_rope=True, @@ -824,7 +830,6 @@ def __init__( ]) self.final_norm = RMSNorm() self.logit_softcap = logit_softcap - self.keep_prob = keep_prob self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) self.skip_scales = nn.Parameter(torch.zeros(self.num_encoder_layers, model_dim, dtype=torch.float32)) @@ -890,6 +895,9 @@ def main() -> None: inductor_config.fx_graph_cache = True # Caches compiled kernels to disk (saves 5+ minutes on restart) inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP inductor_config.freezing = True # Aggressive constant-folding for inference/eval + inductor_config.shape_padding = True + inductor_config.coordinate_descent_tuning = True + inductor_config.epilogue_fusion = True if distributed: dist.init_process_group(backend="nccl", device_id=device) @@ -972,10 +980,10 @@ def log0(msg: str, console: bool = True) -> None: tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, - rope_base=args.rope_base, + rope_max_base=args.rope_max_base, + rope_min_base=args.rope_min_base, qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len, - keep_prob=args.keep_prob + seq_len=args.train_seq_len ).to(device).bfloat16() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, fullgraph=True) @@ -1269,5 +1277,27 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: dist.destroy_process_group() +def main_params(): + args = Hyperparameters() + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_max_base=args.rope_max_base, + rope_min_base=args.rope_min_base, + qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len + ) + print(sum([p.numel() for p in base_model.parameters()])) + if __name__ == "__main__": - main() \ No newline at end of file + if not torch.cuda.is_available(): + main_params() + else: + main() \ No newline at end of file From 8a95441f79e254234e84f7302398216c1987a308 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:03:23 -0700 Subject: [PATCH 55/80] Squeeze part 2 --- train_gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index c10c5d2f9a..0b53350678 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -66,7 +66,7 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.85)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.8625)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 4096.0)) rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 256.0)) @@ -713,7 +713,7 @@ def forward(self, x: Tensor, emb: Tensor) -> Tensor: class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: float): super().__init__() - self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 + self.hidden_dim = int(mlp_mult * 2/3 * dim // 32) * 32 self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) From 25f583cfb15f9e182a9bde35abcc41bcc343bc0a Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:36:29 -0700 Subject: [PATCH 56/80] Too many params --- train_gpt.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 0b53350678..e3935d3d11 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -66,19 +66,19 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.8625)) + mlp_mult = float(os.environ.get("MLP_MULT", 1.75)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 4096.0)) - rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 256.0)) + rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) + rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.0225)) - head_lr = float(os.environ.get("HEAD_LR", 0.0225)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.0125)) + embed_lr = float(os.environ.get("EMBED_LR", 0.025)) + head_lr = float(os.environ.get("HEAD_LR", 0.025)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.0225)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.015)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.0275)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.0175)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -742,9 +742,9 @@ def __init__( self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.05)) self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.1)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.05)) # Inside Block.forward def forward(self, x: Tensor, emb: Tensor) -> Tensor: @@ -770,7 +770,7 @@ def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): return int(max(min_kv, kv_heads)) -def get_rope_p_smooth(i: int, num_layers: int, p_min=0.5, p_max=0.75) -> float: +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.375, p_max=0.875) -> float: if num_layers <= 1: return p_min progress = i / (num_layers - 1) @@ -896,7 +896,6 @@ def main() -> None: inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP inductor_config.freezing = True # Aggressive constant-folding for inference/eval inductor_config.shape_padding = True - inductor_config.coordinate_descent_tuning = True inductor_config.epilogue_fusion = True if distributed: From cd72e7511dee635ee11deb560df52b7957474749 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:15:29 -0700 Subject: [PATCH 57/80] Longer --- train_gpt.py | 49 ++++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e3935d3d11..84d95a48d0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -62,31 +62,31 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 1.75)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 14)) + model_dim = int(os.environ.get("MODEL_DIM", 448)) + num_heads = int(os.environ.get("NUM_HEADS", 14)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.025)) - head_lr = float(os.environ.get("HEAD_LR", 0.025)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + embed_lr = float(os.environ.get("EMBED_LR", 0.03)) + head_lr = float(os.environ.get("HEAD_LR", 0.03)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.0275)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.0175)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 3.0)) # ----------------------------- # MUON OPTIMIZER @@ -675,20 +675,16 @@ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_le self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qk = nn.Linear(dim, dim + self.kv_dim, bias=False) - self.c_v = nn.Linear(dim, self.kv_dim, bias=False) - self.v_mix = nn.Parameter(torch.zeros(dim)) + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) self.use_rope = use_rope - def forward(self, x: Tensor, emb: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - qk = self.c_qk(x) - q, k = qk.split([dim, self.kv_dim], dim=-1) - v_input = torch.lerp(emb, x, self.v_mix.to(x.dtype)) - v = self.c_v(v_input) + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) @@ -747,10 +743,10 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.05)) # Inside Block.forward - def forward(self, x: Tensor, emb: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: dtype = x.dtype normed_x = self.attn_norm(x) - attn_out = self.attn(normed_x, emb) + attn_out = self.attn(normed_x) y = x + self.attn_scale.to(dtype) * attn_out normed_y = self.mlp_norm(y) mlp_out = self.mlp(normed_y) @@ -831,8 +827,7 @@ def __init__( self.final_norm = RMSNorm() self.logit_softcap = logit_softcap self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - self.skip_scales = nn.Parameter(torch.zeros(self.num_encoder_layers, model_dim, dtype=torch.float32)) - + self.skip_gate = nn.Parameter(torch.ones(self.num_encoder_layers, model_dim) * 1e-4) apply_zero_init(self, std=self.tied_embed_init_std) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: @@ -842,15 +837,15 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: skips = [] for i in range(self.num_encoder_layers): skips.append(x) - x = self.blocks[i](x, emb) + x = self.blocks[i](x) for j in range(self.num_decoder_layers): block_idx = j + self.num_encoder_layers skip_idx = self.num_encoder_layers - 1 - j if skip_idx >= 0: skip_x = F.rms_norm(skips[skip_idx], (skips[skip_idx].size(-1),), eps=self.final_norm.eps) - x = x + (self.skip_scales[skip_idx] * skip_x) + x = x + (self.skip_gate[skip_idx] * skip_x) - x = self.blocks[block_idx](x, emb) + x = self.blocks[block_idx](x) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) From bf03d007cd2e1bebf295cd427563e1944e87491d Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:31:54 -0700 Subject: [PATCH 58/80] Unique hyperparams + geometric rope progression --- train_gpt.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 84d95a48d0..7701cdbfdb 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -66,21 +66,21 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 14)) model_dim = int(os.environ.get("MODEL_DIM", 448)) num_heads = int(os.environ.get("NUM_HEADS", 14)) - mlp_mult = float(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 2.05)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.03)) - head_lr = float(os.environ.get("HEAD_LR", 0.03)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.4)) + head_lr = float(os.environ.get("HEAD_LR", 0.01)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) beta1 = float(os.environ.get("BETA1", 0.9)) @@ -777,7 +777,7 @@ def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float if total_layers <= 1: return max_base fraction = layer_idx / (total_layers - 1) - return min_base * ((max_base / min_base) ** fraction) + return min_base * (max_base / min_base) ** fraction def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: # If base_mult is 2, this progresses from 1.0 (Layer 0) to 3.0 (Final Layer) @@ -827,7 +827,7 @@ def __init__( self.final_norm = RMSNorm() self.logit_softcap = logit_softcap self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - self.skip_gate = nn.Parameter(torch.ones(self.num_encoder_layers, model_dim) * 1e-4) + self.skip_gate = nn.Parameter(torch.ones(self.num_encoder_layers, model_dim) * 1e-3) apply_zero_init(self, std=self.tied_embed_init_std) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: From a2905ff5a55c800b3ed7b138b217ee5ede5aaa25 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:08:23 -0700 Subject: [PATCH 59/80] It's unique for sure --- train_gpt.py | 2112 +++++++++++++++++++++----------------------------- 1 file changed, 878 insertions(+), 1234 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 7701cdbfdb..00a06238f5 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,11 +1,4 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - from __future__ import annotations - import copy import glob import io @@ -19,7 +12,6 @@ import zlib import concurrent.futures from pathlib import Path - import numpy as np import sentencepiece as spm import torch @@ -27,1271 +19,923 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run (not actually reflected within Hyperparameters): -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 40)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 14)) - model_dim = int(os.environ.get("MODEL_DIM", 448)) - num_heads = int(os.environ.get("NUM_HEADS", 14)) - mlp_mult = float(os.environ.get("MLP_MULT", 2.05)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) - rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.4)) - head_lr = float(os.environ.get("HEAD_LR", 0.01)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 3.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -@torch.compile(fullgraph=True) + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 5000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) + rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.3)) + head_lr = float(os.environ.get("HEAD_LR", 0.04)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) +@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True) def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): - # 1. Eliminate CPU-GPU synchronization (Graph Break) - # isfinite().all() creates a scalar tensor on the device. - # Multiplying by this mask eliminates NaNs/Infs without a host-sync. - valid_mask = torch.isfinite(G).all().to(G.dtype) - G = G * valid_mask - - a, b, c = 3.4445, -4.7750, 2.0315 - X = G.bfloat16() - - # 2. Multiply by reciprocal (faster than tensor division) - X.mul_(1.0 / (X.norm() + eps)) - - transposed = X.size(0) > X.size(1) - if transposed: - X = X.T - - for _ in range(steps): - A = X @ X.T - - # 3. Fuse pointwise operations into the GEMM kernel via addmm - # Equivalent to: B = b * A + c * (A @ A) - B = torch.addmm(A, A, A, beta=b, alpha=c) - - # Equivalent to: X = a * X + 1.0 * (B @ X) - X = torch.addmm(X, B, X, beta=a, alpha=1.0) - - return X.T if transposed else X - + X = torch.where(torch.isfinite(G), G, 0.0).to(torch.bfloat16) + norm = torch.linalg.matrix_norm(X) + eps + X = X / norm + transposed = X.size(0) > X.size(1) + if transposed: + X = X.T + a, b, c = 3.4445, -4.7750, 2.0315 + for _ in range(steps): + A = X @ X.T + B = torch.addmm(A, A, A, beta=b, alpha=c) + X = torch.addmm(X, B, X, beta=a, alpha=1.0) + return X.T if transposed else X @torch.no_grad() def get_hadamard_matrix(n, device): - """Generates a deterministic, orthonormal Hadamard matrix.""" - p2 = 2**math.ceil(math.log2(n)) - H = torch.tensor([[1.0]], device=device) - while H.shape[0] < p2: - H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) - return H[:n, :n] / math.sqrt(p2) - + p2 = 2**math.ceil(math.log2(n)) + H = torch.tensor([[1.0]], device=device) + while H.shape[0] < p2: + H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) + return H[:n, :n] / math.sqrt(p2) @torch.no_grad() def apply_zero_init(model, std=0.02): - """ - Unified ZerO Init: - - Embeddings: Hadamard - - Linear Layers: Last in any sub-module is 0 (Exit), others are Hadamard (Internal). - """ - for m in model.modules(): - # 1. Handle Embeddings (Identity-like initialization) - if isinstance(m, nn.Embedding): - d_out, d_in = m.weight.shape - H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) - m.weight.copy_(H[:d_out, :d_in]*std) - - # 2. Handle Linear Layers by inspecting the module's direct children - # We look for direct Linear children to identify "Branches" - linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] - if linears: - for i, l in enumerate(linears): - d_out, d_in = l.weight.shape - - # The 'Exit' layer is the last Linear in this specific module - if i == len(linears) - 1: - nn.init.normal_(l.weight, std=1e-5) - # 'Internal' layers get Hadamard symmetry breaking - else: - H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) - l.weight.copy_(H[:d_out, :d_in]) - - if l.bias is not None: - nn.init.zeros_(l.bias) - + for m in model.modules(): + if isinstance(m, nn.Embedding): + d_out, d_in = m.weight.shape + H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) + m.weight.copy_(H[:d_out, :d_in]*std) + linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] + if linears: + for i, l in enumerate(linears): + d_out, d_in = l.weight.shape + if i == len(linears) - 1: + nn.init.normal_(l.weight, std=1e-5) + else: + H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) + l.weight.copy_(H[:d_out, :d_in]) + if l.bias is not None: + nn.init.zeros_(l.bias) class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): - defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) - super().__init__(params, defaults) - self._is_initialized = False - - def _init_state(self): - """Pre-allocates buffers for each group individually.""" - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - # Check if THIS specific group needs initialization - if "updates_flat" in group or not group["params"]: - continue - - params = group["params"] - total_params = sum(p.numel() for p in params) - - # Pre-allocate the communication buffer for this group - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - group["updates_flat"] = updates_flat - group["param_views"] = [] - group["rank_params"] = [] - group["rank_param_views"] = [] - - curr = 0 - for i, p in enumerate(params): - numel = p.numel() - view = updates_flat[curr : curr + numel].view_as(p) - group["param_views"].append(view) - - if i % world_size == rank: - group["rank_params"].append(p) - group["rank_param_views"].append(view) - - curr += numel - - @torch.no_grad() - def step(self, closure=None): - loss = closure() if closure is not None else None - self._init_state() - - for group in self.param_groups: - if not group["params"]: - continue - - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - updates_flat = group["updates_flat"] - - updates_flat.zero_() - - # 1. Local Computation - for p, view in zip(group["rank_params"], group["rank_param_views"]): - if p.grad is None: - continue - - g = p.grad - state = self.state[p] - - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - - buf = state["momentum_buffer"] - - # Momentum update: m = m * momentum + g - buf.mul_(momentum).add_(g) - - # Nesterov logic: use (g + momentum * m) or just m - u = g.add(buf, alpha=momentum) if nesterov else buf - - # Apply Newton-Schulz (on 2D flattened version if ndim > 2) - original_shape = u.shape - if u.ndim > 2: - u = u.view(u.size(0), -1) - - g_ns = zeropower_via_newtonschulz5(u, steps=backend_steps) - - # Scaling factor (standard for Muon) - g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) - - view.copy_(g_ns.view(original_shape)) - - # 2. Distributed Sync - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - # 3. Parameter Update (Kernel Fusion) - # Use foreach_add with a list of tensors to avoid multiple kernel launches - params = group["params"] - views = group["param_views"] - if updates_flat.dtype != params[0].dtype: - views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] - torch._foreach_add_(params, views, alpha=-lr) - - return loss - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - + def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) + super().__init__(params, defaults) + self._is_initialized = False + def _init_state(self): + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + if "updates_flat" in group or not group["params"]: + continue + params = group["params"] + total_params = sum(p.numel() for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + group["updates_flat"] = updates_flat + group["param_views"] = [] + group["rank_params"] = [] + group["rank_param_views"] = [] + curr = 0 + for i, p in enumerate(params): + numel = p.numel() + view = updates_flat[curr : curr + numel].view_as(p) + group["param_views"].append(view) + if i % world_size == rank: + group["rank_params"].append(p) + group["rank_param_views"].append(view) + curr += numel + @torch.no_grad() + def step(self, closure=None): + loss = closure() if closure is not None else None + self._init_state() + for group in self.param_groups: + if not group["params"]: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + updates_flat = group["updates_flat"] + updates_flat.zero_() + for p, view in zip(group["rank_params"], group["rank_param_views"]): + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + u = g.add(buf, alpha=momentum) if nesterov else buf + original_shape = u.shape + if u.ndim > 2: + u = u.view(u.size(0), -1) + u_prepared = u.to(dtype=torch.bfloat16, memory_format=torch.contiguous_format, non_blocking=True) + g_ns = zeropower_via_newtonschulz5(u_prepared, steps=backend_steps) + g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) + view.copy_(g_ns.view(original_shape)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + params = group["params"] + views = group["param_views"] + if updates_flat.dtype != params[0].dtype: + views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] + torch._foreach_add_(params, views, alpha=-lr) + return loss def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", - ).split(",") - if pattern + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", + ).split(",") + if pattern ) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern ) INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - + return int(t.numel()) * int(t.element_size()) def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - # result() will block ONLY if the disk is slower than your training - # (unlikely with NVMe), otherwise it returns instantly. - self.tokens = self.future_tokens.result() - self.pos = 0 - # Start loading the one AFTER this - self.future_tokens = self._queue_next_file() - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - + def __init__(self, pattern: str): + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.tokens = load_data_shard(self.files[0]) + self.future_tokens = self._queue_next_file() + self.pos = 0 + def _queue_next_file(self): + next_idx = (self.file_idx + 1) % len(self.files) + return self.executor.submit(load_data_shard, self.files[next_idx]) + def _advance_file(self) -> None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = self.future_tokens.result() + self.pos = 0 + self.future_tokens = self._queue_next_file() + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - self.next_x, self.next_y = None, None - self.transfer_stream = torch.cuda.Stream(device) - - def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): - with torch.cuda.stream(self.transfer_stream): - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - - # OPTIMIZATION: Move to device as int32 (4 bytes) instead of int64 (8 bytes) - local = chunk[start : start + per_rank_span].pin_memory().to( - device=self.device, dtype=torch.int32, non_blocking=True - ) - # Embedding layer handles int32 perfectly. - self.next_x = local[:-1].reshape(-1, seq_len) - self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) # CrossEntropy needs long - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self.next_x is None: - self.preload(global_tokens, seq_len, grad_accum_steps) - - # Ensure the default stream waits for the background transfer to finish - torch.cuda.current_stream().wait_stream(self.transfer_stream) - x, y = self.next_x, self.next_y - - # Immediately queue up the next batch in the background - self.preload(global_tokens, seq_len, grad_accum_steps) - return x, y - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + self.next_x, self.next_y = None, None + self.transfer_stream = torch.cuda.Stream(device) + def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + with torch.cuda.stream(self.transfer_stream): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].pin_memory().to( + device=self.device, dtype=torch.int32, non_blocking=True + ) + self.next_x = local[:-1].reshape(-1, seq_len) + self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self.next_x is None: + self.preload(global_tokens, seq_len, grad_accum_steps) + torch.cuda.current_stream().wait_stream(self.transfer_stream) + x, y = self.next_x, self.next_y + self.preload(global_tokens, seq_len, grad_accum_steps) + return x, y class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): - super().__init__() - self.rotary_dim = (int(dim * p) // 8) * 8 - inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) - t = torch.arange(max_seq_len, dtype=torch.float32) - freqs = torch.outer(t, inv_freq) - # Register as float32; shapes: (1, 1, max_seq_len, rotary_dim // 2) - self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) - self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) - - def forward(self, x: Tensor): - t = x.size(2) - return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) - + def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): + super().__init__() + self.rotary_dim = (int(dim * p) // 8) * 8 + inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) + self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) + def forward(self, x: Tensor): + t = x.size(2) + return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) def rotate_half(x: Tensor): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - d = cos.shape[-1] * 2 - x_rop = x[..., :d] - x_pass = x[..., d:] - x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) - x0 = x_rop[..., 0] - x1 = x_rop[..., 1] - res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) - return torch.cat([res.flatten(-2), x_pass], dim=-1) - + d = cos.shape[-1] * 2 + x_rop = x[..., :d] + x_pass = x[..., d:] + x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) + x0 = x_rop[..., 0] + x1 = x_rop[..., 1] + res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) + return torch.cat([res.flatten(-2), x_pass], dim=-1) class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): - super().__init__() - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) - self.proj = nn.Linear(dim, dim, bias=False) - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) - self.use_rope = use_rope - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv = self.c_qkv(x) - q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) - - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - if self.use_rope: - cos, sin = self.rotary(q) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - - q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) - y = F.scaled_dot_product_attention( - q, k, v, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads) - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) + self.use_rope = use_rope + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + if self.use_rope: + cos, sin = self.rotary(q) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) + y = F.scaled_dot_product_attention( + q, k, v, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads) + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - self.hidden_dim = int(mlp_mult * 2/3 * dim // 32) * 32 - self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) - self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - fused_x = self.c_fc(x) - gate, value = fused_x.chunk(2, dim=-1) - return self.c_proj(F.silu(gate) * value) - - + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 + self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) + self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) + def forward(self, x: Tensor) -> Tensor: + fused_x = self.c_fc(x) + gate, value = fused_x.chunk(2, dim=-1) + return self.c_proj(F.silu(gate) * value) class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - rope_base: float, - qk_gain_init: float, - seq_len: int=1024, - use_rope: bool=True, - resid_scale: float=0.125, - rope_proportion: float=0.5, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.05)) - self.resid_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(resid_scale)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32).mul(0.05)) - - # Inside Block.forward - def forward(self, x: Tensor) -> Tensor: - dtype = x.dtype - normed_x = self.attn_norm(x) - attn_out = self.attn(normed_x) - y = x + self.attn_scale.to(dtype) * attn_out - normed_y = self.mlp_norm(y) - mlp_out = self.mlp(normed_y) - return self.resid_scale.to(dtype) * x + y + self.mlp_scale.to(dtype) * mlp_out - + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + seq_len: int=1024, + use_rope: bool=True, + rope_proportion: float=0.5, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) + self.mlp = MLP(dim, mlp_mult) + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.attn_norm(x)) + x = x + self.mlp(self.mlp_norm(x)) + return x def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): - # Progresses from 2 heads at layer 0 to num_heads at the final layer - min_kv = 2 - max_kv = num_heads - - # Simple linear interpolation from layer 0 to total_layers - fraction = layer_idx / (total_layers - 1) - raw_kv = min_kv + (max_kv - min_kv) * fraction - - valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] - kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) - - return int(max(min_kv, kv_heads)) - -def get_rope_p_smooth(i: int, num_layers: int, p_min=0.375, p_max=0.875) -> float: - if num_layers <= 1: - return p_min - progress = i / (num_layers - 1) - scale = math.sin(progress * math.pi) - return p_min + (p_max - p_min) * scale - + min_kv = 1 + max_kv = num_heads + fraction = layer_idx / (total_layers - 1) + raw_kv = min_kv + (max_kv - min_kv) * fraction + valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] + kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) + return int(max(min_kv, kv_heads)) +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: + if num_layers <= 1: + return p_min + progress = i / (num_layers - 1) + scale = math.sin(progress * math.pi) + return p_min + (p_max - p_min) * scale def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float, max_base: float) -> float: - if total_layers <= 1: - return max_base - fraction = layer_idx / (total_layers - 1) - return min_base * (max_base / min_base) ** fraction - + if total_layers <= 1: + return max_base + fraction = layer_idx / (total_layers - 1) + return min_base * (max_base / min_base) ** fraction def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: - # If base_mult is 2, this progresses from 1.0 (Layer 0) to 3.0 (Final Layer) - min_mult = float(base_mult) * 0.5 - max_mult = float(base_mult) * 1.5 - fraction = layer_idx / max(1, total_layers - 1) - return min_mult + (max_mult - min_mult) * fraction - + min_mult = float(base_mult) * 0.75 + max_mult = float(base_mult) * 1.25 + fraction = layer_idx / max(1, total_layers - 1) + return min_mult + (max_mult - min_mult) * fraction class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_max_base: float, - rope_min_base: float, - qk_gain_init: float, - seq_len: int=1024, - ): - super().__init__() - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_layers = num_layers - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.blocks = nn.ModuleList([ - Block( - model_dim, - num_heads, - get_linear_progression_kv_heads(i, num_layers, num_kv_heads), - get_linear_progression_mlp_mult(i, num_layers, mlp_mult), - get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), - qk_gain_init, - seq_len=seq_len, - use_rope=True, - resid_scale=1/math.sqrt(2 * num_layers), - rope_proportion=get_rope_p_smooth(i, num_layers) - ) for i in range(num_layers) - ]) - self.final_norm = RMSNorm() - self.logit_softcap = logit_softcap - self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - self.skip_gate = nn.Parameter(torch.ones(self.num_encoder_layers, model_dim) * 1e-3) - apply_zero_init(self, std=self.tied_embed_init_std) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - emb = self.tok_emb(input_ids) - x = F.rms_norm(emb, (emb.size(-1),)) - - skips = [] - for i in range(self.num_encoder_layers): - skips.append(x) - x = self.blocks[i](x) - for j in range(self.num_decoder_layers): - block_idx = j + self.num_encoder_layers - skip_idx = self.num_encoder_layers - 1 - j - if skip_idx >= 0: - skip_x = F.rms_norm(skips[skip_idx], (skips[skip_idx].size(-1),), eps=self.final_norm.eps) - x = x + (self.skip_gate[skip_idx] * skip_x) - - x = self.blocks[block_idx](x) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - - logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) - return F.cross_entropy(logits, targets) - - -# ----------------------------- -# TRAINING -# ----------------------------- - + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_max_base: float, + rope_min_base: float, + qk_gain_init: float, + seq_len: int=1024, + ): + super().__init__() + self.tie_embeddings = tie_embeddings + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.blocks = nn.ModuleList([ + Block( + model_dim, + num_heads, + get_linear_progression_kv_heads(i, num_layers, num_kv_heads), + get_linear_progression_mlp_mult(i, num_layers, mlp_mult), + get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), + qk_gain_init, + seq_len=seq_len, + use_rope=True, + rope_proportion=get_rope_p_smooth(i, num_layers) + ) for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.logit_softcap = logit_softcap + self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) + apply_zero_init(self, std=tied_embed_init_std) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + emb = self.tok_emb(input_ids) + x = F.rms_norm(emb, (emb.size(-1),)) + for block in self.blocks: + x = block(x) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) + return F.cross_entropy(logits, targets) def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - # 1. Force the PyTorch dispatcher to use Tensor Cores for all matmuls - torch.set_float32_matmul_precision('high') - import torch._inductor.config as inductor_config - inductor_config.fx_graph_cache = True # Caches compiled kernels to disk (saves 5+ minutes on restart) - inductor_config.triton.unique_kernel_names = True # Prevents Triton kernel namespace collisions in DDP - inductor_config.freezing = True # Aggressive constant-folding for inference/eval - inductor_config.shape_padding = True - inductor_config.epilogue_fusion = True - - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - @torch.compiler.disable - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_max_base=args.rope_max_base, - rope_min_base=args.rope_min_base, - qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len - ).to(device).bfloat16() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - gpt_scalars = [p for name, p in base_model.named_parameters() if "skip_scales" in name or "skip_gain" in name] - scalar_params.extend(gpt_scalars) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - warmup_mul = min(step / args.warmup_steps, 1.0) if args.warmup_steps > 0 else 1.0 - - # --- Warmdown Logic --- - if args.warmdown_iters <= 0: - return warmup_mul - - if max_wallclock_ms is None: - # Iteration-based warmdown - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - warmdown_mul = max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if step >= warmdown_start else 1.0 - else: - # Time-based warmdown - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - warmdown_mul = remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # The effective multiplier is the intersection of warmup and warmdown - return min(warmup_mul, warmdown_mul) - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) # ← Use len() instead of getsize() - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - -def main_params(): - args = Hyperparameters() - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_max_base=args.rope_max_base, - rope_min_base=args.rope_min_base, - qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len - ) - print(sum([p.numel() for p in base_model.parameters()])) - -if __name__ == "__main__": - if not torch.cuda.is_available(): - main_params() - else: - main() \ No newline at end of file + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + import torch._inductor.config as inductor_config + inductor_config.fx_graph_cache = True + inductor_config.triton.unique_kernel_names = True + inductor_config.freezing = True + inductor_config.shape_padding = True + inductor_config.epilogue_fusion = True + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + @torch.compiler.disable + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_max_base=args.rope_max_base, + rope_min_base=args.rope_min_base, + qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len + ).to(device).bfloat16() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, fullgraph=True, mode="reduce-overhead") + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + gpt_scalars = [p for name, p in base_model.named_parameters() if "skip" in name] + scalar_params.extend(gpt_scalars) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": main() \ No newline at end of file From 528747f323fbc90e33a3b627c4f0f085ad219bcd Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:36:11 -0700 Subject: [PATCH 60/80] Add submission --- .../56869d79-f695-4ef4-a7ee-2ff630965978.txt | 1236 +++++++++++++++++ .../README.md | 47 + .../requirements.txt | 10 + .../submission.json | 11 + 4 files changed, 1304 insertions(+) create mode 100644 records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/56869d79-f695-4ef4-a7ee-2ff630965978.txt create mode 100644 records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md create mode 100644 records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/requirements.txt create mode 100644 records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/submission.json diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/56869d79-f695-4ef4-a7ee-2ff630965978.txt b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/56869d79-f695-4ef4-a7ee-2ff630965978.txt new file mode 100644 index 0000000000..f925fe4e15 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/56869d79-f695-4ef4-a7ee-2ff630965978.txt @@ -0,0 +1,1236 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +import concurrent.futures +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 5000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 2.125)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) + rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.3)) + head_lr = float(os.environ.get("HEAD_LR", 0.04)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) +@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True) +def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): + X = torch.where(torch.isfinite(G), G, 0.0).to(torch.bfloat16) + norm = torch.linalg.matrix_norm(X) + eps + X = X / norm + transposed = X.size(0) > X.size(1) + if transposed: + X = X.T + a, b, c = 3.4445, -4.7750, 2.0315 + for _ in range(steps): + A = X @ X.T + B = torch.addmm(A, A, A, beta=b, alpha=c) + X = torch.addmm(X, B, X, beta=a, alpha=1.0) + return X.T if transposed else X +@torch.no_grad() +def get_hadamard_matrix(n, device): + p2 = 2**math.ceil(math.log2(n)) + H = torch.tensor([[1.0]], device=device) + while H.shape[0] < p2: + H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) + return H[:n, :n] / math.sqrt(p2) +@torch.no_grad() +def apply_zero_init(model, std=0.02): + for m in model.modules(): + if isinstance(m, nn.Embedding): + d_out, d_in = m.weight.shape + H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) + m.weight.copy_(H[:d_out, :d_in]*std) + linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] + if linears: + for i, l in enumerate(linears): + d_out, d_in = l.weight.shape + if i == len(linears) - 1: + nn.init.normal_(l.weight, std=1e-5) + else: + H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) + l.weight.copy_(H[:d_out, :d_in]) + if l.bias is not None: + nn.init.zeros_(l.bias) +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) + super().__init__(params, defaults) + self._is_initialized = False + def _init_state(self): + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + if "updates_flat" in group or not group["params"]: + continue + params = group["params"] + total_params = sum(p.numel() for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + group["updates_flat"] = updates_flat + group["param_views"] = [] + group["rank_params"] = [] + group["rank_param_views"] = [] + curr = 0 + for i, p in enumerate(params): + numel = p.numel() + view = updates_flat[curr : curr + numel].view_as(p) + group["param_views"].append(view) + if i % world_size == rank: + group["rank_params"].append(p) + group["rank_param_views"].append(view) + curr += numel + @torch.no_grad() + def step(self, closure=None): + loss = closure() if closure is not None else None + self._init_state() + for group in self.param_groups: + if not group["params"]: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + updates_flat = group["updates_flat"] + updates_flat.zero_() + for p, view in zip(group["rank_params"], group["rank_param_views"]): + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + u = g.add(buf, alpha=momentum) if nesterov else buf + original_shape = u.shape + if u.ndim > 2: + u = u.view(u.size(0), -1) + u_prepared = u.to(dtype=torch.bfloat16, memory_format=torch.contiguous_format, non_blocking=True) + g_ns = zeropower_via_newtonschulz5(u_prepared, steps=backend_steps) + g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) + view.copy_(g_ns.view(original_shape)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + params = group["params"] + views = group["param_views"] + if updates_flat.dtype != params[0].dtype: + views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] + torch._foreach_add_(params, views, alpha=-lr) + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = self.future_tokens.result() + self.pos = 0 + self.future_tokens = self._queue_next_file() + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + self.next_x, self.next_y = None, None + self.transfer_stream = torch.cuda.Stream(device) + def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + with torch.cuda.stream(self.transfer_stream): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].pin_memory().to( + device=self.device, dtype=torch.int32, non_blocking=True + ) + self.next_x = local[:-1].reshape(-1, seq_len) + self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self.next_x is None: + self.preload(global_tokens, seq_len, grad_accum_steps) + torch.cuda.current_stream().wait_stream(self.transfer_stream) + x, y = self.next_x, self.next_y + self.preload(global_tokens, seq_len, grad_accum_steps) + return x, y +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): + super().__init__() + self.rotary_dim = (int(dim * p) // 8) * 8 + inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) + self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) + def forward(self, x: Tensor): + t = x.size(2) + return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) +def rotate_half(x: Tensor): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + d = cos.shape[-1] * 2 + x_rop = x[..., :d] + x_pass = x[..., d:] + x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) + x0 = x_rop[..., 0] + x1 = x_rop[..., 1] + res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) + return torch.cat([res.flatten(-2), x_pass], dim=-1) +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) + self.use_rope = use_rope + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + if self.use_rope: + cos, sin = self.rotary(q) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) + y = F.scaled_dot_product_attention( + q, k, v, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads) + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 + self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) + self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) + def forward(self, x: Tensor) -> Tensor: + fused_x = self.c_fc(x) + gate, value = fused_x.chunk(2, dim=-1) + return self.c_proj(F.silu(gate) * value) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + seq_len: int=1024, + use_rope: bool=True, + rope_proportion: float=0.5, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) + self.mlp = MLP(dim, mlp_mult) + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.attn_norm(x)) + x = x + self.mlp(self.mlp_norm(x)) + return x +def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): + min_kv = 1 + max_kv = num_heads + fraction = layer_idx / (total_layers - 1) + raw_kv = min_kv + (max_kv - min_kv) * fraction + valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] + kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) + return int(max(min_kv, kv_heads)) +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: + if num_layers <= 1: + return p_min + progress = i / (num_layers - 1) + scale = math.sin(progress * math.pi) + return p_min + (p_max - p_min) * scale +def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float, max_base: float) -> float: + if total_layers <= 1: + return max_base + fraction = layer_idx / (total_layers - 1) + return min_base * (max_base / min_base) ** fraction +def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: + min_mult = float(base_mult) * 0.75 + max_mult = float(base_mult) * 1.25 + fraction = layer_idx / max(1, total_layers - 1) + return min_mult + (max_mult - min_mult) * fraction +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_max_base: float, + rope_min_base: float, + qk_gain_init: float, + seq_len: int=1024, + ): + super().__init__() + self.tie_embeddings = tie_embeddings + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.blocks = nn.ModuleList([ + Block( + model_dim, + num_heads, + get_linear_progression_kv_heads(i, num_layers, num_kv_heads), + get_linear_progression_mlp_mult(i, num_layers, mlp_mult), + get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), + qk_gain_init, + seq_len=seq_len, + use_rope=True, + rope_proportion=get_rope_p_smooth(i, num_layers) + ) for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.logit_softcap = logit_softcap + self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) + apply_zero_init(self, std=tied_embed_init_std) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + emb = self.tok_emb(input_ids) + x = F.rms_norm(emb, (emb.size(-1),)) + for block in self.blocks: + x = block(x) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) + return F.cross_entropy(logits, targets) +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + import torch._inductor.config as inductor_config + inductor_config.fx_graph_cache = True + inductor_config.triton.unique_kernel_names = True + inductor_config.freezing = True + inductor_config.shape_padding = True + inductor_config.epilogue_fusion = True + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + @torch.compiler.disable + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_max_base=args.rope_max_base, + rope_min_base=args.rope_min_base, + qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len + ).to(device).bfloat16() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, fullgraph=True, mode="reduce-overhead") + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + gpt_scalars = [p for name, p in base_model.named_parameters() if "skip" in name] + scalar_params.extend(gpt_scalars) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.11.0+cu130 +Fri May 1 06:22:08 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 43C P0 118W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 45C P0 121W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 43C P0 121W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:16318536 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:512 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:10/512 +warmup_step:20/512 +warmup_step:30/512 +warmup_step:40/512 +warmup_step:50/512 +warmup_step:60/512 +warmup_step:70/512 +warmup_step:80/512 +warmup_step:90/512 +warmup_step:100/512 +warmup_step:110/512 +warmup_step:120/512 +warmup_step:130/512 +warmup_step:140/512 +warmup_step:150/512 +warmup_step:160/512 +warmup_step:170/512 +warmup_step:180/512 +warmup_step:190/512 +warmup_step:200/512 +warmup_step:210/512 +warmup_step:220/512 +warmup_step:230/512 +warmup_step:240/512 +warmup_step:250/512 +warmup_step:260/512 +warmup_step:270/512 +warmup_step:280/512 +warmup_step:290/512 +warmup_step:300/512 +warmup_step:310/512 +warmup_step:320/512 +warmup_step:330/512 +warmup_step:340/512 +warmup_step:350/512 +warmup_step:360/512 +warmup_step:370/512 +warmup_step:380/512 +warmup_step:390/512 +warmup_step:400/512 +warmup_step:410/512 +warmup_step:420/512 +warmup_step:430/512 +warmup_step:440/512 +warmup_step:450/512 +warmup_step:460/512 +warmup_step:470/512 +warmup_step:480/512 +warmup_step:490/512 +warmup_step:500/512 +warmup_step:510/512 +warmup_step:512/512 +step:0/20000 val_loss:6.9326 val_bpb:4.1059 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9321 train_time:96ms step_avg:96.29ms +step:2/20000 train_loss:12.7577 train_time:108ms step_avg:53.86ms +step:3/20000 train_loss:9.7676 train_time:146ms step_avg:48.78ms +step:4/20000 train_loss:7.0821 train_time:183ms step_avg:45.79ms +step:5/20000 train_loss:6.3340 train_time:221ms step_avg:44.14ms +step:6/20000 train_loss:6.9930 train_time:258ms step_avg:43.00ms +step:7/20000 train_loss:6.4102 train_time:295ms step_avg:42.19ms +step:8/20000 train_loss:6.2008 train_time:333ms step_avg:41.62ms +step:9/20000 train_loss:5.8136 train_time:368ms step_avg:40.86ms +step:10/20000 train_loss:5.5661 train_time:404ms step_avg:40.42ms +step:100/20000 train_loss:3.3085 train_time:3832ms step_avg:38.32ms +step:200/20000 train_loss:2.8155 train_time:7645ms step_avg:38.23ms +step:300/20000 train_loss:2.4686 train_time:11462ms step_avg:38.21ms +step:400/20000 train_loss:2.3356 train_time:15294ms step_avg:38.23ms +step:500/20000 train_loss:2.4812 train_time:19125ms step_avg:38.25ms +step:600/20000 train_loss:2.5382 train_time:22957ms step_avg:38.26ms +step:700/20000 train_loss:2.4505 train_time:26789ms step_avg:38.27ms +step:800/20000 train_loss:2.3076 train_time:30627ms step_avg:38.28ms +step:900/20000 train_loss:2.3598 train_time:34462ms step_avg:38.29ms +step:1000/20000 train_loss:2.3954 train_time:38301ms step_avg:38.30ms +step:1100/20000 train_loss:2.2898 train_time:42136ms step_avg:38.31ms +step:1200/20000 train_loss:2.4160 train_time:45976ms step_avg:38.31ms +step:1300/20000 train_loss:2.3857 train_time:49816ms step_avg:38.32ms +step:1400/20000 train_loss:2.4578 train_time:53655ms step_avg:38.33ms +step:1500/20000 train_loss:2.2588 train_time:57498ms step_avg:38.33ms +step:1600/20000 train_loss:2.1302 train_time:61350ms step_avg:38.34ms +step:1700/20000 train_loss:2.2090 train_time:65194ms step_avg:38.35ms +step:1800/20000 train_loss:2.2405 train_time:69038ms step_avg:38.35ms +step:1900/20000 train_loss:2.2255 train_time:72889ms step_avg:38.36ms +step:2000/20000 train_loss:2.2935 train_time:76735ms step_avg:38.37ms +step:2100/20000 train_loss:2.3010 train_time:80584ms step_avg:38.37ms +step:2200/20000 train_loss:2.1174 train_time:84447ms step_avg:38.38ms +step:2300/20000 train_loss:2.4189 train_time:88291ms step_avg:38.39ms +step:2400/20000 train_loss:2.2397 train_time:92140ms step_avg:38.39ms +step:2500/20000 train_loss:2.1675 train_time:95981ms step_avg:38.39ms +step:2600/20000 train_loss:2.4472 train_time:99816ms step_avg:38.39ms +step:2700/20000 train_loss:2.1975 train_time:103656ms step_avg:38.39ms +step:2800/20000 train_loss:2.2769 train_time:107490ms step_avg:38.39ms +step:2900/20000 train_loss:2.2256 train_time:111335ms step_avg:38.39ms +step:3000/20000 train_loss:2.2668 train_time:115172ms step_avg:38.39ms +step:3100/20000 train_loss:2.2495 train_time:119006ms step_avg:38.39ms +step:3200/20000 train_loss:2.2295 train_time:122836ms step_avg:38.39ms +step:3300/20000 train_loss:2.2828 train_time:126673ms step_avg:38.39ms +step:3400/20000 train_loss:2.2037 train_time:130507ms step_avg:38.38ms +step:3500/20000 train_loss:2.2975 train_time:134337ms step_avg:38.38ms +step:3600/20000 train_loss:2.1646 train_time:138169ms step_avg:38.38ms +step:3700/20000 train_loss:2.1945 train_time:142005ms step_avg:38.38ms +step:3800/20000 train_loss:2.2685 train_time:145833ms step_avg:38.38ms +step:3900/20000 train_loss:2.0385 train_time:149664ms step_avg:38.38ms +step:4000/20000 train_loss:2.2095 train_time:153493ms step_avg:38.37ms +step:4100/20000 train_loss:2.2325 train_time:157327ms step_avg:38.37ms +step:4200/20000 train_loss:2.2192 train_time:161158ms step_avg:38.37ms +step:4300/20000 train_loss:2.0673 train_time:164984ms step_avg:38.37ms +step:4400/20000 train_loss:2.1623 train_time:168813ms step_avg:38.37ms +step:4500/20000 train_loss:2.3004 train_time:172642ms step_avg:38.36ms +step:4600/20000 train_loss:2.0217 train_time:176472ms step_avg:38.36ms +step:4700/20000 train_loss:2.3170 train_time:180297ms step_avg:38.36ms +step:4800/20000 train_loss:2.3109 train_time:184125ms step_avg:38.36ms +step:4900/20000 train_loss:2.2106 train_time:187953ms step_avg:38.36ms +step:5000/20000 train_loss:2.0821 train_time:191778ms step_avg:38.36ms +step:5000/20000 val_loss:2.2000 val_bpb:1.3030 train_time:191807ms step_avg:38.36ms +step:5100/20000 train_loss:2.0779 train_time:195601ms step_avg:38.35ms +step:5200/20000 train_loss:2.2210 train_time:199425ms step_avg:38.35ms +step:5300/20000 train_loss:2.2597 train_time:203248ms step_avg:38.35ms +step:5400/20000 train_loss:2.2354 train_time:207075ms step_avg:38.35ms +step:5500/20000 train_loss:2.1865 train_time:210897ms step_avg:38.34ms +step:5600/20000 train_loss:2.2334 train_time:214723ms step_avg:38.34ms +step:5700/20000 train_loss:2.2238 train_time:218545ms step_avg:38.34ms +step:5800/20000 train_loss:2.1977 train_time:222373ms step_avg:38.34ms +step:5900/20000 train_loss:2.1414 train_time:226194ms step_avg:38.34ms +step:6000/20000 train_loss:2.2692 train_time:230018ms step_avg:38.34ms +step:6100/20000 train_loss:2.1666 train_time:233843ms step_avg:38.33ms +step:6200/20000 train_loss:2.1356 train_time:237665ms step_avg:38.33ms +step:6300/20000 train_loss:2.0845 train_time:241489ms step_avg:38.33ms +step:6400/20000 train_loss:2.2149 train_time:245315ms step_avg:38.33ms +step:6500/20000 train_loss:2.1277 train_time:249139ms step_avg:38.33ms +step:6600/20000 train_loss:2.1770 train_time:252960ms step_avg:38.33ms +step:6700/20000 train_loss:2.2093 train_time:256789ms step_avg:38.33ms +step:6800/20000 train_loss:2.2402 train_time:260611ms step_avg:38.33ms +step:6900/20000 train_loss:2.1462 train_time:264433ms step_avg:38.32ms +step:7000/20000 train_loss:2.2669 train_time:268252ms step_avg:38.32ms +step:7100/20000 train_loss:2.1219 train_time:272076ms step_avg:38.32ms +step:7200/20000 train_loss:2.2491 train_time:275902ms step_avg:38.32ms +step:7300/20000 train_loss:2.1476 train_time:279726ms step_avg:38.32ms +step:7400/20000 train_loss:2.1668 train_time:283547ms step_avg:38.32ms +step:7500/20000 train_loss:2.1642 train_time:287369ms step_avg:38.32ms +step:7600/20000 train_loss:2.0476 train_time:291191ms step_avg:38.31ms +step:7700/20000 train_loss:2.1384 train_time:295013ms step_avg:38.31ms +step:7800/20000 train_loss:2.1988 train_time:298835ms step_avg:38.31ms +step:7900/20000 train_loss:2.1842 train_time:302657ms step_avg:38.31ms +step:8000/20000 train_loss:2.1684 train_time:306481ms step_avg:38.31ms +step:8100/20000 train_loss:2.1974 train_time:310303ms step_avg:38.31ms +step:8200/20000 train_loss:2.2376 train_time:314135ms step_avg:38.31ms +step:8300/20000 train_loss:2.1766 train_time:317962ms step_avg:38.31ms +step:8400/20000 train_loss:2.1880 train_time:321787ms step_avg:38.31ms +step:8500/20000 train_loss:2.1806 train_time:325605ms step_avg:38.31ms +step:8600/20000 train_loss:2.1877 train_time:329429ms step_avg:38.31ms +step:8700/20000 train_loss:2.0807 train_time:333247ms step_avg:38.30ms +step:8800/20000 train_loss:2.1613 train_time:337070ms step_avg:38.30ms +step:8900/20000 train_loss:2.2540 train_time:340886ms step_avg:38.30ms +step:9000/20000 train_loss:2.0775 train_time:344708ms step_avg:38.30ms +step:9100/20000 train_loss:2.3561 train_time:348527ms step_avg:38.30ms +step:9200/20000 train_loss:2.1404 train_time:352349ms step_avg:38.30ms +step:9300/20000 train_loss:2.1881 train_time:356167ms step_avg:38.30ms +step:9400/20000 train_loss:2.1932 train_time:359989ms step_avg:38.30ms +step:9500/20000 train_loss:2.3075 train_time:363807ms step_avg:38.30ms +step:9600/20000 train_loss:2.1974 train_time:367631ms step_avg:38.29ms +step:9700/20000 train_loss:2.1641 train_time:371446ms step_avg:38.29ms +step:9800/20000 train_loss:2.1332 train_time:375266ms step_avg:38.29ms +step:9900/20000 train_loss:2.1992 train_time:379078ms step_avg:38.29ms +step:10000/20000 train_loss:2.1724 train_time:382904ms step_avg:38.29ms +step:10000/20000 val_loss:2.1669 val_bpb:1.2834 train_time:382933ms step_avg:38.29ms +step:10100/20000 train_loss:2.1487 train_time:386718ms step_avg:38.29ms +step:10200/20000 train_loss:2.1234 train_time:390539ms step_avg:38.29ms +step:10300/20000 train_loss:2.2396 train_time:394413ms step_avg:38.29ms +step:10400/20000 train_loss:2.1579 train_time:398257ms step_avg:38.29ms +step:10500/20000 train_loss:2.0369 train_time:402076ms step_avg:38.29ms +step:10600/20000 train_loss:2.0318 train_time:405895ms step_avg:38.29ms +step:10700/20000 train_loss:2.1239 train_time:409714ms step_avg:38.29ms +step:10800/20000 train_loss:2.2393 train_time:413529ms step_avg:38.29ms +step:10900/20000 train_loss:2.2075 train_time:417349ms step_avg:38.29ms +step:11000/20000 train_loss:2.1734 train_time:421168ms step_avg:38.29ms +step:11100/20000 train_loss:2.1196 train_time:424996ms step_avg:38.29ms +step:11200/20000 train_loss:2.1242 train_time:428815ms step_avg:38.29ms +step:11300/20000 train_loss:2.0554 train_time:432636ms step_avg:38.29ms +step:11400/20000 train_loss:2.1112 train_time:436455ms step_avg:38.29ms +step:11500/20000 train_loss:2.1653 train_time:440278ms step_avg:38.29ms +step:11600/20000 train_loss:2.1158 train_time:444098ms step_avg:38.28ms +step:11700/20000 train_loss:2.2628 train_time:447919ms step_avg:38.28ms +step:11800/20000 train_loss:2.1500 train_time:451741ms step_avg:38.28ms +step:11900/20000 train_loss:2.1074 train_time:455568ms step_avg:38.28ms +step:12000/20000 train_loss:2.1282 train_time:459388ms step_avg:38.28ms +step:12100/20000 train_loss:2.1586 train_time:463213ms step_avg:38.28ms +step:12200/20000 train_loss:2.2711 train_time:467030ms step_avg:38.28ms +step:12300/20000 train_loss:2.1612 train_time:470852ms step_avg:38.28ms +step:12400/20000 train_loss:1.9223 train_time:474675ms step_avg:38.28ms +step:12500/20000 train_loss:2.4385 train_time:478492ms step_avg:38.28ms +step:12600/20000 train_loss:2.1458 train_time:482311ms step_avg:38.28ms +step:12700/20000 train_loss:2.1435 train_time:486131ms step_avg:38.28ms +step:12800/20000 train_loss:2.1655 train_time:489949ms step_avg:38.28ms +step:12900/20000 train_loss:2.2066 train_time:493769ms step_avg:38.28ms +step:13000/20000 train_loss:2.2481 train_time:497594ms step_avg:38.28ms +step:13100/20000 train_loss:2.0823 train_time:501411ms step_avg:38.28ms +step:13200/20000 train_loss:2.2604 train_time:505241ms step_avg:38.28ms +step:13300/20000 train_loss:2.1139 train_time:509060ms step_avg:38.28ms +step:13400/20000 train_loss:2.1364 train_time:512883ms step_avg:38.27ms +step:13500/20000 train_loss:2.0993 train_time:516702ms step_avg:38.27ms +step:13600/20000 train_loss:2.0107 train_time:520525ms step_avg:38.27ms +step:13700/20000 train_loss:2.1780 train_time:524342ms step_avg:38.27ms +step:13800/20000 train_loss:2.0868 train_time:528172ms step_avg:38.27ms +step:13900/20000 train_loss:2.1781 train_time:532013ms step_avg:38.27ms +step:14000/20000 train_loss:2.1436 train_time:535841ms step_avg:38.27ms +step:14100/20000 train_loss:2.1131 train_time:539663ms step_avg:38.27ms +step:14200/20000 train_loss:2.2242 train_time:543489ms step_avg:38.27ms +step:14300/20000 train_loss:2.1785 train_time:547308ms step_avg:38.27ms +step:14400/20000 train_loss:2.1152 train_time:551130ms step_avg:38.27ms +step:14500/20000 train_loss:2.0416 train_time:554954ms step_avg:38.27ms +step:14600/20000 train_loss:2.1680 train_time:558772ms step_avg:38.27ms +step:14700/20000 train_loss:2.1258 train_time:562594ms step_avg:38.27ms +step:14800/20000 train_loss:1.9416 train_time:566414ms step_avg:38.27ms +step:14900/20000 train_loss:1.8653 train_time:570235ms step_avg:38.27ms +step:15000/20000 train_loss:2.0555 train_time:574054ms step_avg:38.27ms +step:15000/20000 val_loss:2.1115 val_bpb:1.2506 train_time:574083ms step_avg:38.27ms +step:15100/20000 train_loss:2.1918 train_time:577878ms step_avg:38.27ms +step:15200/20000 train_loss:2.1684 train_time:581696ms step_avg:38.27ms +step:15300/20000 train_loss:2.1991 train_time:585519ms step_avg:38.27ms +step:15400/20000 train_loss:2.0722 train_time:589338ms step_avg:38.27ms +step:15500/20000 train_loss:2.0847 train_time:593162ms step_avg:38.27ms +step:15600/20000 train_loss:2.0858 train_time:596980ms step_avg:38.27ms +step:15680/20000 val_loss:2.1038 val_bpb:1.2460 train_time:600065ms step_avg:38.27ms +stopping_early: wallclock_cap train_time:600065ms step:15680/20000 +peak memory allocated: 8612 MiB reserved: 8974 MiB +Serialized model: 32656041 bytes +Code size: 39067 bytes +Total submission size: 32695108 bytes +Serialized model int8+zlib: 15664859 bytes (payload:16379168 raw_torch:16404714 payload_ratio:1.99x) +Total submission size int8+zlib: 15703926 bytes +final_int8_zlib_roundtrip val_loss:2.1142 val_bpb:1.2521 eval_time:1392ms +final_int8_zlib_roundtrip_exact val_loss:2.11416887 val_bpb:1.25212989 diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md new file mode 100644 index 0000000000..6561ff1378 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md @@ -0,0 +1,47 @@ +# Constrained by Time +*(Not in the traditional sense)* + +Between schoolwork and other commitments, this was definitely an on-and-off type of project. +- Prior to this, I had never trained any language models of any kind from scratch like this, let alone in a speedrunning context. +- I was planning on full-scale testing but had to wait until 4/30/2026 for a sufficient compute grant for full H100 testing. With only a day to really experiment, there wasn't much time for me to validate anything. + +There were many (frankly) strange implementations that I tried experimenting with with the help of AI. +- Extreme depth/low dim +- Low depth/high dim +- Rolling data inside each batch (give the model one more chance at "harder" data) +- Dynamic loss scaling +- You'll find more by going through my commit history. + +Although this probably isn't going to shatter any records (at all), I do hope that this at least shines an interesting light at some potential ideas that may be integrated into future GPT training/speedruns. +- Even if this ultimately goes nowhere, it would be nice if some of the ideas in this implementation were explored. Perhaps there's too much going on in this implementation, and that together they're conflicting each other. + +My focus is not on Test-Time Training (TTT) or any implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. +- I may continue experiments even after this competition, since research isn't just one and done! + +## ZerO initalization +I was intersted in this paper: https://arxiv.org/pdf/2110.12661 +- Zhao et. al. describes how performance of deep networks can be both better and more reproducable through a more deterministic initialization method involving Hadaramard/Identity-like matrices + +The actual usage of this initialization does involve some level of non-determinism, but it's less pronounced than fully random initalization. + +## Progression of Various Model Hyperparameters +Throughout each layer, I utilized a progression of KV head count, rope proportion, and MLP multiple, all of which increase as the layer's depth w.r.t. the model increases. The rationale is as follows: +1. Earlier layers most likely focus on nearby context and shouldn't worry about long-range dependencies. +2. As tokens go further into the model, more information is going to be needed. + +## So, could it work? +Well, maybe if there was more time (to experiment + training time) and a better implementation. +- Of course, hyperparameter choice also plays a role. However, I did not have too much time to really test anything. + +## What would I have done if I had more time? +1. Tuning hyperparameters. +2. More experimentation with other strange hypotheses. +3. Optimizing the implementation. + +If I wasn't constrained by the constraints, I would also test it on larger-scale models (e.g. 100 million+ parameters). +- These ideas may perform better at larger scales, but not so at such a small scale (~16M params). + +## Why just one Result? +Time... +- If I had more time, I would submit more results. +- The result should (hopefully) be reproducable across seeds due to the [ZerO intialization](#zero-initalization) being used with a very small deviation. \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/requirements.txt b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/requirements.txt new file mode 100644 index 0000000000..911b0e52f0 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/submission.json b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/submission.json new file mode 100644 index 0000000000..35be36e82c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/submission.json @@ -0,0 +1,11 @@ +{ + "track": "non_record_16mb", + "date": "2026-04-30", + "name": "ZerO Initalization + Progressive KV, RoPE proportion and base, and MLP mult", + "author": "Alston Tang", + "github_id": "AlstonTang", + "val_bpb": 1.25212989, + "val_loss": 2.11416887, + "bytes_total": 15703926 + } + \ No newline at end of file From 9993ab140ce23539304e7b762af4884c3f0d15f2 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:36:49 -0700 Subject: [PATCH 61/80] Restore original train_gpt.py --- train_gpt.py | 1961 +++++++++++++++++++++++++++----------------------- 1 file changed, 1073 insertions(+), 888 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 00a06238f5..7f5125ac56 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,4 +1,11 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + from __future__ import annotations + import copy import glob import io @@ -10,8 +17,8 @@ import time import uuid import zlib -import concurrent.futures from pathlib import Path + import numpy as np import sentencepiece as spm import torch @@ -19,923 +26,1101 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 5000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) - rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.3)) - head_lr = float(os.environ.get("HEAD_LR", 0.04)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) -@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True) -def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): - X = torch.where(torch.isfinite(G), G, 0.0).to(torch.bfloat16) - norm = torch.linalg.matrix_norm(X) + eps - X = X / norm - transposed = X.size(0) > X.size(1) - if transposed: - X = X.T - a, b, c = 3.4445, -4.7750, 2.0315 - for _ in range(steps): - A = X @ X.T - B = torch.addmm(A, A, A, beta=b, alpha=c) - X = torch.addmm(X, B, X, beta=a, alpha=1.0) - return X.T if transposed else X -@torch.no_grad() -def get_hadamard_matrix(n, device): - p2 = 2**math.ceil(math.log2(n)) - H = torch.tensor([[1.0]], device=device) - while H.shape[0] < p2: - H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) - return H[:n, :n] / math.sqrt(p2) -@torch.no_grad() -def apply_zero_init(model, std=0.02): - for m in model.modules(): - if isinstance(m, nn.Embedding): - d_out, d_in = m.weight.shape - H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) - m.weight.copy_(H[:d_out, :d_in]*std) - linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] - if linears: - for i, l in enumerate(linears): - d_out, d_in = l.weight.shape - if i == len(linears) - 1: - nn.init.normal_(l.weight, std=1e-5) - else: - H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) - l.weight.copy_(H[:d_out, :d_in]) - if l.bias is not None: - nn.init.zeros_(l.bias) + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): - defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) - super().__init__(params, defaults) - self._is_initialized = False - def _init_state(self): - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - if "updates_flat" in group or not group["params"]: - continue - params = group["params"] - total_params = sum(p.numel() for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - group["updates_flat"] = updates_flat - group["param_views"] = [] - group["rank_params"] = [] - group["rank_param_views"] = [] - curr = 0 - for i, p in enumerate(params): - numel = p.numel() - view = updates_flat[curr : curr + numel].view_as(p) - group["param_views"].append(view) - if i % world_size == rank: - group["rank_params"].append(p) - group["rank_param_views"].append(view) - curr += numel - @torch.no_grad() - def step(self, closure=None): - loss = closure() if closure is not None else None - self._init_state() - for group in self.param_groups: - if not group["params"]: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - updates_flat = group["updates_flat"] - updates_flat.zero_() - for p, view in zip(group["rank_params"], group["rank_param_views"]): - if p.grad is None: - continue - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - u = g.add(buf, alpha=momentum) if nesterov else buf - original_shape = u.shape - if u.ndim > 2: - u = u.view(u.size(0), -1) - u_prepared = u.to(dtype=torch.bfloat16, memory_format=torch.contiguous_format, non_blocking=True) - g_ns = zeropower_via_newtonschulz5(u_prepared, steps=backend_steps) - g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) - view.copy_(g_ns.view(original_shape)) - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - params = group["params"] - views = group["param_views"] - if updates_flat.dtype != params[0].dtype: - views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] - torch._foreach_add_(params, views, alpha=-lr) - return loss + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", - ).split(",") - if pattern + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern ) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern ) INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) + return int(t.numel()) * int(t.element_size()) + def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = self.future_tokens.result() - self.pos = 0 - self.future_tokens = self._queue_next_file() - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + # Reads shards sequentially and wraps around forever. The training loop therefore + # has deterministic, simple streaming behavior with no sampling or workers. + def __init__(self, pattern: str): + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.tokens = load_data_shard(self.files[0]) + self.pos = 0 + + def _advance_file(self) -> None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - self.next_x, self.next_y = None, None - self.transfer_stream = torch.cuda.Stream(device) - def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): - with torch.cuda.stream(self.transfer_stream): - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].pin_memory().to( - device=self.device, dtype=torch.int32, non_blocking=True - ) - self.next_x = local[:-1].reshape(-1, seq_len) - self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self.next_x is None: - self.preload(global_tokens, seq_len, grad_accum_steps) - torch.cuda.current_stream().wait_stream(self.transfer_stream) - x, y = self.next_x, self.next_y - self.preload(global_tokens, seq_len, grad_accum_steps) - return x, y + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): - super().__init__() - self.rotary_dim = (int(dim * p) // 8) * 8 - inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) - t = torch.arange(max_seq_len, dtype=torch.float32) - freqs = torch.outer(t, inv_freq) - self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) - self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) - def forward(self, x: Tensor): - t = x.size(2) - return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) -def rotate_half(x: Tensor): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - d = cos.shape[-1] * 2 - x_rop = x[..., :d] - x_pass = x[..., d:] - x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) - x0 = x_rop[..., 0] - x1 = x_rop[..., 1] - res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) - return torch.cat([res.flatten(-2), x_pass], dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): - super().__init__() - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.kv_dim = self.num_kv_heads * self.head_dim - self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) - self.proj = nn.Linear(dim, dim, bias=False) - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) - self.use_rope = use_rope - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv = self.c_qkv(x) - q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - if self.use_rope: - cos, sin = self.rotary(q) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) - y = F.scaled_dot_product_attention( - q, k, v, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads) - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 - self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) - self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) - def forward(self, x: Tensor) -> Tensor: - fused_x = self.c_fc(x) - gate, value = fused_x.chunk(2, dim=-1) - return self.c_proj(F.silu(gate) * value) + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - rope_base: float, - qk_gain_init: float, - seq_len: int=1024, - use_rope: bool=True, - rope_proportion: float=0.5, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) - self.mlp = MLP(dim, mlp_mult) - def forward(self, x: Tensor) -> Tensor: - x = x + self.attn(self.attn_norm(x)) - x = x + self.mlp(self.mlp_norm(x)) - return x -def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): - min_kv = 1 - max_kv = num_heads - fraction = layer_idx / (total_layers - 1) - raw_kv = min_kv + (max_kv - min_kv) * fraction - valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] - kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) - return int(max(min_kv, kv_heads)) -def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: - if num_layers <= 1: - return p_min - progress = i / (num_layers - 1) - scale = math.sin(progress * math.pi) - return p_min + (p_max - p_min) * scale -def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float, max_base: float) -> float: - if total_layers <= 1: - return max_base - fraction = layer_idx / (total_layers - 1) - return min_base * (max_base / min_base) ** fraction -def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: - min_mult = float(base_mult) * 0.75 - max_mult = float(base_mult) * 1.25 - fraction = layer_idx / max(1, total_layers - 1) - return min_mult + (max_mult - min_mult) * fraction + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_max_base: float, - rope_min_base: float, - qk_gain_init: float, - seq_len: int=1024, - ): - super().__init__() - self.tie_embeddings = tie_embeddings - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.blocks = nn.ModuleList([ - Block( - model_dim, - num_heads, - get_linear_progression_kv_heads(i, num_layers, num_kv_heads), - get_linear_progression_mlp_mult(i, num_layers, mlp_mult), - get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), - qk_gain_init, - seq_len=seq_len, - use_rope=True, - rope_proportion=get_rope_p_smooth(i, num_layers) - ) for i in range(num_layers) - ]) - self.final_norm = RMSNorm() - self.logit_softcap = logit_softcap - self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) - apply_zero_init(self, std=tied_embed_init_std) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - emb = self.tok_emb(input_ids) - x = F.rms_norm(emb, (emb.size(-1),)) - for block in self.blocks: - x = block(x) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) - return F.cross_entropy(logits, targets) + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - torch.set_float32_matmul_precision('high') - import torch._inductor.config as inductor_config - inductor_config.fx_graph_cache = True - inductor_config.triton.unique_kernel_names = True - inductor_config.freezing = True - inductor_config.shape_padding = True - inductor_config.epilogue_fusion = True - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - @torch.compiler.disable - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_max_base=args.rope_max_base, - rope_min_base=args.rope_min_base, - qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len - ).to(device).bfloat16() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, fullgraph=True, mode="reduce-overhead") - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - gpt_scalars = [p for name, p in base_model.named_parameters() if "skip" in name] - scalar_params.extend(gpt_scalars) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": main() \ No newline at end of file + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file From 31f2ebbb66a27382af9328de6e515ac1f5cf3b59 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:46:38 -0700 Subject: [PATCH 62/80] Get train_gpt and update readme --- .../README.md | 20 +- .../train_gpt.py | 941 ++++++++++++++++++ 2 files changed, 953 insertions(+), 8 deletions(-) create mode 100644 records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/train_gpt.py diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md index 6561ff1378..f967bcf6c3 100644 --- a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md @@ -2,21 +2,22 @@ *(Not in the traditional sense)* Between schoolwork and other commitments, this was definitely an on-and-off type of project. -- Prior to this, I had never trained any language models of any kind from scratch like this, let alone in a speedrunning context. -- I was planning on full-scale testing but had to wait until 4/30/2026 for a sufficient compute grant for full H100 testing. With only a day to really experiment, there wasn't much time for me to validate anything. +- Prior to this, I have never trained any language models of any kind, let alone in a speedrunning context. +- I was planning on full-scale testing but had to wait until 4/30/2026 for a sufficient compute grant for full H100 testing. With only a day to really experiment, there wasn't much time for me to validate anything or find very good ideas. + - Smaller-scale testing was done on Google Colab and the RTX 5090s provided by RunPod, but these were slow and it was hard to really know what worked when moving so slowly. There were many (frankly) strange implementations that I tried experimenting with with the help of AI. - Extreme depth/low dim - Low depth/high dim - Rolling data inside each batch (give the model one more chance at "harder" data) - Dynamic loss scaling -- You'll find more by going through my commit history. +- You'll find more by going through [my repo's](https://github.com/AlstonTang/parameter-golf) commit history. Although this probably isn't going to shatter any records (at all), I do hope that this at least shines an interesting light at some potential ideas that may be integrated into future GPT training/speedruns. -- Even if this ultimately goes nowhere, it would be nice if some of the ideas in this implementation were explored. Perhaps there's too much going on in this implementation, and that together they're conflicting each other. +- Even if this ultimately doesn't go very far alone, it would be nice if some of the ideas in this implementation were explored. Perhaps there's too much going on in this implementation, and that together they're conflicting each other. Or perhaps's it's merely a hyperparameter configuration away from getting solid results. My focus is not on Test-Time Training (TTT) or any implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. -- I may continue experiments even after this competition, since research isn't just one and done! +- I may continue experiments even after this competition, since research isn't just one and done! I may take some ideas from my implementation and iteratively add it to new language models I may train in the future as I both learn more about LLMs and advanced deep learning in general. ## ZerO initalization I was intersted in this paper: https://arxiv.org/pdf/2110.12661 @@ -29,6 +30,8 @@ Throughout each layer, I utilized a progression of KV head count, rope proportio 1. Earlier layers most likely focus on nearby context and shouldn't worry about long-range dependencies. 2. As tokens go further into the model, more information is going to be needed. +Whether these should all be scaled linearly, geometrically, or something else is a question for another day. + ## So, could it work? Well, maybe if there was more time (to experiment + training time) and a better implementation. - Of course, hyperparameter choice also plays a role. However, I did not have too much time to really test anything. @@ -39,9 +42,10 @@ Well, maybe if there was more time (to experiment + training time) and a better 3. Optimizing the implementation. If I wasn't constrained by the constraints, I would also test it on larger-scale models (e.g. 100 million+ parameters). -- These ideas may perform better at larger scales, but not so at such a small scale (~16M params). +- Perhaps the model was too small to really realize the potential gains of my proposed implementation. +- Going beyong 16 MB would be nice to see if my ideas could potentially fly! ## Why just one Result? Time... -- If I had more time, I would submit more results. -- The result should (hopefully) be reproducable across seeds due to the [ZerO intialization](#zero-initalization) being used with a very small deviation. \ No newline at end of file +- If I had more time, I would submit more results (and probably more refined ones). +- The result should (hopefully) be reproducable across seeds due to the more deterministic [ZerO intialization](#zero-initalization) being used. \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/train_gpt.py b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/train_gpt.py new file mode 100644 index 0000000000..09bcd6fcb5 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/train_gpt.py @@ -0,0 +1,941 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +import concurrent.futures +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 5000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2048)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 512)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 2.125)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_max_base = float(os.environ.get("ROPE_MAX_BASE", 8192.0)) + rope_min_base = float(os.environ.get("ROPE_MIN_BASE", 512.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.3)) + head_lr = float(os.environ.get("HEAD_LR", 0.04)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.125)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 256)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) +@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True) +def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): + X = torch.where(torch.isfinite(G), G, 0.0).to(torch.bfloat16) + norm = torch.linalg.matrix_norm(X) + eps + X = X / norm + transposed = X.size(0) > X.size(1) + if transposed: + X = X.T + a, b, c = 3.4445, -4.7750, 2.0315 + for _ in range(steps): + A = X @ X.T + B = torch.addmm(A, A, A, beta=b, alpha=c) + X = torch.addmm(X, B, X, beta=a, alpha=1.0) + return X.T if transposed else X +@torch.no_grad() +def get_hadamard_matrix(n, device): + p2 = 2**math.ceil(math.log2(n)) + H = torch.tensor([[1.0]], device=device) + while H.shape[0] < p2: + H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) + return H[:n, :n] / math.sqrt(p2) +@torch.no_grad() +def apply_zero_init(model, std=0.02): + for m in model.modules(): + if isinstance(m, nn.Embedding): + d_out, d_in = m.weight.shape + H = get_hadamard_matrix(max(d_out, d_in), m.weight.device) + m.weight.copy_(H[:d_out, :d_in]*std) + linears = [sub for sub in m.children() if isinstance(sub, nn.Linear)] + if linears: + for i, l in enumerate(linears): + d_out, d_in = l.weight.shape + if i == len(linears) - 1: + nn.init.normal_(l.weight, std=1e-5) + else: + H = get_hadamard_matrix(max(d_out, d_in), l.weight.device) + l.weight.copy_(H[:d_out, :d_in]) + if l.bias is not None: + nn.init.zeros_(l.bias) +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float = 0.95, backend_steps: int = 5, nesterov: bool = True): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov) + super().__init__(params, defaults) + self._is_initialized = False + def _init_state(self): + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + if "updates_flat" in group or not group["params"]: + continue + params = group["params"] + total_params = sum(p.numel() for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + group["updates_flat"] = updates_flat + group["param_views"] = [] + group["rank_params"] = [] + group["rank_param_views"] = [] + curr = 0 + for i, p in enumerate(params): + numel = p.numel() + view = updates_flat[curr : curr + numel].view_as(p) + group["param_views"].append(view) + if i % world_size == rank: + group["rank_params"].append(p) + group["rank_param_views"].append(view) + curr += numel + @torch.no_grad() + def step(self, closure=None): + loss = closure() if closure is not None else None + self._init_state() + for group in self.param_groups: + if not group["params"]: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + updates_flat = group["updates_flat"] + updates_flat.zero_() + for p, view in zip(group["rank_params"], group["rank_param_views"]): + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + u = g.add(buf, alpha=momentum) if nesterov else buf + original_shape = u.shape + if u.ndim > 2: + u = u.view(u.size(0), -1) + u_prepared = u.to(dtype=torch.bfloat16, memory_format=torch.contiguous_format, non_blocking=True) + g_ns = zeropower_via_newtonschulz5(u_prepared, steps=backend_steps) + g_ns.mul_(max(1, g_ns.size(0) / g_ns.size(1))**0.5) + view.copy_(g_ns.view(original_shape)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + params = group["params"] + views = group["param_views"] + if updates_flat.dtype != params[0].dtype: + views = [v.to(dtype=p.dtype, non_blocking=True) for v, p in zip(views, params)] + torch._foreach_add_(params, views, alpha=-lr) + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = self.future_tokens.result() + self.pos = 0 + self.future_tokens = self._queue_next_file() + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + self.next_x, self.next_y = None, None + self.transfer_stream = torch.cuda.Stream(device) + def preload(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + with torch.cuda.stream(self.transfer_stream): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].pin_memory().to( + device=self.device, dtype=torch.int32, non_blocking=True + ) + self.next_x = local[:-1].reshape(-1, seq_len) + self.next_y = local[1:].reshape(-1, seq_len).to(torch.long) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self.next_x is None: + self.preload(global_tokens, seq_len, grad_accum_steps) + torch.cuda.current_stream().wait_stream(self.transfer_stream) + x, y = self.next_x, self.next_y + self.preload(global_tokens, seq_len, grad_accum_steps) + return x, y +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int, p: float = 0.5, base: float = 10000.0): + super().__init__() + self.rotary_dim = (int(dim * p) // 8) * 8 + inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("cos", freqs.cos().view(1, 1, max_seq_len, -1), persistent=False) + self.register_buffer("sin", freqs.sin().view(1, 1, max_seq_len, -1), persistent=False) + def forward(self, x: Tensor): + t = x.size(2) + return self.cos[:, :, :t, :].to(x.dtype), self.sin[:, :, :t, :].to(x.dtype) +def rotate_half(x: Tensor): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + d = cos.shape[-1] * 2 + x_rop = x[..., :d] + x_pass = x[..., d:] + x_rop = x_rop.view(*x_rop.shape[:-1], -1, 2) + x0 = x_rop[..., 0] + x1 = x_rop[..., 1] + res = torch.stack([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) + return torch.cat([res.flatten(-2), x_pass], dim=-1) +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len=1024, use_rope=True, rope_proportion=0.5): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.c_qkv = nn.Linear(dim, dim + 2 * self.kv_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, max_seq_len=seq_len, p=rope_proportion, base=rope_base) + self.use_rope = use_rope + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv = self.c_qkv(x) + q, k, v = qkv.split([dim, self.kv_dim, self.kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + if self.use_rope: + cos, sin = self.rotary(q) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.view(1, -1, 1, 1).to(q.dtype) + y = F.scaled_dot_product_attention( + q, k, v, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads) + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + self.hidden_dim = int(mlp_mult * 2/3 * dim // 64) * 64 + self.c_fc = nn.Linear(dim, 2 * self.hidden_dim, bias=False) + self.c_proj = nn.Linear(self.hidden_dim, dim, bias=False) + def forward(self, x: Tensor) -> Tensor: + fused_x = self.c_fc(x) + gate, value = fused_x.chunk(2, dim=-1) + return self.c_proj(F.silu(gate) * value) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + seq_len: int=1024, + use_rope: bool=True, + rope_proportion: float=0.5, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, seq_len, use_rope, rope_proportion) + self.mlp = MLP(dim, mlp_mult) + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.attn_norm(x)) + x = x + self.mlp(self.mlp_norm(x)) + return x +def get_linear_progression_kv_heads(layer_idx, total_layers, num_heads): + min_kv = 1 + max_kv = num_heads + fraction = layer_idx / (total_layers - 1) + raw_kv = min_kv + (max_kv - min_kv) * fraction + valid_kvs = [i for i in range(1, num_heads + 1) if num_heads % i == 0] + kv_heads = min(valid_kvs, key=lambda x: abs(x - raw_kv)) + return int(max(min_kv, kv_heads)) +def get_rope_p_smooth(i: int, num_layers: int, p_min=0.25, p_max=0.75) -> float: + if num_layers <= 1: + return p_min + progress = i / (num_layers - 1) + scale = math.sin(progress * math.pi) + return p_min + (p_max - p_min) * scale +def get_rope_base_progression(layer_idx: int, total_layers: int, min_base: float, max_base: float) -> float: + if total_layers <= 1: + return max_base + fraction = layer_idx / (total_layers - 1) + return min_base * (max_base / min_base) ** fraction +def get_linear_progression_mlp_mult(layer_idx: int, total_layers: int, base_mult: int) -> float: + min_mult = float(base_mult) * 0.75 + max_mult = float(base_mult) * 1.25 + fraction = layer_idx / max(1, total_layers - 1) + return min_mult + (max_mult - min_mult) * fraction +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_max_base: float, + rope_min_base: float, + qk_gain_init: float, + seq_len: int=1024, + ): + super().__init__() + self.tie_embeddings = tie_embeddings + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.blocks = nn.ModuleList([ + Block( + model_dim, + num_heads, + get_linear_progression_kv_heads(i, num_layers, num_kv_heads), + get_linear_progression_mlp_mult(i, num_layers, mlp_mult), + get_rope_base_progression(i, num_layers, rope_min_base, rope_max_base), + qk_gain_init, + seq_len=seq_len, + use_rope=True, + rope_proportion=get_rope_p_smooth(i, num_layers) + ) for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.logit_softcap = logit_softcap + self.lm_head = None if tie_embeddings else nn.Linear(model_dim, vocab_size, bias=False) + apply_zero_init(self, std=tied_embed_init_std) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + emb = self.tok_emb(input_ids) + x = F.rms_norm(emb, (emb.size(-1),)) + for block in self.blocks: + x = block(x) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits.float() / self.logit_softcap) + return F.cross_entropy(logits, targets) +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + import torch._inductor.config as inductor_config + inductor_config.fx_graph_cache = True + inductor_config.triton.unique_kernel_names = True + inductor_config.freezing = True + inductor_config.shape_padding = True + inductor_config.epilogue_fusion = True + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + @torch.compiler.disable + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_max_base=args.rope_max_base, + rope_min_base=args.rope_min_base, + qk_gain_init=args.qk_gain_init, + seq_len=args.train_seq_len + ).to(device).bfloat16() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, fullgraph=True, mode="reduce-overhead") + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + gpt_scalars = [p for name, p in base_model.named_parameters() if "skip" in name] + scalar_params.extend(gpt_scalars) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": main() \ No newline at end of file From 4221127e8cd3a3b1ec655b77fcf18d386b84575e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:50:44 -0700 Subject: [PATCH 63/80] Delete check.py --- check.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 check.py diff --git a/check.py b/check.py deleted file mode 100644 index 1e3c6490c6..0000000000 --- a/check.py +++ /dev/null @@ -1,17 +0,0 @@ -from train_gpt import Hyperparameters, GPT -args = Hyperparameters() -base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - seq_len=args.train_seq_len - ) - -print(sum([p.numel() for p in base_model.parameters()])) \ No newline at end of file From 9a2be884aa0ff8e2c5f039215810ceb674351151 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:52:09 -0700 Subject: [PATCH 64/80] Rectify trailing line in original train_gpt.py --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index 7f5125ac56..651beb2b89 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1123,4 +1123,4 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if __name__ == "__main__": - main() \ No newline at end of file + main() From 3f4fc4540e8989061c529f1e614171a9c219eefe Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:54:31 -0700 Subject: [PATCH 65/80] Update README.md --- .../README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md index f967bcf6c3..2c45c28c89 100644 --- a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md @@ -16,8 +16,9 @@ There were many (frankly) strange implementations that I tried experimenting wit Although this probably isn't going to shatter any records (at all), I do hope that this at least shines an interesting light at some potential ideas that may be integrated into future GPT training/speedruns. - Even if this ultimately doesn't go very far alone, it would be nice if some of the ideas in this implementation were explored. Perhaps there's too much going on in this implementation, and that together they're conflicting each other. Or perhaps's it's merely a hyperparameter configuration away from getting solid results. -My focus is not on Test-Time Training (TTT) or any implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. +My focus is not on Test-Time Training (TTT) or any significant implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. - I may continue experiments even after this competition, since research isn't just one and done! I may take some ideas from my implementation and iteratively add it to new language models I may train in the future as I both learn more about LLMs and advanced deep learning in general. +- Some commits may show some significant implementation-specific optimizations (e.g. my attempt at fp8 training), but these usually either failed or lead to training instability during experimentation. ## ZerO initalization I was intersted in this paper: https://arxiv.org/pdf/2110.12661 @@ -48,4 +49,4 @@ If I wasn't constrained by the constraints, I would also test it on larger-scale ## Why just one Result? Time... - If I had more time, I would submit more results (and probably more refined ones). -- The result should (hopefully) be reproducable across seeds due to the more deterministic [ZerO intialization](#zero-initalization) being used. \ No newline at end of file +- The result should (hopefully) be reproducable across seeds due to the more deterministic [ZerO intialization](#zero-initalization) being used. From 25e98564ed18b4a38dfb7894ae5920284cf7c130 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:59:36 -0700 Subject: [PATCH 66/80] Update README.md Minor clarifications and minor grammar fixes --- .../README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md index 2c45c28c89..3acfde1d34 100644 --- a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md @@ -16,9 +16,9 @@ There were many (frankly) strange implementations that I tried experimenting wit Although this probably isn't going to shatter any records (at all), I do hope that this at least shines an interesting light at some potential ideas that may be integrated into future GPT training/speedruns. - Even if this ultimately doesn't go very far alone, it would be nice if some of the ideas in this implementation were explored. Perhaps there's too much going on in this implementation, and that together they're conflicting each other. Or perhaps's it's merely a hyperparameter configuration away from getting solid results. -My focus is not on Test-Time Training (TTT) or any significant implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. +My focus ended up not being on Test-Time Training (TTT) or any significant implementation-specifc optimizations (e.g. fp8 training). Rather, my focus was on the underlying architecture itself, and really (trying towards) pushing the limits of what a conventional transformer can do. - I may continue experiments even after this competition, since research isn't just one and done! I may take some ideas from my implementation and iteratively add it to new language models I may train in the future as I both learn more about LLMs and advanced deep learning in general. -- Some commits may show some significant implementation-specific optimizations (e.g. my attempt at fp8 training), but these usually either failed or lead to training instability during experimentation. +- Some commits may show attempts implementation-specific optimizations (e.g. my attempt at fp8 training), but these usually either failed or led to training instability during experimentation. ## ZerO initalization I was intersted in this paper: https://arxiv.org/pdf/2110.12661 From fef7edc8c96ce169f31754f8deae1334a76e0fba Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 08:49:20 -0700 Subject: [PATCH 67/80] Touch up README.md --- .../README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md index 3acfde1d34..e90d0680f2 100644 --- a/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md +++ b/records/track_non_record_16mb/2026-04-30_ZerO _init+progressive_kv_rope_and_mlp_mult/README.md @@ -21,17 +21,17 @@ My focus ended up not being on Test-Time Training (TTT) or any significant imple - Some commits may show attempts implementation-specific optimizations (e.g. my attempt at fp8 training), but these usually either failed or led to training instability during experimentation. ## ZerO initalization -I was intersted in this paper: https://arxiv.org/pdf/2110.12661 -- Zhao et. al. describes how performance of deep networks can be both better and more reproducable through a more deterministic initialization method involving Hadaramard/Identity-like matrices +I was interested in this paper: https://arxiv.org/pdf/2110.12661 +- Zhao et. al. describes how performance of deep networks can be both better and more reproducable through a more deterministic initialization method involving Hadamard/Identity-like matrices -The actual usage of this initialization does involve some level of non-determinism, but it's less pronounced than fully random initalization. +The actual usage of this initialization does involve some level of non-determinism (and likely somewhat deviates from the original algorithm), but it's less pronounced than fully random initalization. ## Progression of Various Model Hyperparameters -Throughout each layer, I utilized a progression of KV head count, rope proportion, and MLP multiple, all of which increase as the layer's depth w.r.t. the model increases. The rationale is as follows: +Throughout each layer, I utilized a progression of KV head count, rope proportion, and MLP expansion factor, all of which increase as the layer's depth within the model increases. The rationale is as follows: 1. Earlier layers most likely focus on nearby context and shouldn't worry about long-range dependencies. 2. As tokens go further into the model, more information is going to be needed. -Whether these should all be scaled linearly, geometrically, or something else is a question for another day. +Whether these should all be scaled linearly, geometrically, inversely, or something else is a question for another day. ## So, could it work? Well, maybe if there was more time (to experiment + training time) and a better implementation. From 1ca050366fe5f7dc2d8080f267ec2867f3b4c2af Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 08:52:53 -0700 Subject: [PATCH 68/80] Utilize the more original version of ZerO --- train_gpt.py | 143 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..90637ab6e0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -497,6 +497,148 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> # TRANSFORMER MODULES # ----------------------------- +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers) + - Algorithm 2 (Convolution Layers) + - Section 4.1 Transformer specific initialization (W_Q=Identity, W_K/W_V=Zeros) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + # torch.eye intrinsically handles rectangular partial-identities. + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + @torch.no_grad() + def _init_mha_zero(module: nn.MultiheadAttention): + """Initializes PyTorch's native MultiheadAttention according to Sec 4.1""" + embed_dim = module.embed_dim + + # in_proj_weight groups Q, K, and V + if module.in_proj_weight is not None: + module.in_proj_weight.zero_() + # W_Q is identity + module.in_proj_weight[:embed_dim, :].copy_( + torch.eye(embed_dim, dtype=module.in_proj_weight.dtype, device=module.in_proj_weight.device) + ) + # W_K, W_V remain zero (as handled by .zero_()) + else: + # If instantiated with separate weights + if getattr(module, 'q_proj_weight', None) is not None: + module.q_proj_weight.copy_( + torch.eye(embed_dim, dtype=module.q_proj_weight.dtype, device=module.q_proj_weight.device) + ) + if getattr(module, 'k_proj_weight', None) is not None: + module.k_proj_weight.zero_() + if getattr(module, 'v_proj_weight', None) is not None: + module.v_proj_weight.zero_() + + if getattr(module, 'in_proj_bias', None) is not None: + module.in_proj_bias.zero_() + + # The output projection relies on the standard algorithm (P_l = Q_l) -> Identity + if getattr(module, 'out_proj', None) is not None and getattr(module.out_proj, 'weight', None) is not None: + _zero_init_tensor(module.out_proj.weight, module.out_proj.out_features, module.out_proj.in_features) + if getattr(module.out_proj, 'bias', None) is not None: + module.out_proj.bias.zero_() + + @torch.no_grad() + def _init_conv_zero(module): + """Initializes Convolutions according to Algorithm 2""" + out_c, in_c = module.out_channels, module.in_channels + module.weight.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: # Conv1D + module.weight[:, :, centers[0]] = block + elif len(centers) == 2: # Conv2D + module.weight[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: # Conv3D + module.weight[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.zero_() + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Transformers MultiheadAttention (PyTorch Native MHA) + if isinstance(module, nn.MultiheadAttention): + _init_mha_zero(module) + + # 2. Linear Layers (Applies to Feed-Forward Networks or HuggingFace Custom QKV Projections) + elif isinstance(module, nn.Linear): + name_lower = name.split('.')[-1].lower() + + # W_K, W_V at zero (Fallback heuristics for custom attention modules) + if any(k in name_lower for k in ['k_proj', 'v_proj', 'key', 'value']): + module.weight.data.zero_() + + # W_Q as identity (Fallback heuristics for custom attention modules) + elif any(q in name_lower for q in ['q_proj', 'query']): + module.weight.data.copy_( + torch.eye(module.out_features, module.in_features, dtype=module.weight.dtype, device=module.weight.device) + ) + + # Generic linear layer (e.g., Dimension expanding/shrinking FFN) + else: + _zero_init_tensor(module.weight, module.out_features, module.in_features) + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + _init_conv_zero(module) + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if getattr(module, 'weight', None) is not None: + module.weight.data.fill_(1.0) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + return model + class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() @@ -836,6 +978,7 @@ def log0(msg: str, console: bool = True) -> None: rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() + base_model = apply_zero_init(base_model) for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() From 22d8e519a40352c78b55bdfa2ac736fa61e8f74c Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 08:56:38 -0700 Subject: [PATCH 69/80] Compilation flags --- train_gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 90637ab6e0..1f495beffb 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -875,7 +875,7 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True, mode="reduce-overhead") # ----------------------------- # DISTRIBUTED + CUDA SETUP @@ -983,7 +983,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True, mode="reduce-overhead") model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: From 5685ab3893cd13d8c9539ac35f1d3a8a4e6ed393 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:07:02 -0700 Subject: [PATCH 70/80] No norms? --- train_gpt.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 1f495beffb..53adea7236 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -47,13 +47,13 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 2000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -639,14 +639,6 @@ def _init_conv_zero(module): return model -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - class CastedLinear(nn.Linear): # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. @@ -727,8 +719,6 @@ def forward(self, x: Tensor) -> Tensor: q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) @@ -770,8 +760,6 @@ def __init__( qk_gain_init: float, ): super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -781,9 +769,9 @@ def __init__( def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + attn_out = self.attn(x) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(x) return x @@ -826,7 +814,6 @@ def __init__( for i in range(num_layers) ] ) - self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True @@ -841,7 +828,6 @@ def _init_weights(self) -> None: def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] @@ -854,7 +840,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) + x = x.reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) From 93b44fb65a9a9ca2f066702175689c71a2d6daf3 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:14:01 -0700 Subject: [PATCH 71/80] Revert no norms since instability --- train_gpt.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 53adea7236..7a45bf8142 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -47,13 +47,13 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 2000)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -75,7 +75,7 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) @@ -639,6 +639,14 @@ def _init_conv_zero(module): return model +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + class CastedLinear(nn.Linear): # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. @@ -719,6 +727,8 @@ def forward(self, x: Tensor) -> Tensor: q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) @@ -760,6 +770,8 @@ def __init__( qk_gain_init: float, ): super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -769,9 +781,9 @@ def __init__( def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(x) + attn_out = self.attn(self.attn_norm(x)) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(x) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -814,6 +826,7 @@ def __init__( for i in range(num_layers) ] ) + self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True @@ -828,6 +841,7 @@ def _init_weights(self) -> None: def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] @@ -840,7 +854,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() x = self.blocks[self.num_encoder_layers + i](x, x0) - x = x.reshape(-1, x.size(-1)) + x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) From 014c14c5f7a3a161329b6965775f3c6a938b9eb1 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:16:43 -0700 Subject: [PATCH 72/80] Can't reduce overhead either --- train_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 7a45bf8142..07c8829d36 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,7 +53,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -875,7 +875,7 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True, mode="reduce-overhead") + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) # ----------------------------- # DISTRIBUTED + CUDA SETUP @@ -983,7 +983,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True, mode="reduce-overhead") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: From a93e1622e56a28b35fafe893ec9d349cdae8454c Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:33:34 -0700 Subject: [PATCH 73/80] Fix up initialization function --- train_gpt.py | 130 +++++++++++++++++---------------------------------- 1 file changed, 42 insertions(+), 88 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 07c8829d36..b143c24c67 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -502,10 +502,9 @@ def apply_zero_init(model: nn.Module) -> nn.Module: Applies ZerO initialization to a given PyTorch nn.Module in-place. This function implements: - - Algorithm 1 (Linear/FFN Layers) + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) - Algorithm 2 (Convolution Layers) - - Section 4.1 Transformer specific initialization (W_Q=Identity, W_K/W_V=Zeros) - + Args: model (nn.Module): The PyTorch model to initialize. @@ -520,11 +519,10 @@ def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): if out_f <= in_f: # P_l == Q_l: Identity mapping - # P_l < Q_l: Partial identity (Propagates first P_l dimensions) - # torch.eye intrinsically handles rectangular partial-identities. + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) else: - # P_l > Q_l: Hadamard mapping + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) m = math.ceil(math.log2(out_f)) m = max(m, 1) # Ensure at least m=1 @@ -541,100 +539,56 @@ def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): # Apply scaled Hadamard top-left submatrix (I* H_m I*) tensor.copy_(c * H[:out_f, :in_f]) - @torch.no_grad() - def _init_mha_zero(module: nn.MultiheadAttention): - """Initializes PyTorch's native MultiheadAttention according to Sec 4.1""" - embed_dim = module.embed_dim - - # in_proj_weight groups Q, K, and V - if module.in_proj_weight is not None: - module.in_proj_weight.zero_() - # W_Q is identity - module.in_proj_weight[:embed_dim, :].copy_( - torch.eye(embed_dim, dtype=module.in_proj_weight.dtype, device=module.in_proj_weight.device) - ) - # W_K, W_V remain zero (as handled by .zero_()) - else: - # If instantiated with separate weights - if getattr(module, 'q_proj_weight', None) is not None: - module.q_proj_weight.copy_( - torch.eye(embed_dim, dtype=module.q_proj_weight.dtype, device=module.q_proj_weight.device) - ) - if getattr(module, 'k_proj_weight', None) is not None: - module.k_proj_weight.zero_() - if getattr(module, 'v_proj_weight', None) is not None: - module.v_proj_weight.zero_() - - if getattr(module, 'in_proj_bias', None) is not None: - module.in_proj_bias.zero_() - - # The output projection relies on the standard algorithm (P_l = Q_l) -> Identity - if getattr(module, 'out_proj', None) is not None and getattr(module.out_proj, 'weight', None) is not None: - _zero_init_tensor(module.out_proj.weight, module.out_proj.out_features, module.out_proj.in_features) - if getattr(module.out_proj, 'bias', None) is not None: - module.out_proj.bias.zero_() - - @torch.no_grad() - def _init_conv_zero(module): - """Initializes Convolutions according to Algorithm 2""" - out_c, in_c = module.out_channels, module.in_channels - module.weight.zero_() - - # Find center index, taking n <- floor(k / 2) for each dimension - centers = tuple(k // 2 for k in module.kernel_size) - - # Create an out_c x in_c block utilizing the base ZerO logic - block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) - _zero_init_tensor(block, out_c, in_c) - - # Place block in the center slice of the kernel - if len(centers) == 1: # Conv1D - module.weight[:, :, centers[0]] = block - elif len(centers) == 2: # Conv2D - module.weight[:, :, centers[0], centers[1]] = block - elif len(centers) == 3: # Conv3D - module.weight[:, :, centers[0], centers[1], centers[2]] = block - - if getattr(module, 'bias', None) is not None: - module.bias.zero_() - # Iterate through and parse all network submodules for name, module in model.named_modules(): - # 1. Transformers MultiheadAttention (PyTorch Native MHA) - if isinstance(module, nn.MultiheadAttention): - _init_mha_zero(module) - - # 2. Linear Layers (Applies to Feed-Forward Networks or HuggingFace Custom QKV Projections) - elif isinstance(module, nn.Linear): - name_lower = name.split('.')[-1].lower() - - # W_K, W_V at zero (Fallback heuristics for custom attention modules) - if any(k in name_lower for k in ['k_proj', 'v_proj', 'key', 'value']): + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: module.weight.data.zero_() - - # W_Q as identity (Fallback heuristics for custom attention modules) - elif any(q in name_lower for q in ['q_proj', 'query']): - module.weight.data.copy_( - torch.eye(module.out_features, module.in_features, dtype=module.weight.dtype, device=module.weight.device) - ) - - # Generic linear layer (e.g., Dimension expanding/shrinking FFN) - else: - _zero_init_tensor(module.weight, module.out_features, module.in_features) - + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) if getattr(module, 'bias', None) is not None: module.bias.data.zero_() - # 3. Convolutional Layers + # 3. Convolutional Layers (Implementation of Algorithm 2) elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - _init_conv_zero(module) + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - if getattr(module, 'weight', None) is not None: + if hasattr(module, 'weight') and module.weight is not None: module.weight.data.fill_(1.0) - if getattr(module, 'bias', None) is not None: + if hasattr(module, 'bias') and module.bias is not None: module.bias.data.zero_() return model From 02b4d377cebe53f52b47116c929bb085859dadc0 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:53:27 -0700 Subject: [PATCH 74/80] One more layer + adam --- train_gpt.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index b143c24c67..ff284e0fe0 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,7 +53,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -61,7 +61,7 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) @@ -75,7 +75,7 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.008)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) @@ -965,13 +965,15 @@ def log0(msg: str, console: bool = True) -> None: eps=args.adam_eps, fused=True, ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, + # Use standard Adam for the matrix params instead of Muon + adam_matrix_lr = 0.005 # 0.04 from Muon is way too high for Adam + optimizer_matrix = torch.optim.Adam( + [{"params": matrix_params, "lr": adam_matrix_lr, "base_lr": adam_matrix_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, ) - for group in optimizer_muon.param_groups: + for group in optimizer_matrix.param_groups: group["base_lr"] = args.matrix_lr optimizer_scalar = torch.optim.Adam( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], @@ -979,7 +981,7 @@ def log0(msg: str, console: bool = True) -> None: eps=args.adam_eps, fused=True, ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_matrix, optimizer_scalar] if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], @@ -1115,10 +1117,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float: (loss * grad_scale).backward() train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum + #frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + #muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + #for group in optimizer_muon.param_groups: + # group["momentum"] = muon_momentum for opt in optimizers: for group in opt.param_groups: From 41a38abc8e0e8502db96f75c99ab2f1ae85ce7a1 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 09:59:05 -0700 Subject: [PATCH 75/80] Restore muon --- train_gpt.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ff284e0fe0..56934a1dcb 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -75,7 +75,7 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.008)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) @@ -965,15 +965,13 @@ def log0(msg: str, console: bool = True) -> None: eps=args.adam_eps, fused=True, ) - # Use standard Adam for the matrix params instead of Muon - adam_matrix_lr = 0.005 # 0.04 from Muon is way too high for Adam - optimizer_matrix = torch.optim.Adam( - [{"params": matrix_params, "lr": adam_matrix_lr, "base_lr": adam_matrix_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, ) - for group in optimizer_matrix.param_groups: + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr optimizer_scalar = torch.optim.Adam( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], @@ -981,7 +979,7 @@ def log0(msg: str, console: bool = True) -> None: eps=args.adam_eps, fused=True, ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_matrix, optimizer_scalar] + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], @@ -1117,10 +1115,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float: (loss * grad_scale).backward() train_loss /= grad_accum_steps - #frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - #muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - #for group in optimizer_muon.param_groups: - # group["momentum"] = muon_momentum + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum for opt in optimizers: for group in opt.param_groups: From 907f5bfce074cccf1f5175b9093347affb701323 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 11:27:30 -0700 Subject: [PATCH 76/80] Gear up for another submission --- .../2026-05-01_Follow_up_to_PR_2104/README.md | 60 + .../requirements.txt | 10 + .../submission.json | 11 + .../train_gpt.py | 1223 +++++++++++++++++ train_gpt.py | 105 +- 5 files changed, 1308 insertions(+), 101 deletions(-) create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/requirements.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/submission.json create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/train_gpt.py diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md new file mode 100644 index 0000000000..70797caa78 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md @@ -0,0 +1,60 @@ +# A Follow Up To PR 2104 +The original writeup in [PR 2104](https://github.com/AlstonTang/parameter-golf/blob/fef7edc8c96ce169f31754f8deae1334a76e0fba/records/track_non_record_16mb/2026-04-30_ZerO%20_init%2Bprogressive_kv_rope_and_mlp_mult/README.md) was written semi-hastily in order to try and submit it in time with the deadline. However, after a bit more experimentation, there were some surprising results! +- The ZerO implementation from generative AI used in PR 2104 was flawed. + +In this submission I dump various training logs, plus the final train_gpt.py used in the final run. +- Keep in mind that each training log has a copy of train_gpt.py used during the run. + +I do not explore whether progression of various hyperparameters within each layer will have worked. +- for more details + implementation, see the [pull request](https://github.com/openai/parameter-golf/pull/2104). + +## Correcting ZerO +Previously, I had generative AI try to implement the [ZerO](https://arxiv.org/pdf/2110.12661) initialization function. However, due to issues in prompting and visibility of the document (a URL was sent instead of the underlying PDF document), the model most likely hallucinated the implementation. +- It did take inspiration from the usage of Hadamard matrices, though this could also have been a hallucination or perhaps was not fully memorized from the training corpus that the prompted LLM went through. +- However, due to various other ideas being tested concurrently, validation of this implementation was largely overlooked during the challenge. + - Admittedly, it isn't a really good idea to test too many things at once. + +However, post-challenge, I wanted to see what ZerO could really do given a fresh start. I had previously suspected that the implementation was perhaps somewhat off due to non-determinisism being used, but I hadn't really thought of generating a fresh implementation until now. + +When feeding generative AI the actual PDF, it was able to more accurately generate a valid implementation of ZerO. +- The implementation used in this submission does include usage beyond transformers (e.g. Convolutions). +- The implementation also takes into account the existing transformer implementation in the original train_gpt.py. + +Note that, although I tried to set $W_k$ and $W_v$ as zeroed matrices, since the projection after attention was also zeroed out due to the existing implementation, the model trained very poorly. +- Hence, there is still a slight deviation from zero, where $W_q$ + +## Reimplementation +The implementation is based on the original train_gpt.py, not the current one in the pull request. +- This means that we can more accurately see whether or not ZerO works and is doing its thing instead of it being potentially masked by the other hypotheses concurrently tested within the submitted train_gpt.py in PR 2104. + +## So was ZerO the Bottleneck? +It would seem so, at least for GPT speedrunning. Although more testing is needed, when reading through section 4.1, you will see that the standard initialization actually beats out ZerO with a small number of layers. +- Perhaps ZerO may perform better with more layers and/or more training time. + +Interesting things to note are that: +1. Experimentation of ZerO within the paper was primarily focused on Convolutional Neural Networks (specifically ResNet). +2. Much of the paper focuses on the math instead of emperical testing. This could explain why the initialization makes the model slightly worse despite sounding better on paper. + +## Experimentation +Without the constant knowledge of having to get something submitted, this went surprisingly smoothly +- Less focus on tuning hyperparameters to get something in one day, more focus on getting things right. + +With a much more accurate and deterministic intialization, it could be likely that one training run is all that is needed when testing hypotheses moving forward. + +Initially, there was one invalid implementation due to the naming conventions used in the train_gpt.py file. This was resolved in later training runs, but for convenience, the following list shows the names of log(s) containing invalid implementations: +- e31e596e-e21a-4c48-829c-78233c992cc8.txt +- fe927335-4827-415a-b543-8b5d2706de4c.txt + +These are still included so that you can kind of see the experiemntation process. All other logs contain the correct implementation of ZerO. + +## The surprise +As it turns out, due to the fundamental nature of ZerO, we actually get a much better compression result. + +Within logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt, although there are 17,059,912 parameters, the serialized model only reaches a size of 13,695,416 bytes. Although this could be partly attributed to a higher bpb (suggesting that the model learned less), I strongly suspect that this is due to the low-rank learning trajectory that the model goes through within ZerO initialization. Because of this, the model is naturally more-easily compressible. + +Adding one more layer (so that the model has 10 layers, 8 query heads, 4 kv heads, and a dim of 512 with mlp_mult of 2), and running it for the full 20,000 steps (see logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt), we get a final bpb of ~1.2494 with the submission size still only taking up 15,221,665 despite a parameter count of 18,897,488. + +## Future Work +My plans for this consist of the following: +1. Testing ZerO with more step counts +2. Investigating how to best alter hyperparameters using ZerO initialization diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/requirements.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/requirements.txt new file mode 100644 index 0000000000..911b0e52f0 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/submission.json b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/submission.json new file mode 100644 index 0000000000..279d5ba326 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/submission.json @@ -0,0 +1,11 @@ +{ + "track": "non_record_16mb", + "date": "2026-05-01", + "name": "Non-record: Corrected ZerO initialization + Follow up to PR 2104", + "author": "Alston Tang", + "github_id": "AlstonTang", + "val_bpb": 1.24938225, + "val_loss": 2.10952960, + "bytes_total": 15221665 + } + \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/train_gpt.py b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/train_gpt.py new file mode 100644 index 0000000000..1e111ea556 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/train_gpt.py @@ -0,0 +1,1223 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) + - Algorithm 2 (Convolution Layers) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.zero_() + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers (Implementation of Algorithm 2) + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 56934a1dcb..651beb2b89 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -53,7 +53,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -61,7 +61,7 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) @@ -75,7 +75,7 @@ class Hyperparameters: head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) @@ -497,102 +497,6 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> # TRANSFORMER MODULES # ----------------------------- -def apply_zero_init(model: nn.Module) -> nn.Module: - """ - Applies ZerO initialization to a given PyTorch nn.Module in-place. - - This function implements: - - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) - - Algorithm 2 (Convolution Layers) - - Args: - model (nn.Module): The PyTorch model to initialize. - - Returns: - nn.Module: The model with initialized weights. - """ - - @torch.no_grad() - def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): - """Core logic corresponding to Algorithm 1 in the paper.""" - tensor.zero_() - - if out_f <= in_f: - # P_l == Q_l: Identity mapping - # P_l < Q_l: Partial identity (Propagates first P_l dimensions) - tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) - else: - # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) - m = math.ceil(math.log2(out_f)) - m = max(m, 1) # Ensure at least m=1 - - # Recursively generate Hadamard matrix H_m - H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) - for _ in range(m): - H1 = torch.cat([H, H], dim=1) - H2 = torch.cat([H, -H], dim=1) - H = torch.cat([H1, H2], dim=0) - - # Normalization factor defined in the paper: c = 2^{-(m-1)/2} - c = 2.0 ** (-(m - 1) / 2.0) - - # Apply scaled Hadamard top-left submatrix (I* H_m I*) - tensor.copy_(c * H[:out_f, :in_f]) - - # Iterate through and parse all network submodules - for name, module in model.named_modules(): - - # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). - # This achieves the paper's goal of dynamical isometry (Identity residual pass) - # and avoids the gradient deadlock. - if getattr(module, '_zero_init', False): - if hasattr(module, 'weight') and module.weight is not None: - module.weight.data.zero_() - if hasattr(module, 'bias') and module.bias is not None: - module.bias.data.zero_() - continue - - # 2. Linear Layers (c_q, c_k, c_v, fc) - # c_q (512->512) becomes Identity. - # c_k/c_v (512->256) become Partial Identities. - # fc (512->1024) becomes a Hadamard transform. - if isinstance(module, nn.Linear): - _zero_init_tensor(module.weight.data, module.out_features, module.in_features) - if getattr(module, 'bias', None) is not None: - module.bias.data.zero_() - - # 3. Convolutional Layers (Implementation of Algorithm 2) - elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - out_c, in_c = module.out_channels, module.in_channels - module.weight.data.zero_() - - # Find center index, taking n <- floor(k / 2) for each dimension - centers = tuple(k // 2 for k in module.kernel_size) - - # Create an out_c x in_c block utilizing the base ZerO logic - block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) - _zero_init_tensor(block, out_c, in_c) - - # Place block in the center slice of the kernel - if len(centers) == 1: - module.weight.data[:, :, centers[0]] = block - elif len(centers) == 2: - module.weight.data[:, :, centers[0], centers[1]] = block - elif len(centers) == 3: - module.weight.data[:, :, centers[0], centers[1], centers[2]] = block - - if getattr(module, 'bias', None) is not None: - module.bias.data.zero_() - - # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) - elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - if hasattr(module, 'weight') and module.weight is not None: - module.weight.data.fill_(1.0) - if hasattr(module, 'bias') and module.bias is not None: - module.bias.data.zero_() - - return model - class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() @@ -829,7 +733,7 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP @@ -932,7 +836,6 @@ def log0(msg: str, console: bool = True) -> None: rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() - base_model = apply_zero_init(base_model) for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() From b6eabb81adc247984ed8858fda31e7897d447ee2 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 11:43:05 -0700 Subject: [PATCH 77/80] Update README.md --- .../2026-05-01_Follow_up_to_PR_2104/README.md | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md index 70797caa78..e131654451 100644 --- a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md @@ -1,6 +1,7 @@ -# A Follow Up To PR 2104 -The original writeup in [PR 2104](https://github.com/AlstonTang/parameter-golf/blob/fef7edc8c96ce169f31754f8deae1334a76e0fba/records/track_non_record_16mb/2026-04-30_ZerO%20_init%2Bprogressive_kv_rope_and_mlp_mult/README.md) was written semi-hastily in order to try and submit it in time with the deadline. However, after a bit more experimentation, there were some surprising results! +# Redoing ZerO and A Follow Up To PR 2104 +The original [writeup](https://github.com/AlstonTang/parameter-golf/blob/fef7edc8c96ce169f31754f8deae1334a76e0fba/records/track_non_record_16mb/2026-04-30_ZerO%20_init%2Bprogressive_kv_rope_and_mlp_mult/README.md) in [PR 2104](https://github.com/openai/parameter-golf/pull/2104) was written semi-hastily in order to at least submit something in time within the deadline. However, after a bit more experimentation, there were some surprising results! - The ZerO implementation from generative AI used in PR 2104 was flawed. +- If you just want to see results, either view logs or see [here](#the-surprise) for a high-level summary. In this submission I dump various training logs, plus the final train_gpt.py used in the final run. - Keep in mind that each training log has a copy of train_gpt.py used during the run. @@ -20,11 +21,11 @@ When feeding generative AI the actual PDF, it was able to more accurately genera - The implementation used in this submission does include usage beyond transformers (e.g. Convolutions). - The implementation also takes into account the existing transformer implementation in the original train_gpt.py. -Note that, although I tried to set $W_k$ and $W_v$ as zeroed matrices, since the projection after attention was also zeroed out due to the existing implementation, the model trained very poorly. -- Hence, there is still a slight deviation from zero, where $W_q$ +Note that, although I tried to set $W_k$ and $W_v$ as zeroed matrices as described in the paper, since the projection right after attention was also zeroed out due to the existing implementation, the model trained very poorly. +- Hence, there is still a slight deviation from the canonical ZerO implementation. ## Reimplementation -The implementation is based on the original train_gpt.py, not the current one in the pull request. +The implementation is based on the original train_gpt.py, not the submitted one in PR 2104. - This means that we can more accurately see whether or not ZerO works and is doing its thing instead of it being potentially masked by the other hypotheses concurrently tested within the submitted train_gpt.py in PR 2104. ## So was ZerO the Bottleneck? @@ -33,7 +34,9 @@ It would seem so, at least for GPT speedrunning. Although more testing is needed Interesting things to note are that: 1. Experimentation of ZerO within the paper was primarily focused on Convolutional Neural Networks (specifically ResNet). -2. Much of the paper focuses on the math instead of emperical testing. This could explain why the initialization makes the model slightly worse despite sounding better on paper. +2. Much of the paper focuses on the math (and proofs) instead of emperical testing. This could explain why the initialization makes the model slightly worse in practice despite sounding better on paper. + +I do have a [few plans](#future-work) to see if ZerO really is a bottleneck in a more realistic settimg beyond speedrunning. ## Experimentation Without the constant knowledge of having to get something submitted, this went surprisingly smoothly @@ -41,7 +44,7 @@ Without the constant knowledge of having to get something submitted, this went s With a much more accurate and deterministic intialization, it could be likely that one training run is all that is needed when testing hypotheses moving forward. -Initially, there was one invalid implementation due to the naming conventions used in the train_gpt.py file. This was resolved in later training runs, but for convenience, the following list shows the names of log(s) containing invalid implementations: +Initially, there was an invalid implementation due to the naming conventions used in the train_gpt.py file. This was resolved in later training runs, but for convenience, the following list shows the names of log(s) containing invalid implementations: - e31e596e-e21a-4c48-829c-78233c992cc8.txt - fe927335-4827-415a-b543-8b5d2706de4c.txt @@ -54,7 +57,10 @@ Within logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt, although there are 17,059, Adding one more layer (so that the model has 10 layers, 8 query heads, 4 kv heads, and a dim of 512 with mlp_mult of 2), and running it for the full 20,000 steps (see logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt), we get a final bpb of ~1.2494 with the submission size still only taking up 15,221,665 despite a parameter count of 18,897,488. +I initially thought that ZerO would at least yield better loss. Although it didn't, the improvements in compression were very surprising, and I'm interesting in further increasing parameter efficiency with ZerO. + ## Future Work My plans for this consist of the following: 1. Testing ZerO with more step counts 2. Investigating how to best alter hyperparameters using ZerO initialization +3. Using ZerO with much larger (and deeper!) models. From 3d1a1c152a029412bdb584b7e9b0d699e9563233 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 11:45:17 -0700 Subject: [PATCH 78/80] Update README.md --- .../2026-05-01_Follow_up_to_PR_2104/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md index e131654451..c8a758833a 100644 --- a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md @@ -11,7 +11,7 @@ I do not explore whether progression of various hyperparameters within each laye ## Correcting ZerO Previously, I had generative AI try to implement the [ZerO](https://arxiv.org/pdf/2110.12661) initialization function. However, due to issues in prompting and visibility of the document (a URL was sent instead of the underlying PDF document), the model most likely hallucinated the implementation. -- It did take inspiration from the usage of Hadamard matrices, though this could also have been a hallucination or perhaps was not fully memorized from the training corpus that the prompted LLM went through. +- It did take inspiration from the usage of Hadamard matrices, though this could also have been a hallucination or perhaps was not fully memorized from the training corpus that the LLM I prompted went through. - However, due to various other ideas being tested concurrently, validation of this implementation was largely overlooked during the challenge. - Admittedly, it isn't a really good idea to test too many things at once. From 48e24c633ddfb202b47eabafbdf8a11274a580f4 Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 12:06:07 -0700 Subject: [PATCH 79/80] Include logs and update readme.md --- .../2026-05-01_Follow_up_to_PR_2104/README.md | 6 +- .../24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt | 1476 +++++++++++++++++ .../3c25e790-2f8a-4c36-881c-67066cf1e465.txt | 1399 ++++++++++++++++ .../8c0bf12c-29d7-4dad-8cb7-6fa04b89b309.txt | 1431 ++++++++++++++++ .../bf9a8898-d4e8-40dd-8786-320b140eb700.txt | 1431 ++++++++++++++++ .../e31e596e-e21a-4c48-829c-78233c992cc8.txt | 1455 ++++++++++++++++ .../fe927335-4827-415a-b543-8b5d2706de4c.txt | 1445 ++++++++++++++++ 7 files changed, 8640 insertions(+), 3 deletions(-) create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/8c0bf12c-29d7-4dad-8cb7-6fa04b89b309.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/bf9a8898-d4e8-40dd-8786-320b140eb700.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/e31e596e-e21a-4c48-829c-78233c992cc8.txt create mode 100644 records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/fe927335-4827-415a-b543-8b5d2706de4c.txt diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md index c8a758833a..3b90491122 100644 --- a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md @@ -36,7 +36,7 @@ Interesting things to note are that: 1. Experimentation of ZerO within the paper was primarily focused on Convolutional Neural Networks (specifically ResNet). 2. Much of the paper focuses on the math (and proofs) instead of emperical testing. This could explain why the initialization makes the model slightly worse in practice despite sounding better on paper. -I do have a [few plans](#future-work) to see if ZerO really is a bottleneck in a more realistic settimg beyond speedrunning. +I do have a [few plans](#future-work) to see if ZerO really is a bottleneck in a more realistic setting beyond speedrunning. ## Experimentation Without the constant knowledge of having to get something submitted, this went surprisingly smoothly @@ -53,9 +53,9 @@ These are still included so that you can kind of see the experiemntation process ## The surprise As it turns out, due to the fundamental nature of ZerO, we actually get a much better compression result. -Within logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt, although there are 17,059,912 parameters, the serialized model only reaches a size of 13,695,416 bytes. Although this could be partly attributed to a higher bpb (suggesting that the model learned less), I strongly suspect that this is due to the low-rank learning trajectory that the model goes through within ZerO initialization. Because of this, the model is naturally more-easily compressible. +Within run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt, although the model contained 17,059,912 parameters, the serialized model only reaches a size of 13,695,416 bytes. Although this could be partly attributed to a higher bpb (suggesting that the model learned less within the 10 minutes of training time), I strongly suspect that this is due to the low-rank learning trajectory that the model goes through within ZerO initialization. Because of this, the model is naturally more compressible. -Adding one more layer (so that the model has 10 layers, 8 query heads, 4 kv heads, and a dim of 512 with mlp_mult of 2), and running it for the full 20,000 steps (see logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt), we get a final bpb of ~1.2494 with the submission size still only taking up 15,221,665 despite a parameter count of 18,897,488. +Adding one more layer (so that the model has 10 layers, 8 query heads, 4 kv heads, and a dim of 512 with mlp_mult of 2), and running it for the full 20,000 steps (see run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt), we get a final bpb of ~1.2494 with the submission size still only taking up 15,221,665 despite a parameter count of 18,897,488. I initially thought that ZerO would at least yield better loss. Although it didn't, the improvements in compression were very surprising, and I'm interesting in further increasing parameter efficiency with ZerO. diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt new file mode 100644 index 0000000000..fe4e1f95c0 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt @@ -0,0 +1,1476 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) + - Algorithm 2 (Convolution Layers) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.zero_() + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers (Implementation of Algorithm 2) + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 17:33:22 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:18897488 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:500 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:10/500 +warmup_step:20/500 +warmup_step:30/500 +warmup_step:40/500 +warmup_step:50/500 +warmup_step:60/500 +warmup_step:70/500 +warmup_step:80/500 +warmup_step:90/500 +warmup_step:100/500 +warmup_step:110/500 +warmup_step:120/500 +warmup_step:130/500 +warmup_step:140/500 +warmup_step:150/500 +warmup_step:160/500 +warmup_step:170/500 +warmup_step:180/500 +warmup_step:190/500 +warmup_step:200/500 +warmup_step:210/500 +warmup_step:220/500 +warmup_step:230/500 +warmup_step:240/500 +warmup_step:250/500 +warmup_step:260/500 +warmup_step:270/500 +warmup_step:280/500 +warmup_step:290/500 +warmup_step:300/500 +warmup_step:310/500 +warmup_step:320/500 +warmup_step:330/500 +warmup_step:340/500 +warmup_step:350/500 +warmup_step:360/500 +warmup_step:370/500 +warmup_step:380/500 +warmup_step:390/500 +warmup_step:400/500 +warmup_step:410/500 +warmup_step:420/500 +warmup_step:430/500 +warmup_step:440/500 +warmup_step:450/500 +warmup_step:460/500 +warmup_step:470/500 +warmup_step:480/500 +warmup_step:490/500 +warmup_step:500/500 +step:0/20000 val_loss:6.9363 val_bpb:4.1080 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9355 train_time:32ms step_avg:32.15ms +step:2/20000 train_loss:15.7807 train_time:80ms step_avg:40.10ms +step:3/20000 train_loss:7.3458 train_time:128ms step_avg:42.72ms +step:4/20000 train_loss:7.2066 train_time:176ms step_avg:43.97ms +step:5/20000 train_loss:7.2272 train_time:224ms step_avg:44.72ms +step:6/20000 train_loss:7.3731 train_time:271ms step_avg:45.23ms +step:7/20000 train_loss:6.4855 train_time:319ms step_avg:45.61ms +step:8/20000 train_loss:6.1886 train_time:367ms step_avg:45.90ms +step:9/20000 train_loss:5.9110 train_time:415ms step_avg:46.13ms +step:10/20000 train_loss:5.7178 train_time:463ms step_avg:46.30ms +step:200/20000 train_loss:2.9555 train_time:9534ms step_avg:47.67ms +step:400/20000 train_loss:2.4261 train_time:19087ms step_avg:47.72ms +step:600/20000 train_loss:2.5914 train_time:28659ms step_avg:47.76ms +step:800/20000 train_loss:2.3420 train_time:38249ms step_avg:47.81ms +step:1000/20000 train_loss:2.4166 train_time:47846ms step_avg:47.85ms +step:1000/20000 val_loss:2.3827 val_bpb:1.4112 train_time:47869ms step_avg:47.87ms +step:1200/20000 train_loss:2.4324 train_time:57449ms step_avg:47.87ms +step:1400/20000 train_loss:2.4740 train_time:67057ms step_avg:47.90ms +step:1600/20000 train_loss:2.1459 train_time:76658ms step_avg:47.91ms +step:1800/20000 train_loss:2.2476 train_time:86252ms step_avg:47.92ms +step:2000/20000 train_loss:2.3005 train_time:95845ms step_avg:47.92ms +step:2000/20000 val_loss:2.2831 val_bpb:1.3522 train_time:95869ms step_avg:47.93ms +step:2200/20000 train_loss:2.1276 train_time:105434ms step_avg:47.92ms +step:2400/20000 train_loss:2.2487 train_time:115017ms step_avg:47.92ms +step:2600/20000 train_loss:2.4535 train_time:124600ms step_avg:47.92ms +step:2800/20000 train_loss:2.2854 train_time:134172ms step_avg:47.92ms +step:3000/20000 train_loss:2.2749 train_time:143733ms step_avg:47.91ms +step:3000/20000 val_loss:2.2428 val_bpb:1.3283 train_time:143757ms step_avg:47.92ms +step:3200/20000 train_loss:2.2357 train_time:153292ms step_avg:47.90ms +step:3400/20000 train_loss:2.2073 train_time:162849ms step_avg:47.90ms +step:3600/20000 train_loss:2.1720 train_time:172407ms step_avg:47.89ms +step:3800/20000 train_loss:2.2673 train_time:181960ms step_avg:47.88ms +step:4000/20000 train_loss:2.2074 train_time:191513ms step_avg:47.88ms +step:4000/20000 val_loss:2.2175 val_bpb:1.3133 train_time:191537ms step_avg:47.88ms +step:4200/20000 train_loss:2.2220 train_time:201101ms step_avg:47.88ms +step:4400/20000 train_loss:2.1612 train_time:210646ms step_avg:47.87ms +step:4600/20000 train_loss:2.0193 train_time:220198ms step_avg:47.87ms +step:4800/20000 train_loss:2.3137 train_time:229750ms step_avg:47.86ms +step:5000/20000 train_loss:2.0826 train_time:239293ms step_avg:47.86ms +step:5000/20000 val_loss:2.2017 val_bpb:1.3040 train_time:239317ms step_avg:47.86ms +step:5200/20000 train_loss:2.2194 train_time:248831ms step_avg:47.85ms +step:5400/20000 train_loss:2.2337 train_time:258370ms step_avg:47.85ms +step:5600/20000 train_loss:2.2370 train_time:267904ms step_avg:47.84ms +step:5800/20000 train_loss:2.1881 train_time:277438ms step_avg:47.83ms +step:6000/20000 train_loss:2.2698 train_time:286976ms step_avg:47.83ms +step:6000/20000 val_loss:2.1927 val_bpb:1.2986 train_time:287000ms step_avg:47.83ms +step:6200/20000 train_loss:2.1367 train_time:296509ms step_avg:47.82ms +step:6400/20000 train_loss:2.2112 train_time:306038ms step_avg:47.82ms +step:6600/20000 train_loss:2.1733 train_time:315569ms step_avg:47.81ms +step:6800/20000 train_loss:2.2418 train_time:325104ms step_avg:47.81ms +step:7000/20000 train_loss:2.2779 train_time:334637ms step_avg:47.81ms +step:7000/20000 val_loss:2.1807 val_bpb:1.2915 train_time:334660ms step_avg:47.81ms +step:7200/20000 train_loss:2.2507 train_time:344211ms step_avg:47.81ms +step:7400/20000 train_loss:2.1676 train_time:353742ms step_avg:47.80ms +step:7600/20000 train_loss:2.0482 train_time:363268ms step_avg:47.80ms +step:7800/20000 train_loss:2.1976 train_time:372790ms step_avg:47.79ms +step:8000/20000 train_loss:2.1654 train_time:382312ms step_avg:47.79ms +step:8000/20000 val_loss:2.1717 val_bpb:1.2862 train_time:382336ms step_avg:47.79ms +step:8200/20000 train_loss:2.2361 train_time:391848ms step_avg:47.79ms +step:8400/20000 train_loss:2.1826 train_time:401416ms step_avg:47.79ms +step:8600/20000 train_loss:2.1872 train_time:410941ms step_avg:47.78ms +step:8800/20000 train_loss:2.1560 train_time:420468ms step_avg:47.78ms +step:9000/20000 train_loss:2.0747 train_time:429995ms step_avg:47.78ms +step:9000/20000 val_loss:2.1668 val_bpb:1.2833 train_time:430018ms step_avg:47.78ms +step:9200/20000 train_loss:2.1297 train_time:439526ms step_avg:47.77ms +step:9400/20000 train_loss:2.1854 train_time:449052ms step_avg:47.77ms +step:9600/20000 train_loss:2.1940 train_time:458577ms step_avg:47.77ms +step:9800/20000 train_loss:2.1239 train_time:468102ms step_avg:47.77ms +step:10000/20000 train_loss:2.1634 train_time:477629ms step_avg:47.76ms +step:10000/20000 val_loss:2.1610 val_bpb:1.2799 train_time:477652ms step_avg:47.77ms +step:10200/20000 train_loss:2.1188 train_time:487160ms step_avg:47.76ms +step:10400/20000 train_loss:2.1489 train_time:496693ms step_avg:47.76ms +step:10600/20000 train_loss:2.0260 train_time:506220ms step_avg:47.76ms +step:10800/20000 train_loss:2.2336 train_time:515746ms step_avg:47.75ms +step:11000/20000 train_loss:2.1634 train_time:525274ms step_avg:47.75ms +step:11000/20000 val_loss:2.1551 val_bpb:1.2764 train_time:525298ms step_avg:47.75ms +step:11200/20000 train_loss:2.1209 train_time:534796ms step_avg:47.75ms +step:11400/20000 train_loss:2.1049 train_time:544318ms step_avg:47.75ms +step:11600/20000 train_loss:2.1093 train_time:553842ms step_avg:47.74ms +step:11800/20000 train_loss:2.1475 train_time:563367ms step_avg:47.74ms +step:12000/20000 train_loss:2.1222 train_time:572897ms step_avg:47.74ms +step:12000/20000 val_loss:2.1500 val_bpb:1.2734 train_time:572921ms step_avg:47.74ms +step:12200/20000 train_loss:2.2634 train_time:582423ms step_avg:47.74ms +step:12400/20000 train_loss:1.9120 train_time:591994ms step_avg:47.74ms +step:12600/20000 train_loss:2.1398 train_time:601521ms step_avg:47.74ms +step:12800/20000 train_loss:2.1643 train_time:611055ms step_avg:47.74ms +step:13000/20000 train_loss:2.2421 train_time:620585ms step_avg:47.74ms +step:13000/20000 val_loss:2.1498 val_bpb:1.2732 train_time:620609ms step_avg:47.74ms +step:13200/20000 train_loss:2.2502 train_time:630114ms step_avg:47.74ms +step:13400/20000 train_loss:2.1292 train_time:639639ms step_avg:47.73ms +step:13600/20000 train_loss:2.0075 train_time:649167ms step_avg:47.73ms +step:13800/20000 train_loss:2.0835 train_time:658698ms step_avg:47.73ms +step:14000/20000 train_loss:2.1477 train_time:668223ms step_avg:47.73ms +step:14000/20000 val_loss:2.1431 val_bpb:1.2693 train_time:668246ms step_avg:47.73ms +step:14200/20000 train_loss:2.2300 train_time:677747ms step_avg:47.73ms +step:14400/20000 train_loss:2.1275 train_time:687272ms step_avg:47.73ms +step:14600/20000 train_loss:2.1833 train_time:696795ms step_avg:47.73ms +step:14800/20000 train_loss:1.9672 train_time:706320ms step_avg:47.72ms +step:15000/20000 train_loss:2.0831 train_time:715860ms step_avg:47.72ms +step:15000/20000 val_loss:2.1385 val_bpb:1.2666 train_time:715884ms step_avg:47.73ms +step:15200/20000 train_loss:2.1946 train_time:725389ms step_avg:47.72ms +step:15400/20000 train_loss:2.0976 train_time:734914ms step_avg:47.72ms +step:15600/20000 train_loss:2.1188 train_time:744440ms step_avg:47.72ms +step:15800/20000 train_loss:1.9666 train_time:753967ms step_avg:47.72ms +step:16000/20000 train_loss:2.1792 train_time:763496ms step_avg:47.72ms +step:16000/20000 val_loss:2.1364 val_bpb:1.2653 train_time:763520ms step_avg:47.72ms +step:16200/20000 train_loss:2.0572 train_time:773024ms step_avg:47.72ms +step:16400/20000 train_loss:2.0803 train_time:782549ms step_avg:47.72ms +step:16600/20000 train_loss:2.0230 train_time:792111ms step_avg:47.72ms +step:16800/20000 train_loss:2.2437 train_time:801632ms step_avg:47.72ms +step:17000/20000 train_loss:2.1546 train_time:811162ms step_avg:47.72ms +step:17000/20000 val_loss:2.1339 val_bpb:1.2638 train_time:811185ms step_avg:47.72ms +step:17200/20000 train_loss:2.1488 train_time:820691ms step_avg:47.71ms +step:17400/20000 train_loss:2.0464 train_time:830216ms step_avg:47.71ms +step:17600/20000 train_loss:2.1303 train_time:839742ms step_avg:47.71ms +step:17800/20000 train_loss:2.2072 train_time:849266ms step_avg:47.71ms +step:18000/20000 train_loss:2.1204 train_time:858792ms step_avg:47.71ms +step:18000/20000 val_loss:2.1331 val_bpb:1.2633 train_time:858815ms step_avg:47.71ms +step:18200/20000 train_loss:2.3465 train_time:868315ms step_avg:47.71ms +step:18400/20000 train_loss:2.1184 train_time:877843ms step_avg:47.71ms +step:18600/20000 train_loss:2.1484 train_time:887374ms step_avg:47.71ms +step:18800/20000 train_loss:2.1997 train_time:896902ms step_avg:47.71ms +step:19000/20000 train_loss:2.1245 train_time:906423ms step_avg:47.71ms +step:19000/20000 val_loss:2.1240 val_bpb:1.2580 train_time:906447ms step_avg:47.71ms +step:19200/20000 train_loss:1.9684 train_time:916010ms step_avg:47.71ms +step:19400/20000 train_loss:2.1891 train_time:925577ms step_avg:47.71ms +step:19600/20000 train_loss:2.2643 train_time:935103ms step_avg:47.71ms +step:19800/20000 train_loss:1.9416 train_time:944630ms step_avg:47.71ms +step:20000/20000 train_loss:2.1303 train_time:954159ms step_avg:47.71ms +step:20000/20000 val_loss:2.0938 val_bpb:1.2400 train_time:954183ms step_avg:47.71ms +peak memory allocated: 11263 MiB reserved: 11320 MiB +Serialized model: 74578915 bytes +Code size: 52042 bytes +Total submission size: 74630957 bytes +Serialized model int8+zlib: 15169623 bytes (payload:19030336 raw_torch:19080377 payload_ratio:3.92x) +Total submission size int8+zlib: 15221665 bytes +final_int8_zlib_roundtrip val_loss:2.1095 val_bpb:1.2494 eval_time:1566ms +final_int8_zlib_roundtrip_exact val_loss:2.10952960 val_bpb:1.24938225 diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt new file mode 100644 index 0000000000..37e760da50 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt @@ -0,0 +1,1399 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) + - Algorithm 2 (Convolution Layers) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.zero_() + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers (Implementation of Algorithm 2) + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 16:39:39 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 42C P0 129W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.05 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:100 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:10/100 +warmup_step:20/100 +warmup_step:30/100 +warmup_step:40/100 +warmup_step:50/100 +warmup_step:60/100 +warmup_step:70/100 +warmup_step:80/100 +warmup_step:90/100 +warmup_step:100/100 +step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9370 train_time:27ms step_avg:26.58ms +step:2/20000 train_loss:15.6862 train_time:73ms step_avg:36.37ms +step:3/20000 train_loss:7.2524 train_time:118ms step_avg:39.23ms +step:4/20000 train_loss:7.3808 train_time:164ms step_avg:41.02ms +step:5/20000 train_loss:7.5702 train_time:208ms step_avg:41.53ms +step:6/20000 train_loss:7.7805 train_time:267ms step_avg:44.49ms +step:7/20000 train_loss:6.6300 train_time:302ms step_avg:43.21ms +step:8/20000 train_loss:6.1898 train_time:345ms step_avg:43.10ms +step:9/20000 train_loss:5.9347 train_time:387ms step_avg:42.96ms +step:10/20000 train_loss:5.6995 train_time:428ms step_avg:42.79ms +step:200/20000 train_loss:2.9527 train_time:8780ms step_avg:43.90ms +step:400/20000 train_loss:2.4409 train_time:17552ms step_avg:43.88ms +step:600/20000 train_loss:2.6121 train_time:26339ms step_avg:43.90ms +step:800/20000 train_loss:2.3574 train_time:35129ms step_avg:43.91ms +step:1000/20000 train_loss:2.4316 train_time:43937ms step_avg:43.94ms +step:1000/20000 val_loss:2.3994 val_bpb:1.4211 train_time:43950ms step_avg:43.95ms +step:1200/20000 train_loss:2.4514 train_time:52775ms step_avg:43.98ms +step:1400/20000 train_loss:2.4901 train_time:61598ms step_avg:44.00ms +step:1600/20000 train_loss:2.1602 train_time:70441ms step_avg:44.03ms +step:1800/20000 train_loss:2.2684 train_time:79269ms step_avg:44.04ms +step:2000/20000 train_loss:2.3156 train_time:88094ms step_avg:44.05ms +step:2000/20000 val_loss:2.3011 val_bpb:1.3628 train_time:88109ms step_avg:44.05ms +step:2200/20000 train_loss:2.1455 train_time:96924ms step_avg:44.06ms +step:2400/20000 train_loss:2.2675 train_time:105742ms step_avg:44.06ms +step:2600/20000 train_loss:2.4753 train_time:114561ms step_avg:44.06ms +step:2800/20000 train_loss:2.3080 train_time:123367ms step_avg:44.06ms +step:3000/20000 train_loss:2.2924 train_time:132173ms step_avg:44.06ms +step:3000/20000 val_loss:2.2622 val_bpb:1.3398 train_time:132200ms step_avg:44.07ms +step:3200/20000 train_loss:2.2553 train_time:140984ms step_avg:44.06ms +step:3400/20000 train_loss:2.2284 train_time:149782ms step_avg:44.05ms +step:3600/20000 train_loss:2.1919 train_time:158602ms step_avg:44.06ms +step:3800/20000 train_loss:2.2899 train_time:167394ms step_avg:44.05ms +step:4000/20000 train_loss:2.2306 train_time:176184ms step_avg:44.05ms +step:4000/20000 val_loss:2.2379 val_bpb:1.3254 train_time:176204ms step_avg:44.05ms +step:4200/20000 train_loss:2.2384 train_time:185050ms step_avg:44.06ms +step:4400/20000 train_loss:2.1830 train_time:193836ms step_avg:44.05ms +step:4600/20000 train_loss:2.0421 train_time:202628ms step_avg:44.05ms +step:4800/20000 train_loss:2.3297 train_time:211418ms step_avg:44.05ms +step:5000/20000 train_loss:2.0996 train_time:220196ms step_avg:44.04ms +step:5000/20000 val_loss:2.2224 val_bpb:1.3163 train_time:220210ms step_avg:44.04ms +step:5200/20000 train_loss:2.2382 train_time:228983ms step_avg:44.04ms +step:5400/20000 train_loss:2.2570 train_time:237772ms step_avg:44.03ms +step:5600/20000 train_loss:2.2551 train_time:246556ms step_avg:44.03ms +step:5800/20000 train_loss:2.2178 train_time:255344ms step_avg:44.02ms +step:6000/20000 train_loss:2.2837 train_time:264127ms step_avg:44.02ms +step:6000/20000 val_loss:2.2127 val_bpb:1.3105 train_time:264141ms step_avg:44.02ms +step:6200/20000 train_loss:2.1591 train_time:272909ms step_avg:44.02ms +step:6400/20000 train_loss:2.2351 train_time:281692ms step_avg:44.01ms +step:6600/20000 train_loss:2.1986 train_time:290473ms step_avg:44.01ms +step:6800/20000 train_loss:2.2593 train_time:299250ms step_avg:44.01ms +step:7000/20000 train_loss:2.2935 train_time:308036ms step_avg:44.01ms +step:7000/20000 val_loss:2.2025 val_bpb:1.3044 train_time:308051ms step_avg:44.01ms +step:7200/20000 train_loss:2.2722 train_time:316819ms step_avg:44.00ms +step:7400/20000 train_loss:2.1888 train_time:325594ms step_avg:44.00ms +step:7600/20000 train_loss:2.0719 train_time:334388ms step_avg:44.00ms +step:7800/20000 train_loss:2.2174 train_time:343173ms step_avg:44.00ms +step:8000/20000 train_loss:2.1836 train_time:351961ms step_avg:44.00ms +step:8000/20000 val_loss:2.1938 val_bpb:1.2993 train_time:351974ms step_avg:44.00ms +step:8200/20000 train_loss:2.2516 train_time:360751ms step_avg:43.99ms +step:8400/20000 train_loss:2.2039 train_time:369595ms step_avg:44.00ms +step:8600/20000 train_loss:2.2087 train_time:378377ms step_avg:44.00ms +step:8800/20000 train_loss:2.1797 train_time:387173ms step_avg:44.00ms +step:9000/20000 train_loss:2.0962 train_time:395957ms step_avg:44.00ms +step:9000/20000 val_loss:2.1889 val_bpb:1.2964 train_time:395971ms step_avg:44.00ms +step:9200/20000 train_loss:2.1523 train_time:404751ms step_avg:43.99ms +step:9400/20000 train_loss:2.2119 train_time:413533ms step_avg:43.99ms +step:9600/20000 train_loss:2.2201 train_time:422325ms step_avg:43.99ms +step:9800/20000 train_loss:2.1550 train_time:431106ms step_avg:43.99ms +step:10000/20000 train_loss:2.1850 train_time:439893ms step_avg:43.99ms +step:10000/20000 val_loss:2.1846 val_bpb:1.2938 train_time:439908ms step_avg:43.99ms +step:10200/20000 train_loss:2.1424 train_time:448683ms step_avg:43.99ms +step:10400/20000 train_loss:2.1723 train_time:457466ms step_avg:43.99ms +step:10600/20000 train_loss:2.0463 train_time:466256ms step_avg:43.99ms +step:10800/20000 train_loss:2.2569 train_time:475038ms step_avg:43.98ms +step:11000/20000 train_loss:2.1925 train_time:483830ms step_avg:43.98ms +step:11000/20000 val_loss:2.1781 val_bpb:1.2900 train_time:483844ms step_avg:43.99ms +step:11200/20000 train_loss:2.1436 train_time:492619ms step_avg:43.98ms +step:11400/20000 train_loss:2.1248 train_time:501403ms step_avg:43.98ms +step:11600/20000 train_loss:2.1365 train_time:510193ms step_avg:43.98ms +step:11800/20000 train_loss:2.1706 train_time:518971ms step_avg:43.98ms +step:12000/20000 train_loss:2.1453 train_time:527759ms step_avg:43.98ms +step:12000/20000 val_loss:2.1730 val_bpb:1.2870 train_time:527774ms step_avg:43.98ms +step:12200/20000 train_loss:2.2864 train_time:536547ms step_avg:43.98ms +step:12400/20000 train_loss:1.9287 train_time:545394ms step_avg:43.98ms +step:12600/20000 train_loss:2.1581 train_time:554180ms step_avg:43.98ms +step:12800/20000 train_loss:2.1741 train_time:562952ms step_avg:43.98ms +step:13000/20000 train_loss:2.2488 train_time:571722ms step_avg:43.98ms +step:13000/20000 val_loss:2.1538 val_bpb:1.2756 train_time:571747ms step_avg:43.98ms +step:13200/20000 train_loss:2.2517 train_time:580590ms step_avg:43.98ms +step:13400/20000 train_loss:2.1198 train_time:589373ms step_avg:43.98ms +step:13600/20000 train_loss:1.9812 train_time:598159ms step_avg:43.98ms +step:13643/20000 val_loss:2.1313 val_bpb:1.2623 train_time:600049ms step_avg:43.98ms +stopping_early: wallclock_cap train_time:600049ms step:13643/20000 +peak memory allocated: 10119 MiB reserved: 10438 MiB +Serialized model: 67224983 bytes +Code size: 52041 bytes +Total submission size: 67277024 bytes +Serialized model int8+zlib: 13695416 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 13747457 bytes +final_int8_zlib_roundtrip val_loss:2.1480 val_bpb:1.2722 eval_time:1414ms +final_int8_zlib_roundtrip_exact val_loss:2.14799565 val_bpb:1.27216402 diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/8c0bf12c-29d7-4dad-8cb7-6fa04b89b309.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/8c0bf12c-29d7-4dad-8cb7-6fa04b89b309.txt new file mode 100644 index 0000000000..4044fd2419 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/8c0bf12c-29d7-4dad-8cb7-6fa04b89b309.txt @@ -0,0 +1,1431 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) + - Algorithm 2 (Convolution Layers) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.zero_() + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers (Implementation of Algorithm 2) + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 16:59:34 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 42C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:18897488 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.05 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:500 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:10/500 +warmup_step:20/500 +warmup_step:30/500 +warmup_step:40/500 +warmup_step:50/500 +warmup_step:60/500 +warmup_step:70/500 +warmup_step:80/500 +warmup_step:90/500 +warmup_step:100/500 +warmup_step:110/500 +warmup_step:120/500 +warmup_step:130/500 +warmup_step:140/500 +warmup_step:150/500 +warmup_step:160/500 +warmup_step:170/500 +warmup_step:180/500 +warmup_step:190/500 +warmup_step:200/500 +warmup_step:210/500 +warmup_step:220/500 +warmup_step:230/500 +warmup_step:240/500 +warmup_step:250/500 +warmup_step:260/500 +warmup_step:270/500 +warmup_step:280/500 +warmup_step:290/500 +warmup_step:300/500 +warmup_step:310/500 +warmup_step:320/500 +warmup_step:330/500 +warmup_step:340/500 +warmup_step:350/500 +warmup_step:360/500 +warmup_step:370/500 +warmup_step:380/500 +warmup_step:390/500 +warmup_step:400/500 +warmup_step:410/500 +warmup_step:420/500 +warmup_step:430/500 +warmup_step:440/500 +warmup_step:450/500 +warmup_step:460/500 +warmup_step:470/500 +warmup_step:480/500 +warmup_step:490/500 +warmup_step:500/500 +step:0/20000 val_loss:6.9363 val_bpb:4.1080 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9355 train_time:34ms step_avg:33.81ms +step:2/20000 train_loss:15.4054 train_time:88ms step_avg:43.76ms +step:3/20000 train_loss:7.0580 train_time:135ms step_avg:45.11ms +step:4/20000 train_loss:7.5619 train_time:184ms step_avg:45.92ms +step:5/20000 train_loss:7.6829 train_time:232ms step_avg:46.32ms +step:6/20000 train_loss:7.7755 train_time:279ms step_avg:46.55ms +step:7/20000 train_loss:6.6480 train_time:327ms step_avg:46.78ms +step:8/20000 train_loss:6.1307 train_time:376ms step_avg:47.03ms +step:9/20000 train_loss:5.8335 train_time:425ms step_avg:47.21ms +step:10/20000 train_loss:5.6066 train_time:472ms step_avg:47.25ms +step:200/20000 train_loss:2.9260 train_time:9718ms step_avg:48.59ms +step:400/20000 train_loss:2.4167 train_time:19466ms step_avg:48.66ms +step:600/20000 train_loss:2.5982 train_time:29229ms step_avg:48.71ms +step:800/20000 train_loss:2.3428 train_time:38994ms step_avg:48.74ms +step:1000/20000 train_loss:2.4160 train_time:48764ms step_avg:48.76ms +step:1000/20000 val_loss:2.3810 val_bpb:1.4101 train_time:48780ms step_avg:48.78ms +step:1200/20000 train_loss:2.4330 train_time:58537ms step_avg:48.78ms +step:1400/20000 train_loss:2.4768 train_time:68306ms step_avg:48.79ms +step:1600/20000 train_loss:2.1452 train_time:78069ms step_avg:48.79ms +step:1800/20000 train_loss:2.2518 train_time:87833ms step_avg:48.80ms +step:2000/20000 train_loss:2.2958 train_time:97588ms step_avg:48.79ms +step:2000/20000 val_loss:2.2838 val_bpb:1.3526 train_time:97603ms step_avg:48.80ms +step:2200/20000 train_loss:2.1265 train_time:107348ms step_avg:48.79ms +step:2400/20000 train_loss:2.2490 train_time:117089ms step_avg:48.79ms +step:2600/20000 train_loss:2.4605 train_time:126824ms step_avg:48.78ms +step:2800/20000 train_loss:2.2921 train_time:136549ms step_avg:48.77ms +step:3000/20000 train_loss:2.2760 train_time:146271ms step_avg:48.76ms +step:3000/20000 val_loss:2.2444 val_bpb:1.3293 train_time:146287ms step_avg:48.76ms +step:3200/20000 train_loss:2.2376 train_time:155990ms step_avg:48.75ms +step:3400/20000 train_loss:2.2064 train_time:165700ms step_avg:48.74ms +step:3600/20000 train_loss:2.1746 train_time:175408ms step_avg:48.72ms +step:3800/20000 train_loss:2.2689 train_time:185099ms step_avg:48.71ms +step:4000/20000 train_loss:2.2102 train_time:194790ms step_avg:48.70ms +step:4000/20000 val_loss:2.2196 val_bpb:1.3146 train_time:194806ms step_avg:48.70ms +step:4200/20000 train_loss:2.2205 train_time:204545ms step_avg:48.70ms +step:4400/20000 train_loss:2.1643 train_time:214231ms step_avg:48.69ms +step:4600/20000 train_loss:2.0250 train_time:223919ms step_avg:48.68ms +step:4800/20000 train_loss:2.3136 train_time:233603ms step_avg:48.67ms +step:5000/20000 train_loss:2.0813 train_time:243284ms step_avg:48.66ms +step:5000/20000 val_loss:2.2034 val_bpb:1.3050 train_time:243300ms step_avg:48.66ms +step:5200/20000 train_loss:2.2222 train_time:252970ms step_avg:48.65ms +step:5400/20000 train_loss:2.2378 train_time:262653ms step_avg:48.64ms +step:5600/20000 train_loss:2.2314 train_time:272333ms step_avg:48.63ms +step:5800/20000 train_loss:2.1918 train_time:282020ms step_avg:48.62ms +step:6000/20000 train_loss:2.2623 train_time:291701ms step_avg:48.62ms +step:6000/20000 val_loss:2.1940 val_bpb:1.2994 train_time:291716ms step_avg:48.62ms +step:6200/20000 train_loss:2.1408 train_time:301380ms step_avg:48.61ms +step:6400/20000 train_loss:2.2189 train_time:311065ms step_avg:48.60ms +step:6600/20000 train_loss:2.1810 train_time:320736ms step_avg:48.60ms +step:6800/20000 train_loss:2.2474 train_time:330420ms step_avg:48.59ms +step:7000/20000 train_loss:2.2785 train_time:340102ms step_avg:48.59ms +step:7000/20000 val_loss:2.1836 val_bpb:1.2933 train_time:340118ms step_avg:48.59ms +step:7200/20000 train_loss:2.2525 train_time:349777ms step_avg:48.58ms +step:7400/20000 train_loss:2.1674 train_time:359453ms step_avg:48.57ms +step:7600/20000 train_loss:2.0531 train_time:369126ms step_avg:48.57ms +step:7800/20000 train_loss:2.1992 train_time:378797ms step_avg:48.56ms +step:8000/20000 train_loss:2.1641 train_time:388473ms step_avg:48.56ms +step:8000/20000 val_loss:2.1747 val_bpb:1.2880 train_time:388489ms step_avg:48.56ms +step:8200/20000 train_loss:2.2369 train_time:398161ms step_avg:48.56ms +step:8400/20000 train_loss:2.1847 train_time:407900ms step_avg:48.56ms +step:8600/20000 train_loss:2.1873 train_time:417583ms step_avg:48.56ms +step:8800/20000 train_loss:2.1570 train_time:427272ms step_avg:48.55ms +step:9000/20000 train_loss:2.0783 train_time:436952ms step_avg:48.55ms +step:9000/20000 val_loss:2.1697 val_bpb:1.2850 train_time:436968ms step_avg:48.55ms +step:9200/20000 train_loss:2.1343 train_time:446634ms step_avg:48.55ms +step:9400/20000 train_loss:2.1895 train_time:456322ms step_avg:48.54ms +step:9600/20000 train_loss:2.2026 train_time:466001ms step_avg:48.54ms +step:9800/20000 train_loss:2.1330 train_time:475675ms step_avg:48.54ms +step:10000/20000 train_loss:2.1662 train_time:485356ms step_avg:48.54ms +step:10000/20000 val_loss:2.1644 val_bpb:1.2819 train_time:485372ms step_avg:48.54ms +step:10200/20000 train_loss:2.1213 train_time:495028ms step_avg:48.53ms +step:10400/20000 train_loss:2.1505 train_time:504709ms step_avg:48.53ms +step:10600/20000 train_loss:2.0258 train_time:514399ms step_avg:48.53ms +step:10800/20000 train_loss:2.2354 train_time:524090ms step_avg:48.53ms +step:11000/20000 train_loss:2.1705 train_time:533770ms step_avg:48.52ms +step:11000/20000 val_loss:2.1580 val_bpb:1.2781 train_time:533786ms step_avg:48.53ms +step:11200/20000 train_loss:2.1251 train_time:543447ms step_avg:48.52ms +step:11400/20000 train_loss:2.1018 train_time:553123ms step_avg:48.52ms +step:11600/20000 train_loss:2.1047 train_time:562918ms step_avg:48.53ms +step:11800/20000 train_loss:2.1291 train_time:572608ms step_avg:48.53ms +step:12000/20000 train_loss:2.0992 train_time:582295ms step_avg:48.52ms +step:12000/20000 val_loss:2.1257 val_bpb:1.2589 train_time:582311ms step_avg:48.53ms +step:12200/20000 train_loss:2.2375 train_time:591987ms step_avg:48.52ms +step:12366/20000 val_loss:2.1141 val_bpb:1.2521 train_time:600048ms step_avg:48.52ms +stopping_early: wallclock_cap train_time:600048ms step:12366/20000 +peak memory allocated: 11263 MiB reserved: 11320 MiB +Serialized model: 74578915 bytes +Code size: 52042 bytes +Total submission size: 74630957 bytes +Serialized model int8+zlib: 15165177 bytes (payload:19030336 raw_torch:19080377 payload_ratio:3.92x) +Total submission size int8+zlib: 15217219 bytes +final_int8_zlib_roundtrip val_loss:2.1298 val_bpb:1.2614 eval_time:1561ms +final_int8_zlib_roundtrip_exact val_loss:2.12975066 val_bpb:1.26135830 diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/bf9a8898-d4e8-40dd-8786-320b140eb700.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/bf9a8898-d4e8-40dd-8786-320b140eb700.txt new file mode 100644 index 0000000000..3822203d52 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/bf9a8898-d4e8-40dd-8786-320b140eb700.txt @@ -0,0 +1,1431 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 500)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers using Identity, Partial Identity, and Hadamard) + - Algorithm 2 (Convolution Layers) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping (e.g., expanding MLPs) + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Respect the script's architectural _zero_init flags (e.g., `proj` & `lm_head`). + # This achieves the paper's goal of dynamical isometry (Identity residual pass) + # and avoids the gradient deadlock. + if getattr(module, '_zero_init', False): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.zero_() + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + continue + + # 2. Linear Layers (c_q, c_k, c_v, fc) + # c_q (512->512) becomes Identity. + # c_k/c_v (512->256) become Partial Identities. + # fc (512->1024) becomes a Hadamard transform. + if isinstance(module, nn.Linear): + _zero_init_tensor(module.weight.data, module.out_features, module.in_features) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers (Implementation of Algorithm 2) + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + out_c, in_c = module.out_channels, module.in_channels + module.weight.data.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: + module.weight.data[:, :, centers[0]] = block + elif len(centers) == 2: + module.weight.data[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: + module.weight.data[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if hasattr(module, 'weight') and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, 'bias') and module.bias is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 17:21:24 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 41C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 44C P0 130W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 42C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:18897488 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:500 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:10/500 +warmup_step:20/500 +warmup_step:30/500 +warmup_step:40/500 +warmup_step:50/500 +warmup_step:60/500 +warmup_step:70/500 +warmup_step:80/500 +warmup_step:90/500 +warmup_step:100/500 +warmup_step:110/500 +warmup_step:120/500 +warmup_step:130/500 +warmup_step:140/500 +warmup_step:150/500 +warmup_step:160/500 +warmup_step:170/500 +warmup_step:180/500 +warmup_step:190/500 +warmup_step:200/500 +warmup_step:210/500 +warmup_step:220/500 +warmup_step:230/500 +warmup_step:240/500 +warmup_step:250/500 +warmup_step:260/500 +warmup_step:270/500 +warmup_step:280/500 +warmup_step:290/500 +warmup_step:300/500 +warmup_step:310/500 +warmup_step:320/500 +warmup_step:330/500 +warmup_step:340/500 +warmup_step:350/500 +warmup_step:360/500 +warmup_step:370/500 +warmup_step:380/500 +warmup_step:390/500 +warmup_step:400/500 +warmup_step:410/500 +warmup_step:420/500 +warmup_step:430/500 +warmup_step:440/500 +warmup_step:450/500 +warmup_step:460/500 +warmup_step:470/500 +warmup_step:480/500 +warmup_step:490/500 +warmup_step:500/500 +step:0/20000 val_loss:6.9363 val_bpb:4.1080 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9355 train_time:34ms step_avg:33.69ms +step:2/20000 train_loss:15.7807 train_time:86ms step_avg:42.82ms +step:3/20000 train_loss:7.3458 train_time:134ms step_avg:44.55ms +step:4/20000 train_loss:7.2066 train_time:183ms step_avg:45.85ms +step:5/20000 train_loss:7.2275 train_time:230ms step_avg:46.08ms +step:6/20000 train_loss:7.3732 train_time:278ms step_avg:46.37ms +step:7/20000 train_loss:6.4856 train_time:326ms step_avg:46.59ms +step:8/20000 train_loss:6.1885 train_time:375ms step_avg:46.82ms +step:9/20000 train_loss:5.9111 train_time:423ms step_avg:46.98ms +step:10/20000 train_loss:5.7179 train_time:471ms step_avg:47.09ms +step:200/20000 train_loss:2.9473 train_time:9698ms step_avg:48.49ms +step:400/20000 train_loss:2.4203 train_time:19401ms step_avg:48.50ms +step:600/20000 train_loss:2.5958 train_time:29130ms step_avg:48.55ms +step:800/20000 train_loss:2.3421 train_time:38857ms step_avg:48.57ms +step:1000/20000 train_loss:2.4133 train_time:48584ms step_avg:48.58ms +step:1000/20000 val_loss:2.3795 val_bpb:1.4093 train_time:48601ms step_avg:48.60ms +step:1200/20000 train_loss:2.4334 train_time:58329ms step_avg:48.61ms +step:1400/20000 train_loss:2.4710 train_time:68065ms step_avg:48.62ms +step:1600/20000 train_loss:2.1405 train_time:77806ms step_avg:48.63ms +step:1800/20000 train_loss:2.2448 train_time:87536ms step_avg:48.63ms +step:2000/20000 train_loss:2.2988 train_time:97267ms step_avg:48.63ms +step:2000/20000 val_loss:2.2813 val_bpb:1.3511 train_time:97284ms step_avg:48.64ms +step:2200/20000 train_loss:2.1245 train_time:106998ms step_avg:48.64ms +step:2400/20000 train_loss:2.2451 train_time:116722ms step_avg:48.63ms +step:2600/20000 train_loss:2.4558 train_time:126440ms step_avg:48.63ms +step:2800/20000 train_loss:2.2855 train_time:136153ms step_avg:48.63ms +step:3000/20000 train_loss:2.2724 train_time:145860ms step_avg:48.62ms +step:3000/20000 val_loss:2.2407 val_bpb:1.3271 train_time:145876ms step_avg:48.63ms +step:3200/20000 train_loss:2.2333 train_time:155570ms step_avg:48.62ms +step:3400/20000 train_loss:2.2067 train_time:165270ms step_avg:48.61ms +step:3600/20000 train_loss:2.1651 train_time:174973ms step_avg:48.60ms +step:3800/20000 train_loss:2.2628 train_time:184671ms step_avg:48.60ms +step:4000/20000 train_loss:2.2087 train_time:194366ms step_avg:48.59ms +step:4000/20000 val_loss:2.2175 val_bpb:1.3133 train_time:194382ms step_avg:48.60ms +step:4200/20000 train_loss:2.2175 train_time:204136ms step_avg:48.60ms +step:4400/20000 train_loss:2.1617 train_time:213826ms step_avg:48.60ms +step:4600/20000 train_loss:2.0184 train_time:223526ms step_avg:48.59ms +step:4800/20000 train_loss:2.3111 train_time:233218ms step_avg:48.59ms +step:5000/20000 train_loss:2.0792 train_time:242920ms step_avg:48.58ms +step:5000/20000 val_loss:2.2005 val_bpb:1.3033 train_time:242944ms step_avg:48.59ms +step:5200/20000 train_loss:2.2193 train_time:252632ms step_avg:48.58ms +step:5400/20000 train_loss:2.2357 train_time:262335ms step_avg:48.58ms +step:5600/20000 train_loss:2.2336 train_time:272027ms step_avg:48.58ms +step:5800/20000 train_loss:2.1926 train_time:281718ms step_avg:48.57ms +step:6000/20000 train_loss:2.2634 train_time:291419ms step_avg:48.57ms +step:6000/20000 val_loss:2.1906 val_bpb:1.2974 train_time:291435ms step_avg:48.57ms +step:6200/20000 train_loss:2.1317 train_time:301112ms step_avg:48.57ms +step:6400/20000 train_loss:2.2120 train_time:310805ms step_avg:48.56ms +step:6600/20000 train_loss:2.1746 train_time:320490ms step_avg:48.56ms +step:6800/20000 train_loss:2.2376 train_time:330175ms step_avg:48.56ms +step:7000/20000 train_loss:2.2767 train_time:339874ms step_avg:48.55ms +step:7000/20000 val_loss:2.1799 val_bpb:1.2911 train_time:339889ms step_avg:48.56ms +step:7200/20000 train_loss:2.2503 train_time:349600ms step_avg:48.56ms +step:7400/20000 train_loss:2.1664 train_time:359291ms step_avg:48.55ms +step:7600/20000 train_loss:2.0471 train_time:368976ms step_avg:48.55ms +step:7800/20000 train_loss:2.1940 train_time:378661ms step_avg:48.55ms +step:8000/20000 train_loss:2.1614 train_time:388350ms step_avg:48.54ms +step:8000/20000 val_loss:2.1705 val_bpb:1.2855 train_time:388367ms step_avg:48.55ms +step:8200/20000 train_loss:2.2335 train_time:398040ms step_avg:48.54ms +step:8400/20000 train_loss:2.1793 train_time:407787ms step_avg:48.55ms +step:8600/20000 train_loss:2.1869 train_time:417469ms step_avg:48.54ms +step:8800/20000 train_loss:2.1529 train_time:427166ms step_avg:48.54ms +step:9000/20000 train_loss:2.0746 train_time:436845ms step_avg:48.54ms +step:9000/20000 val_loss:2.1663 val_bpb:1.2830 train_time:436862ms step_avg:48.54ms +step:9200/20000 train_loss:2.1343 train_time:446535ms step_avg:48.54ms +step:9400/20000 train_loss:2.1866 train_time:456312ms step_avg:48.54ms +step:9600/20000 train_loss:2.1989 train_time:466006ms step_avg:48.54ms +step:9800/20000 train_loss:2.1228 train_time:475685ms step_avg:48.54ms +step:10000/20000 train_loss:2.1626 train_time:485374ms step_avg:48.54ms +step:10000/20000 val_loss:2.1603 val_bpb:1.2794 train_time:485390ms step_avg:48.54ms +step:10200/20000 train_loss:2.1160 train_time:495062ms step_avg:48.54ms +step:10400/20000 train_loss:2.1504 train_time:504743ms step_avg:48.53ms +step:10600/20000 train_loss:2.0270 train_time:514434ms step_avg:48.53ms +step:10800/20000 train_loss:2.2328 train_time:524140ms step_avg:48.53ms +step:11000/20000 train_loss:2.1689 train_time:533835ms step_avg:48.53ms +step:11000/20000 val_loss:2.1548 val_bpb:1.2762 train_time:533852ms step_avg:48.53ms +step:11200/20000 train_loss:2.1158 train_time:543529ms step_avg:48.53ms +step:11400/20000 train_loss:2.0994 train_time:553225ms step_avg:48.53ms +step:11600/20000 train_loss:2.0988 train_time:562917ms step_avg:48.53ms +step:11800/20000 train_loss:2.1232 train_time:572599ms step_avg:48.53ms +step:12000/20000 train_loss:2.0914 train_time:582281ms step_avg:48.52ms +step:12000/20000 val_loss:2.1221 val_bpb:1.2569 train_time:582305ms step_avg:48.53ms +step:12200/20000 train_loss:2.2333 train_time:591982ms step_avg:48.52ms +step:12366/20000 val_loss:2.1106 val_bpb:1.2500 train_time:600049ms step_avg:48.52ms +stopping_early: wallclock_cap train_time:600049ms step:12366/20000 +peak memory allocated: 11263 MiB reserved: 11320 MiB +Serialized model: 74578915 bytes +Code size: 52042 bytes +Total submission size: 74630957 bytes +Serialized model int8+zlib: 15164905 bytes (payload:19030336 raw_torch:19080377 payload_ratio:3.92x) +Total submission size int8+zlib: 15216947 bytes +final_int8_zlib_roundtrip val_loss:2.1255 val_bpb:1.2588 eval_time:1570ms +final_int8_zlib_roundtrip_exact val_loss:2.12551013 val_bpb:1.25884682 diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/e31e596e-e21a-4c48-829c-78233c992cc8.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/e31e596e-e21a-4c48-829c-78233c992cc8.txt new file mode 100644 index 0000000000..2e48843808 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/e31e596e-e21a-4c48-829c-78233c992cc8.txt @@ -0,0 +1,1455 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers) + - Algorithm 2 (Convolution Layers) + - Section 4.1 Transformer specific initialization (W_Q=Identity, W_K/W_V=Zeros) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + # torch.eye intrinsically handles rectangular partial-identities. + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + @torch.no_grad() + def _init_mha_zero(module: nn.MultiheadAttention): + """Initializes PyTorch's native MultiheadAttention according to Sec 4.1""" + embed_dim = module.embed_dim + + # in_proj_weight groups Q, K, and V + if module.in_proj_weight is not None: + module.in_proj_weight.zero_() + # W_Q is identity + module.in_proj_weight[:embed_dim, :].copy_( + torch.eye(embed_dim, dtype=module.in_proj_weight.dtype, device=module.in_proj_weight.device) + ) + # W_K, W_V remain zero (as handled by .zero_()) + else: + # If instantiated with separate weights + if getattr(module, 'q_proj_weight', None) is not None: + module.q_proj_weight.copy_( + torch.eye(embed_dim, dtype=module.q_proj_weight.dtype, device=module.q_proj_weight.device) + ) + if getattr(module, 'k_proj_weight', None) is not None: + module.k_proj_weight.zero_() + if getattr(module, 'v_proj_weight', None) is not None: + module.v_proj_weight.zero_() + + if getattr(module, 'in_proj_bias', None) is not None: + module.in_proj_bias.zero_() + + # The output projection relies on the standard algorithm (P_l = Q_l) -> Identity + if getattr(module, 'out_proj', None) is not None and getattr(module.out_proj, 'weight', None) is not None: + _zero_init_tensor(module.out_proj.weight, module.out_proj.out_features, module.out_proj.in_features) + if getattr(module.out_proj, 'bias', None) is not None: + module.out_proj.bias.zero_() + + @torch.no_grad() + def _init_conv_zero(module): + """Initializes Convolutions according to Algorithm 2""" + out_c, in_c = module.out_channels, module.in_channels + module.weight.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: # Conv1D + module.weight[:, :, centers[0]] = block + elif len(centers) == 2: # Conv2D + module.weight[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: # Conv3D + module.weight[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.zero_() + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Transformers MultiheadAttention (PyTorch Native MHA) + if isinstance(module, nn.MultiheadAttention): + _init_mha_zero(module) + + # 2. Linear Layers (Applies to Feed-Forward Networks or HuggingFace Custom QKV Projections) + elif isinstance(module, nn.Linear): + name_lower = name.split('.')[-1].lower() + + # W_K, W_V at zero (Fallback heuristics for custom attention modules) + if any(k in name_lower for k in ['k_proj', 'v_proj', 'key', 'value']): + module.weight.data.zero_() + + # W_Q as identity (Fallback heuristics for custom attention modules) + elif any(q in name_lower for q in ['q_proj', 'query']): + module.weight.data.copy_( + torch.eye(module.out_features, module.in_features, dtype=module.weight.dtype, device=module.weight.device) + ) + + # Generic linear layer (e.g., Dimension expanding/shrinking FFN) + else: + _zero_init_tensor(module.weight, module.out_features, module.in_features) + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + _init_conv_zero(module) + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if getattr(module, 'weight', None) is not None: + module.weight.data.fill_(1.0) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 15:54:01 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9315 val_bpb:4.1053 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9313 train_time:35ms step_avg:34.95ms +step:2/20000 train_loss:11.5793 train_time:74ms step_avg:36.76ms +step:3/20000 train_loss:7.3305 train_time:118ms step_avg:39.18ms +step:4/20000 train_loss:6.4159 train_time:160ms step_avg:40.09ms +step:5/20000 train_loss:6.0377 train_time:205ms step_avg:41.09ms +step:6/20000 train_loss:6.2271 train_time:249ms step_avg:41.46ms +step:7/20000 train_loss:5.9458 train_time:293ms step_avg:41.86ms +step:8/20000 train_loss:5.8528 train_time:336ms step_avg:42.02ms +step:9/20000 train_loss:5.7863 train_time:380ms step_avg:42.23ms +step:10/20000 train_loss:5.7175 train_time:424ms step_avg:42.35ms +step:200/20000 train_loss:3.0837 train_time:8740ms step_avg:43.70ms +step:400/20000 train_loss:2.4340 train_time:17910ms step_avg:44.77ms +step:600/20000 train_loss:2.5968 train_time:26686ms step_avg:44.48ms +step:800/20000 train_loss:2.3361 train_time:35450ms step_avg:44.31ms +step:1000/20000 train_loss:2.4013 train_time:44257ms step_avg:44.26ms +step:1000/20000 val_loss:2.3684 val_bpb:1.4027 train_time:44272ms step_avg:44.27ms +step:1200/20000 train_loss:2.4126 train_time:53068ms step_avg:44.22ms +step:1400/20000 train_loss:2.4552 train_time:61880ms step_avg:44.20ms +step:1600/20000 train_loss:2.1209 train_time:70707ms step_avg:44.19ms +step:1800/20000 train_loss:2.2201 train_time:79530ms step_avg:44.18ms +step:2000/20000 train_loss:2.2712 train_time:88349ms step_avg:44.17ms +step:2000/20000 val_loss:2.2565 val_bpb:1.3364 train_time:88364ms step_avg:44.18ms +step:2200/20000 train_loss:2.0940 train_time:97181ms step_avg:44.17ms +step:2400/20000 train_loss:2.2174 train_time:105993ms step_avg:44.16ms +step:2600/20000 train_loss:2.4347 train_time:114808ms step_avg:44.16ms +step:2800/20000 train_loss:2.2522 train_time:123633ms step_avg:44.15ms +step:3000/20000 train_loss:2.2455 train_time:132437ms step_avg:44.15ms +step:3000/20000 val_loss:2.2117 val_bpb:1.3099 train_time:132451ms step_avg:44.15ms +step:3200/20000 train_loss:2.2056 train_time:141233ms step_avg:44.14ms +step:3400/20000 train_loss:2.1791 train_time:150025ms step_avg:44.13ms +step:3600/20000 train_loss:2.1333 train_time:158824ms step_avg:44.12ms +step:3800/20000 train_loss:2.2386 train_time:167613ms step_avg:44.11ms +step:4000/20000 train_loss:2.1768 train_time:176396ms step_avg:44.10ms +step:4000/20000 val_loss:2.1844 val_bpb:1.2937 train_time:176410ms step_avg:44.10ms +step:4200/20000 train_loss:2.1867 train_time:185285ms step_avg:44.12ms +step:4400/20000 train_loss:2.1291 train_time:194065ms step_avg:44.11ms +step:4600/20000 train_loss:1.9870 train_time:202854ms step_avg:44.10ms +step:4800/20000 train_loss:2.2776 train_time:211640ms step_avg:44.09ms +step:5000/20000 train_loss:2.0421 train_time:220421ms step_avg:44.08ms +step:5000/20000 val_loss:2.1662 val_bpb:1.2829 train_time:220434ms step_avg:44.09ms +step:5200/20000 train_loss:2.1840 train_time:229207ms step_avg:44.08ms +step:5400/20000 train_loss:2.1990 train_time:237993ms step_avg:44.07ms +step:5600/20000 train_loss:2.1972 train_time:246769ms step_avg:44.07ms +step:5800/20000 train_loss:2.1559 train_time:255545ms step_avg:44.06ms +step:6000/20000 train_loss:2.2323 train_time:264331ms step_avg:44.06ms +step:6000/20000 val_loss:2.1560 val_bpb:1.2769 train_time:264344ms step_avg:44.06ms +step:6200/20000 train_loss:2.0976 train_time:273199ms step_avg:44.06ms +step:6400/20000 train_loss:2.1702 train_time:281981ms step_avg:44.06ms +step:6600/20000 train_loss:2.1342 train_time:290758ms step_avg:44.05ms +step:6800/20000 train_loss:2.2053 train_time:299542ms step_avg:44.05ms +step:7000/20000 train_loss:2.2348 train_time:308319ms step_avg:44.05ms +step:7000/20000 val_loss:2.1433 val_bpb:1.2694 train_time:308334ms step_avg:44.05ms +step:7200/20000 train_loss:2.2099 train_time:317103ms step_avg:44.04ms +step:7400/20000 train_loss:2.1284 train_time:325890ms step_avg:44.04ms +step:7600/20000 train_loss:2.0098 train_time:334673ms step_avg:44.04ms +step:7800/20000 train_loss:2.1552 train_time:343460ms step_avg:44.03ms +step:8000/20000 train_loss:2.1270 train_time:352237ms step_avg:44.03ms +step:8000/20000 val_loss:2.1333 val_bpb:1.2634 train_time:352251ms step_avg:44.03ms +step:8200/20000 train_loss:2.1960 train_time:361006ms step_avg:44.03ms +step:8400/20000 train_loss:2.1513 train_time:369845ms step_avg:44.03ms +step:8600/20000 train_loss:2.1514 train_time:378621ms step_avg:44.03ms +step:8800/20000 train_loss:2.1110 train_time:387420ms step_avg:44.03ms +step:9000/20000 train_loss:2.0373 train_time:396203ms step_avg:44.02ms +step:9000/20000 val_loss:2.1278 val_bpb:1.2602 train_time:396217ms step_avg:44.02ms +step:9200/20000 train_loss:2.0951 train_time:404984ms step_avg:44.02ms +step:9400/20000 train_loss:2.1393 train_time:413762ms step_avg:44.02ms +step:9600/20000 train_loss:2.1571 train_time:422546ms step_avg:44.02ms +step:9800/20000 train_loss:2.0810 train_time:431318ms step_avg:44.01ms +step:10000/20000 train_loss:2.1207 train_time:440096ms step_avg:44.01ms +step:10000/20000 val_loss:2.1216 val_bpb:1.2565 train_time:440110ms step_avg:44.01ms +step:10200/20000 train_loss:2.0766 train_time:448869ms step_avg:44.01ms +step:10400/20000 train_loss:2.1118 train_time:457642ms step_avg:44.00ms +step:10600/20000 train_loss:1.9861 train_time:466421ms step_avg:44.00ms +step:10800/20000 train_loss:2.1940 train_time:475206ms step_avg:44.00ms +step:11000/20000 train_loss:2.1234 train_time:483991ms step_avg:44.00ms +step:11000/20000 val_loss:2.1153 val_bpb:1.2528 train_time:484005ms step_avg:44.00ms +step:11200/20000 train_loss:2.0787 train_time:492763ms step_avg:44.00ms +step:11400/20000 train_loss:2.0635 train_time:501539ms step_avg:43.99ms +step:11600/20000 train_loss:2.0741 train_time:510320ms step_avg:43.99ms +step:11800/20000 train_loss:2.1060 train_time:519094ms step_avg:43.99ms +step:12000/20000 train_loss:2.0810 train_time:527876ms step_avg:43.99ms +step:12000/20000 val_loss:2.1095 val_bpb:1.2494 train_time:527890ms step_avg:43.99ms +step:12200/20000 train_loss:2.2265 train_time:536662ms step_avg:43.99ms +step:12400/20000 train_loss:1.8727 train_time:545506ms step_avg:43.99ms +step:12600/20000 train_loss:2.0998 train_time:554286ms step_avg:43.99ms +step:12800/20000 train_loss:2.1115 train_time:563070ms step_avg:43.99ms +step:13000/20000 train_loss:2.1838 train_time:571849ms step_avg:43.99ms +step:13000/20000 val_loss:2.0895 val_bpb:1.2375 train_time:571862ms step_avg:43.99ms +step:13200/20000 train_loss:2.1871 train_time:580628ms step_avg:43.99ms +step:13400/20000 train_loss:2.0566 train_time:589407ms step_avg:43.99ms +step:13600/20000 train_loss:1.9241 train_time:598195ms step_avg:43.98ms +step:13642/20000 val_loss:2.0669 val_bpb:1.2242 train_time:600038ms step_avg:43.98ms +stopping_early: wallclock_cap train_time:600038ms step:13642/20000 +peak memory allocated: 10119 MiB reserved: 10294 MiB +Serialized model: 67224983 bytes +Code size: 54138 bytes +Total submission size: 67279121 bytes +Serialized model int8+zlib: 15585390 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15639528 bytes +final_int8_zlib_roundtrip val_loss:2.0810 val_bpb:1.2325 eval_time:1410ms +final_int8_zlib_roundtrip_exact val_loss:2.08096715 val_bpb:1.23246596 diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/fe927335-4827-415a-b543-8b5d2706de4c.txt b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/fe927335-4827-415a-b543-8b5d2706de4c.txt new file mode 100644 index 0000000000..262271b0eb --- /dev/null +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/run_logs/fe927335-4827-415a-b543-8b5d2706de4c.txt @@ -0,0 +1,1445 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 100)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.05)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def apply_zero_init(model: nn.Module) -> nn.Module: + """ + Applies ZerO initialization to a given PyTorch nn.Module in-place. + + This function implements: + - Algorithm 1 (Linear/FFN Layers) + - Algorithm 2 (Convolution Layers) + - Section 4.1 Transformer specific initialization (W_Q=Identity, W_K/W_V=Zeros) + + Args: + model (nn.Module): The PyTorch model to initialize. + + Returns: + nn.Module: The model with initialized weights. + """ + + @torch.no_grad() + def _zero_init_tensor(tensor: torch.Tensor, out_f: int, in_f: int): + """Core logic corresponding to Algorithm 1 in the paper.""" + tensor.zero_() + + if out_f <= in_f: + # P_l == Q_l: Identity mapping + # P_l < Q_l: Partial identity (Propagates first P_l dimensions) + # torch.eye intrinsically handles rectangular partial-identities. + tensor.copy_(torch.eye(out_f, in_f, dtype=tensor.dtype, device=tensor.device)) + else: + # P_l > Q_l: Hadamard mapping + m = math.ceil(math.log2(out_f)) + m = max(m, 1) # Ensure at least m=1 + + # Recursively generate Hadamard matrix H_m + H = torch.tensor([[1.0]], dtype=tensor.dtype, device=tensor.device) + for _ in range(m): + H1 = torch.cat([H, H], dim=1) + H2 = torch.cat([H, -H], dim=1) + H = torch.cat([H1, H2], dim=0) + + # Normalization factor defined in the paper: c = 2^{-(m-1)/2} + c = 2.0 ** (-(m - 1) / 2.0) + + # Apply scaled Hadamard top-left submatrix (I* H_m I*) + tensor.copy_(c * H[:out_f, :in_f]) + + @torch.no_grad() + def _init_mha_zero(module: nn.MultiheadAttention): + """Initializes PyTorch's native MultiheadAttention according to Sec 4.1""" + embed_dim = module.embed_dim + + # in_proj_weight groups Q, K, and V + if module.in_proj_weight is not None: + module.in_proj_weight.zero_() + # W_Q is identity + module.in_proj_weight[:embed_dim, :].copy_( + torch.eye(embed_dim, dtype=module.in_proj_weight.dtype, device=module.in_proj_weight.device) + ) + # W_K, W_V remain zero (as handled by .zero_()) + else: + # If instantiated with separate weights + if getattr(module, 'q_proj_weight', None) is not None: + module.q_proj_weight.copy_( + torch.eye(embed_dim, dtype=module.q_proj_weight.dtype, device=module.q_proj_weight.device) + ) + if getattr(module, 'k_proj_weight', None) is not None: + module.k_proj_weight.zero_() + if getattr(module, 'v_proj_weight', None) is not None: + module.v_proj_weight.zero_() + + if getattr(module, 'in_proj_bias', None) is not None: + module.in_proj_bias.zero_() + + # The output projection relies on the standard algorithm (P_l = Q_l) -> Identity + if getattr(module, 'out_proj', None) is not None and getattr(module.out_proj, 'weight', None) is not None: + _zero_init_tensor(module.out_proj.weight, module.out_proj.out_features, module.out_proj.in_features) + if getattr(module.out_proj, 'bias', None) is not None: + module.out_proj.bias.zero_() + + @torch.no_grad() + def _init_conv_zero(module): + """Initializes Convolutions according to Algorithm 2""" + out_c, in_c = module.out_channels, module.in_channels + module.weight.zero_() + + # Find center index, taking n <- floor(k / 2) for each dimension + centers = tuple(k // 2 for k in module.kernel_size) + + # Create an out_c x in_c block utilizing the base ZerO logic + block = torch.zeros(out_c, in_c, dtype=module.weight.dtype, device=module.weight.device) + _zero_init_tensor(block, out_c, in_c) + + # Place block in the center slice of the kernel + if len(centers) == 1: # Conv1D + module.weight[:, :, centers[0]] = block + elif len(centers) == 2: # Conv2D + module.weight[:, :, centers[0], centers[1]] = block + elif len(centers) == 3: # Conv3D + module.weight[:, :, centers[0], centers[1], centers[2]] = block + + if getattr(module, 'bias', None) is not None: + module.bias.zero_() + + # Iterate through and parse all network submodules + for name, module in model.named_modules(): + + # 1. Transformers MultiheadAttention (PyTorch Native MHA) + if isinstance(module, nn.MultiheadAttention): + _init_mha_zero(module) + + # 2. Linear Layers (Applies to Feed-Forward Networks or HuggingFace Custom QKV Projections) + elif isinstance(module, nn.Linear): + name_lower = name.split('.')[-1].lower() + + # W_K, W_V at zero (Fallback heuristics for custom attention modules) + if any(k in name_lower for k in ['k_proj', 'v_proj', 'key', 'value']): + module.weight.data.zero_() + + # W_Q as identity (Fallback heuristics for custom attention modules) + elif any(q in name_lower for q in ['q_proj', 'query']): + module.weight.data.copy_( + torch.eye(module.out_features, module.in_features, dtype=module.weight.dtype, device=module.weight.device) + ) + + # Generic linear layer (e.g., Dimension expanding/shrinking FFN) + else: + _zero_init_tensor(module.weight, module.out_features, module.in_features) + + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + # 3. Convolutional Layers + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + _init_conv_zero(module) + + # 4. Normalization Layers (Standard practice is Scale=1, Bias=0) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + if getattr(module, 'weight', None) is not None: + module.weight.data.fill_(1.0) + if getattr(module, 'bias', None) is not None: + module.bias.data.zero_() + + return model + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + base_model = apply_zero_init(base_model) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri May 1 16:17:15 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.05 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:100 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:10/100 +warmup_step:20/100 +warmup_step:30/100 +warmup_step:40/100 +warmup_step:50/100 +warmup_step:60/100 +warmup_step:70/100 +warmup_step:80/100 +warmup_step:90/100 +warmup_step:100/100 +step:0/20000 val_loss:6.9315 val_bpb:4.1053 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9313 train_time:32ms step_avg:31.75ms +step:2/20000 train_loss:11.5735 train_time:78ms step_avg:38.84ms +step:3/20000 train_loss:7.1929 train_time:123ms step_avg:40.89ms +step:4/20000 train_loss:6.3270 train_time:167ms step_avg:41.81ms +step:5/20000 train_loss:6.0549 train_time:211ms step_avg:42.24ms +step:6/20000 train_loss:6.1655 train_time:256ms step_avg:42.66ms +step:7/20000 train_loss:5.8802 train_time:299ms step_avg:42.66ms +step:8/20000 train_loss:5.8191 train_time:342ms step_avg:42.74ms +step:9/20000 train_loss:5.7510 train_time:384ms step_avg:42.64ms +step:10/20000 train_loss:5.6883 train_time:429ms step_avg:42.88ms +step:200/20000 train_loss:3.0501 train_time:8729ms step_avg:43.64ms +step:400/20000 train_loss:2.4298 train_time:17491ms step_avg:43.73ms +step:600/20000 train_loss:2.5824 train_time:26245ms step_avg:43.74ms +step:800/20000 train_loss:2.3211 train_time:35007ms step_avg:43.76ms +step:1000/20000 train_loss:2.3946 train_time:43784ms step_avg:43.78ms +step:1000/20000 val_loss:2.3563 val_bpb:1.3955 train_time:43800ms step_avg:43.80ms +step:1200/20000 train_loss:2.4059 train_time:52566ms step_avg:43.80ms +step:1400/20000 train_loss:2.4449 train_time:61368ms step_avg:43.83ms +step:1600/20000 train_loss:2.1165 train_time:70173ms step_avg:43.86ms +step:1800/20000 train_loss:2.2146 train_time:78960ms step_avg:43.87ms +step:2000/20000 train_loss:2.2682 train_time:87756ms step_avg:43.88ms +step:2000/20000 val_loss:2.2519 val_bpb:1.3337 train_time:87770ms step_avg:43.89ms +step:2200/20000 train_loss:2.0919 train_time:96553ms step_avg:43.89ms +step:2400/20000 train_loss:2.2119 train_time:105340ms step_avg:43.89ms +step:2600/20000 train_loss:2.4323 train_time:114118ms step_avg:43.89ms +step:2800/20000 train_loss:2.2503 train_time:122918ms step_avg:43.90ms +step:3000/20000 train_loss:2.2446 train_time:131762ms step_avg:43.92ms +step:3000/20000 val_loss:2.2075 val_bpb:1.3074 train_time:131777ms step_avg:43.93ms +step:3200/20000 train_loss:2.2037 train_time:140550ms step_avg:43.92ms +step:3400/20000 train_loss:2.1692 train_time:149329ms step_avg:43.92ms +step:3600/20000 train_loss:2.1296 train_time:158110ms step_avg:43.92ms +step:3800/20000 train_loss:2.2317 train_time:166864ms step_avg:43.91ms +step:4000/20000 train_loss:2.1715 train_time:175643ms step_avg:43.91ms +step:4000/20000 val_loss:2.1805 val_bpb:1.2914 train_time:175658ms step_avg:43.91ms +step:4200/20000 train_loss:2.1877 train_time:184467ms step_avg:43.92ms +step:4400/20000 train_loss:2.1236 train_time:193243ms step_avg:43.92ms +step:4600/20000 train_loss:1.9859 train_time:202015ms step_avg:43.92ms +step:4800/20000 train_loss:2.2733 train_time:210787ms step_avg:43.91ms +step:5000/20000 train_loss:2.0361 train_time:219552ms step_avg:43.91ms +step:5000/20000 val_loss:2.1631 val_bpb:1.2811 train_time:219567ms step_avg:43.91ms +step:5200/20000 train_loss:2.1834 train_time:228313ms step_avg:43.91ms +step:5400/20000 train_loss:2.1937 train_time:237082ms step_avg:43.90ms +step:5600/20000 train_loss:2.1897 train_time:245839ms step_avg:43.90ms +step:5800/20000 train_loss:2.1550 train_time:254602ms step_avg:43.90ms +step:6000/20000 train_loss:2.2332 train_time:263360ms step_avg:43.89ms +step:6000/20000 val_loss:2.1526 val_bpb:1.2749 train_time:263375ms step_avg:43.90ms +step:6200/20000 train_loss:2.0989 train_time:272117ms step_avg:43.89ms +step:6400/20000 train_loss:2.1725 train_time:280883ms step_avg:43.89ms +step:6600/20000 train_loss:2.1283 train_time:289631ms step_avg:43.88ms +step:6800/20000 train_loss:2.2035 train_time:298379ms step_avg:43.88ms +step:7000/20000 train_loss:2.2323 train_time:307142ms step_avg:43.88ms +step:7000/20000 val_loss:2.1415 val_bpb:1.2683 train_time:307157ms step_avg:43.88ms +step:7200/20000 train_loss:2.2079 train_time:315894ms step_avg:43.87ms +step:7400/20000 train_loss:2.1283 train_time:324655ms step_avg:43.87ms +step:7600/20000 train_loss:2.0099 train_time:333407ms step_avg:43.87ms +step:7800/20000 train_loss:2.1548 train_time:342160ms step_avg:43.87ms +step:8000/20000 train_loss:2.1226 train_time:350916ms step_avg:43.86ms +step:8000/20000 val_loss:2.1314 val_bpb:1.2623 train_time:350930ms step_avg:43.87ms +step:8200/20000 train_loss:2.1968 train_time:359678ms step_avg:43.86ms +step:8400/20000 train_loss:2.1435 train_time:368495ms step_avg:43.87ms +step:8600/20000 train_loss:2.1418 train_time:377246ms step_avg:43.87ms +step:8800/20000 train_loss:2.1074 train_time:386013ms step_avg:43.87ms +step:9000/20000 train_loss:2.0360 train_time:394770ms step_avg:43.86ms +step:9000/20000 val_loss:2.1262 val_bpb:1.2593 train_time:394784ms step_avg:43.86ms +step:9200/20000 train_loss:2.0950 train_time:403535ms step_avg:43.86ms +step:9400/20000 train_loss:2.1434 train_time:412289ms step_avg:43.86ms +step:9600/20000 train_loss:2.1575 train_time:421036ms step_avg:43.86ms +step:9800/20000 train_loss:2.0780 train_time:429798ms step_avg:43.86ms +step:10000/20000 train_loss:2.1219 train_time:438556ms step_avg:43.86ms +step:10000/20000 val_loss:2.1198 val_bpb:1.2555 train_time:438582ms step_avg:43.86ms +step:10200/20000 train_loss:2.0745 train_time:447326ms step_avg:43.86ms +step:10400/20000 train_loss:2.1079 train_time:456077ms step_avg:43.85ms +step:10600/20000 train_loss:1.9842 train_time:464835ms step_avg:43.85ms +step:10800/20000 train_loss:2.1929 train_time:473598ms step_avg:43.85ms +step:11000/20000 train_loss:2.1213 train_time:482368ms step_avg:43.85ms +step:11000/20000 val_loss:2.1136 val_bpb:1.2518 train_time:482383ms step_avg:43.85ms +step:11200/20000 train_loss:2.0772 train_time:491132ms step_avg:43.85ms +step:11400/20000 train_loss:2.0655 train_time:499891ms step_avg:43.85ms +step:11600/20000 train_loss:2.0719 train_time:508649ms step_avg:43.85ms +step:11800/20000 train_loss:2.0984 train_time:517410ms step_avg:43.85ms +step:12000/20000 train_loss:2.0816 train_time:526173ms step_avg:43.85ms +step:12000/20000 val_loss:2.1085 val_bpb:1.2488 train_time:526188ms step_avg:43.85ms +step:12200/20000 train_loss:2.2273 train_time:534932ms step_avg:43.85ms +step:12400/20000 train_loss:1.8696 train_time:543766ms step_avg:43.85ms +step:12600/20000 train_loss:2.0964 train_time:552523ms step_avg:43.85ms +step:12800/20000 train_loss:2.1153 train_time:561272ms step_avg:43.85ms +step:13000/20000 train_loss:2.1838 train_time:570022ms step_avg:43.85ms +step:13000/20000 val_loss:2.0897 val_bpb:1.2376 train_time:570037ms step_avg:43.85ms +step:13200/20000 train_loss:2.1885 train_time:578784ms step_avg:43.85ms +step:13400/20000 train_loss:2.0564 train_time:587538ms step_avg:43.85ms +step:13600/20000 train_loss:1.9180 train_time:596292ms step_avg:43.85ms +step:13686/20000 val_loss:2.0654 val_bpb:1.2233 train_time:600051ms step_avg:43.84ms +stopping_early: wallclock_cap train_time:600051ms step:13686/20000 +peak memory allocated: 10119 MiB reserved: 10438 MiB +Serialized model: 67224983 bytes +Code size: 54155 bytes +Total submission size: 67279138 bytes +Serialized model int8+zlib: 15645913 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15700068 bytes +final_int8_zlib_roundtrip val_loss:2.0792 val_bpb:1.2314 eval_time:1407ms +final_int8_zlib_roundtrip_exact val_loss:2.07922175 val_bpb:1.23143224 From 3e35d0c39656ef836d7fc4ecdc20f12679e8634e Mon Sep 17 00:00:00 2001 From: Alston Tang <145715945+AlstonTang@users.noreply.github.com> Date: Fri, 1 May 2026 12:49:48 -0700 Subject: [PATCH 80/80] Update README.md --- .../2026-05-01_Follow_up_to_PR_2104/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md index 3b90491122..cf6e8abf15 100644 --- a/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md +++ b/records/track_non_record_16mb/2026-05-01_Follow_up_to_PR_2104/README.md @@ -57,7 +57,7 @@ Within run_logs/3c25e790-2f8a-4c36-881c-67066cf1e465.txt, although the model con Adding one more layer (so that the model has 10 layers, 8 query heads, 4 kv heads, and a dim of 512 with mlp_mult of 2), and running it for the full 20,000 steps (see run_logs/24d3334a-0ddf-45ca-8be8-9dbd470f8866.txt), we get a final bpb of ~1.2494 with the submission size still only taking up 15,221,665 despite a parameter count of 18,897,488. -I initially thought that ZerO would at least yield better loss. Although it didn't, the improvements in compression were very surprising, and I'm interesting in further increasing parameter efficiency with ZerO. +I initially thought that ZerO would at least yield better loss. Although it didn't, the improvements in compression were very surprising, and I'm interested in further increasing parameter efficiency with ZerO. ## Future Work My plans for this consist of the following: