Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ There are still many TODOs which may offer significant performance gains...

- [ ] Try `RoPE`/`AliBi` Position Embeddings
- [ ] Add more datasets (Terraria, Street Fighter, \<your favorite retro videogame\>)
- [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter)
- [x] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter)
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT)
- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters
Expand Down
5 changes: 4 additions & 1 deletion configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ use_interactive_mode: true # use user-inputted actions
# inference acceleration
amp: false
tf32: false
compile: false
compile: false

# AdaLN-Zero: must match the training config used for loaded checkpoints
use_adaln_zero: false
3 changes: 3 additions & 0 deletions configs/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ optimizer: "adamw"
muon_momentum: 0.95
muon_backend_steps: 5

# AdaLN-Zero: pre-norm conditioning with zero-init gate (alternative to FiLM)
use_adaln_zero: false

# MoE (dynamics model only): replaces SwiGLU FFN with top-k routed experts
use_moe: false
num_experts: 4
Expand Down
4 changes: 3 additions & 1 deletion models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class DynamicsModel(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads=8,
hidden_dim=128, num_blocks=4, num_bins=4, n_actions=8, conditioning_dim=3, latent_dim=5,
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01):
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01,
use_adaln_zero=False):
super().__init__()
H, W = frame_size
codebook_size = num_bins**latent_dim
Expand All @@ -20,6 +21,7 @@ def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads
conditioning_dim=conditioning_dim,
use_moe=use_moe, num_experts=num_experts,
top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff,
use_adaln_zero=use_adaln_zero,
)
self.output_mlp = nn.Linear(embed_dim, codebook_size)

Expand Down
22 changes: 12 additions & 10 deletions models/latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
NUM_LATENT_ACTIONS_BINS = 2

class LatentActionsEncoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3, use_adaln_zero=False):
super().__init__()
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True,
use_adaln_zero=use_adaln_zero)

# embeddings to discrete latent bottleneck actions
self.action_head = nn.Sequential(
Expand Down Expand Up @@ -50,10 +51,11 @@ def forward(self, frames):

class LatentActionsDecoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, conditioning_dim=3):
hidden_dim=256, num_blocks=4, conditioning_dim=3, use_adaln_zero=False):
super().__init__()
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=conditioning_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True,
conditioning_dim=conditioning_dim, use_adaln_zero=use_adaln_zero)

# embeddings to mixed frame output patches
self.frame_head = nn.Sequential(
Expand Down Expand Up @@ -97,14 +99,14 @@ def forward(self, frames, actions, training=True):
return pred_frames # [B, T-1, C, H, W]

class LatentActionModel(nn.Module):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4, use_adaln_zero=False):
super().__init__()
assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}"
self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim)
self.action_dim = int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_adaln_zero=use_adaln_zero)
self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim, use_adaln_zero=use_adaln_zero)
self.var_target = 0.01
self.var_lambda = 100.0

Expand Down
34 changes: 34 additions & 0 deletions models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,40 @@
import torch.nn as nn
from einops import repeat

class AdaLNZeroNorm(nn.Module):
"""Pre-norm with zero-init conditioning MLP producing (scale, shift, gate).

Gate alpha is zero-initialized so residual paths start as identity.
Returns (normed_x, gate); caller computes: x + gate * sublayer(normed_x).
When unconditioned, gate is None and caller computes: x + sublayer(normed_x).
"""
def __init__(self, embed_dim, conditioning_dim=None):
super().__init__()
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
self.to_params = None
if conditioning_dim is not None:
self.to_params = nn.Sequential(
nn.SiLU(),
nn.Linear(conditioning_dim, 3 * embed_dim), # scale, shift, gate
)
nn.init.zeros_(self.to_params[-1].weight)
nn.init.zeros_(self.to_params[-1].bias)

def forward(self, x, conditioning=None):
# x: [B, T, P, E]; returns (normed_x: [B,T,P,E], gate: [B,T,P,E] or None)
x_normed = self.norm(x)
if self.to_params is None or conditioning is None:
return x_normed, None
B, T, P, E = x_normed.shape
params = self.to_params(conditioning) # [B, T, 3E]
params = repeat(params, 'b t e -> b t p e', p=P) # [B, T, P, 3E]
scale, shift, gate = params.chunk(3, dim=-1) # each [B, T, P, E]
if scale.shape[1] == x.shape[1] - 1:
pad = lambda t: torch.cat([torch.zeros_like(t[:, :1]), t], dim=1)
scale, shift, gate = pad(scale), pad(shift), pad(gate)
return x_normed * (1 + scale) + shift, gate


class RMSNorm(nn.Module):
def __init__(self, embed_dim, eps=1e-5):
super().__init__()
Expand Down
120 changes: 82 additions & 38 deletions models/st_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import torch.nn as nn
from einops import rearrange
from models.positional_encoding import build_spatial_only_pe, sincos_time
from models.norms import AdaptiveNormalizer
from models.norms import AdaptiveNormalizer, AdaLNZeroNorm
from models.patch_embed import PatchEmbedding
import math
import torch.nn.functional as F

class SpatialAttention(nn.Module):
def __init__(self, embed_dim, num_heads, conditioning_dim=None):
def __init__(self, embed_dim, num_heads, conditioning_dim=None, use_adaln_zero=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
Expand All @@ -20,16 +20,25 @@ def __init__(self, embed_dim, num_heads, conditioning_dim=None):
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)

self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)
self.use_adaln_zero = use_adaln_zero
if use_adaln_zero:
self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim)
else:
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)

def forward(self, x, conditioning=None):
B, T, P, E = x.shape

# project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H]
if self.use_adaln_zero:
x_in, gate = self.norm(x, conditioning) # pre-norm + conditioning params
else:
x_in = x

# project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H]
# (4 dims to work with torch compile attention)
q = rearrange(self.q_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads)
k = rearrange(self.k_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads)
v = rearrange(self.v_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads)
q = rearrange(self.q_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads)
k = rearrange(self.k_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads)
v = rearrange(self.v_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads)

k_t = k.transpose(-2, -1) # [(B*T), H, P, D, P]

Expand All @@ -42,35 +51,46 @@ def forward(self, x, conditioning=None):
# out proj to mix head information
attn_out = self.out_proj(attn_output) # [B, T, P, E]

# residual and optionally conditioned norm
out = self.norm(x + attn_out, conditioning) # [B, T, P, E]

return out # [B, T, P, E]
if self.use_adaln_zero:
# gated residual; gate is None when unconditioned (identity)
return x + (gate * attn_out if gate is not None else attn_out) # [B, T, P, E]
else:
# residual and optionally conditioned post-norm (FiLM)
return self.norm(x + attn_out, conditioning) # [B, T, P, E]

class TemporalAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None):
def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None, use_adaln_zero=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim

self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)

self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)

self.use_adaln_zero = use_adaln_zero
if use_adaln_zero:
self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim)
else:
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)
self.causal = causal

def forward(self, x, conditioning=None):
B, T, P, E = x.shape

# project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D]

if self.use_adaln_zero:
x_in, gate = self.norm(x, conditioning) # pre-norm + conditioning params
else:
x_in = x

# project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D]
# (4 dims to work with torch compile attention)
q = rearrange(self.q_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) # [B, P, H, T, D]
q = rearrange(self.q_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads)
k = rearrange(self.k_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads)
v = rearrange(self.v_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads)

k_t = k.transpose(-2, -1) # [(B*P), H, T, D, T]

Expand All @@ -89,26 +109,37 @@ def forward(self, x, conditioning=None):
# out proj to mix head information
attn_out = self.out_proj(attn_output) # [B, T, P, E]

# residual and optionally conditioned norm
out = self.norm(x + attn_out, conditioning) # [B, T, P, E]

return out # [B, T, P, E]
if self.use_adaln_zero:
return x + (gate * attn_out if gate is not None else attn_out) # [B, T, P, E]
else:
# residual and optionally conditioned post-norm (FiLM)
return self.norm(x + attn_out, conditioning) # [B, T, P, E]

class SwiGLUFFN(nn.Module):
# swiglu(x) = W3(sigmoid(W1(x) + b1) * (W2(x) + b2)) + b3
def __init__(self, embed_dim, hidden_dim, conditioning_dim=None):
def __init__(self, embed_dim, hidden_dim, conditioning_dim=None, use_adaln_zero=False):
super().__init__()
h = math.floor(2 * hidden_dim / 3)
self.w_v = nn.Linear(embed_dim, h)
self.w_g = nn.Linear(embed_dim, h)
self.w_o = nn.Linear(h, embed_dim)
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)
self.use_adaln_zero = use_adaln_zero
if use_adaln_zero:
self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim)
else:
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)

def forward(self, x, conditioning=None):
v = F.silu(self.w_v(x)) # [B, T, P, h]
g = self.w_g(x) # [B, T, P, h]
out = self.w_o(v * g) # [B, T, P, E]
return self.norm(x + out, conditioning) # [B, T, P, E]
if self.use_adaln_zero:
x_in, gate = self.norm(x, conditioning)
else:
x_in = x
v = F.silu(self.w_v(x_in)) # [B, T, P, h]
g = self.w_g(x_in) # [B, T, P, h]
out = self.w_o(v * g) # [B, T, P, E]
if self.use_adaln_zero:
return x + (gate * out if gate is not None else out) # [B, T, P, E]
return self.norm(x + out, conditioning) # [B, T, P, E]


class SwiGLUExpert(nn.Module):
Expand All @@ -125,7 +156,7 @@ def forward(self, x):

class MoESwiGLUFFN(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_experts=4, top_k=2,
aux_loss_coeff=0.01, conditioning_dim=None):
aux_loss_coeff=0.01, conditioning_dim=None, use_adaln_zero=False):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
Expand All @@ -135,7 +166,11 @@ def __init__(self, embed_dim, hidden_dim, num_experts=4, top_k=2,
self.experts = nn.ModuleList([
SwiGLUExpert(embed_dim, hidden_dim) for _ in range(num_experts)
])
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)
self.use_adaln_zero = use_adaln_zero
if use_adaln_zero:
self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim)
else:
self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim)

self._aux_loss = None
self._expert_counts = None # per-expert token fractions from last forward
Expand All @@ -156,6 +191,9 @@ def forward(self, x, conditioning=None):
B, T, P, E = x.shape
residual = x

if self.use_adaln_zero:
x, gate = self.norm(x, conditioning) # pre-norm; gate applied after expert combine

# flatten spatial dims for routing: [B*T*P, E]
flat = x.reshape(-1, E)
N = flat.shape[0]
Expand Down Expand Up @@ -196,23 +234,27 @@ def forward(self, x, conditioning=None):

# reshape back and apply residual + norm
out = output.reshape(B, T, P, E)
if self.use_adaln_zero:
return residual + (gate * out if gate is not None else out) # [B, T, P, E]
return self.norm(residual + out, conditioning) # [B, T, P, E]

class STTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, hidden_dim, causal=True, conditioning_dim=None,
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01):
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01,
use_adaln_zero=False):
super().__init__()
self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim)
self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim)
self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim, use_adaln_zero)
self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim, use_adaln_zero)
if use_moe:
self.ffn = MoESwiGLUFFN(
embed_dim, hidden_dim,
num_experts=num_experts, top_k=top_k_experts,
aux_loss_coeff=moe_aux_loss_coeff,
conditioning_dim=conditioning_dim,
use_adaln_zero=use_adaln_zero,
)
else:
self.ffn = SwiGLUFFN(embed_dim, hidden_dim, conditioning_dim)
self.ffn = SwiGLUFFN(embed_dim, hidden_dim, conditioning_dim, use_adaln_zero)

def forward(self, x, conditioning=None):
# x: [B, T, P, E]
Expand All @@ -224,7 +266,8 @@ def forward(self, x, conditioning=None):

class STTransformer(nn.Module):
def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=None,
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01):
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01,
use_adaln_zero=False):
super().__init__()
# calculate temporal PE dim
self.temporal_dim = (embed_dim // 3) & ~1 # round down to even number
Expand All @@ -235,6 +278,7 @@ def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, co
embed_dim, num_heads, hidden_dim, causal, conditioning_dim,
use_moe=use_moe, num_experts=num_experts,
top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff,
use_adaln_zero=use_adaln_zero,
)
for _ in range(num_blocks)
])
Expand Down
Loading