Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ When you make a PR, please:

There are still many TODOs which may offer significant performance gains...

- [ ] Try `RoPE`/`AliBi` Position Embeddings
- [x] Try `RoPE`/`AliBi` Position Embeddings - `RoPE` added (1D temporal + 2D spatial), enable with `use_rope: true` in `configs/training.yaml`
- [ ] Add more datasets (Terraria, Street Fighter, \<your favorite retro videogame\>)
- [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter)
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT)
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - cosine schedule added, enable with `maskgit_schedule: "cosine"` in `configs/training.yaml`
- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters
- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames
Expand Down
3 changes: 3 additions & 0 deletions configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use_actions: false # use random actions
use_gt_actions: false # use lam-inferred actions
use_interactive_mode: true # use user-inputted actions

# MaskGIT unmasking schedule ("exp" or "cosine")
maskgit_schedule: "exp"

# inference acceleration
amp: false
tf32: false
Expand Down
6 changes: 6 additions & 0 deletions configs/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ use_moe: false
num_experts: 4
top_k_experts: 2
moe_aux_loss_coeff: 0.01

# RoPE: replace additive sinusoidal PE with rotary position embeddings in all attention layers
use_rope: false

# MaskGIT unmasking schedule for dynamics inference ("exp" or "cosine")
maskgit_schedule: "exp"
30 changes: 23 additions & 7 deletions models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,25 @@
class DynamicsModel(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads=8,
hidden_dim=128, num_blocks=4, num_bins=4, n_actions=8, conditioning_dim=3, latent_dim=5,
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01):
use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01,
use_rope=False):
super().__init__()
H, W = frame_size
codebook_size = num_bins**latent_dim
self.use_rope = use_rope
grid_size = (H // patch_size, W // patch_size)

self.latent_embed = nn.Linear(latent_dim, embed_dim)
self.transformer = STTransformer(
embed_dim, num_heads, hidden_dim, num_blocks, causal=True,
conditioning_dim=conditioning_dim,
use_moe=use_moe, num_experts=num_experts,
top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff,
use_rope=use_rope, grid_size=grid_size,
)
self.output_mlp = nn.Linear(embed_dim, codebook_size)

# shared spatial-only PE (zeros in temporal tail)
# shared spatial-only PE (zeros in temporal tail); used when use_rope=False
pe_spatial = build_spatial_only_pe((H, W), patch_size, embed_dim, device='cpu', dtype=torch.float32) # [1,P,E]
self.register_buffer("pos_spatial_dec", pe_spatial, persistent=False)

Expand Down Expand Up @@ -59,9 +63,9 @@ def forward(self, discrete_latents, training=True, conditioning=None, targets=No

embeddings = self.latent_embed(discrete_latents) # [B, T, P, E]

# add spatial PE (affects only first 2/3 of dimensions)
# STTransformer adds temporal PE to last 1/3 of dimensions
embeddings = embeddings + self.pos_spatial_dec.to(embeddings.device, embeddings.dtype)
# add spatial PE when not using RoPE (STTransformer adds temporal PE to last 1/3 of dims)
if not self.use_rope:
embeddings = embeddings + self.pos_spatial_dec.to(embeddings.device, embeddings.dtype)
transformed = self.transformer(embeddings, conditioning=conditioning) # [B, T, P, E]

# transform to logits for each token in codebook
Expand Down Expand Up @@ -91,8 +95,17 @@ def exp_schedule_torch(self, t, T, P_total, k, device):
return torch.tensor(P_total, dtype=result.dtype, device=device)
return result

def cosine_schedule_torch(self, t, T, P_total, device):
# cosine schedule: reveals P_total * (1 - cos(pi/2 * t/T)) tokens by step t
x = t / max(T, 1)
ratio = 1.0 - math.cos(math.pi / 2.0 * x)
result = torch.tensor(float(P_total) * ratio, device=device)
if t == T - 1:
return torch.tensor(float(P_total), device=device)
return result

@torch.no_grad()
def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0):
def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0, schedule: str = "exp"):
# MaskGIT-style iterative decoding across all prediction horizon steps
# context_latents: [B, T_ctx, P, L]
# T_ctx=context timesteps, H=prediction horizon, K=codebook size
Expand All @@ -108,7 +121,10 @@ def forward_inference(self, context_latents, prediction_horizon, num_steps, inde

P_total = H * P # total masked positions across the horizon window
for m in range(num_steps):
n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device)
if schedule == "cosine":
n_tokens_raw = self.cosine_schedule_torch(m, num_steps, P_total, device)
else:
n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device)

# predict logits for current input
logits, _, _ = self.forward(input_latents, training=False, conditioning=conditioning, targets=None) # [B, T_ctx+H, P, L^D]
Expand Down
25 changes: 16 additions & 9 deletions models/latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
NUM_LATENT_ACTIONS_BINS = 2

class LatentActionsEncoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3, use_rope=False):
super().__init__()
H, W = frame_size
grid_size = (H // patch_size, W // patch_size)
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True,
use_rope=use_rope, grid_size=grid_size)

# embeddings to discrete latent bottleneck actions
self.action_head = nn.Sequential(
Expand Down Expand Up @@ -50,10 +53,14 @@ def forward(self, frames):

class LatentActionsDecoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, conditioning_dim=3):
hidden_dim=256, num_blocks=4, conditioning_dim=3, use_rope=False):
super().__init__()
H, W = frame_size
grid_size = (H // patch_size, W // patch_size)
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=conditioning_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True,
conditioning_dim=conditioning_dim,
use_rope=use_rope, grid_size=grid_size)

# embeddings to mixed frame output patches
self.frame_head = nn.Sequential(
Expand Down Expand Up @@ -97,14 +104,14 @@ def forward(self, frames, actions, training=True):
return pred_frames # [B, T-1, C, H, W]

class LatentActionModel(nn.Module):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4, use_rope=False):
super().__init__()
assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}"
self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim)
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_rope=use_rope)
self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim, use_rope=use_rope)
self.var_target = 0.01
self.var_lambda = 100.0

Expand Down
77 changes: 75 additions & 2 deletions models/positional_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from einops import rearrange, repeat

# TODO: Try RoPE / AliBi

def sincos_1d(L, D, device, dtype):
# 1d sinusoidal position encoding where element j of ith patch embedding is encoded as:
# PE[i, 2j] = sin(i / 10000^(2j/D)) # even indices
Expand Down Expand Up @@ -57,4 +57,77 @@ def build_spatial_only_pe(frame_size, patch_size, embed_dim, device='cpu', dtype
], dim=-1) # [Hp, Wp, E]

pe_spatial = rearrange(pe_spatial, 'hp wp e -> 1 (hp wp) e') # [1, P, E]
return pe_spatial # [1, P, E]
return pe_spatial # [1, P, E]


# ---------------------------------------------------------------------------
# Rotary Position Embeddings (RoPE)
# ---------------------------------------------------------------------------

def rope_1d_cos_sin(seq_len, head_dim, device, dtype):
"""1D RoPE frequencies using the chunked convention: pairs (i, i+D/2).
Returns cos, sin each of shape [seq_len, head_dim].
"""
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
half = head_dim // 2
freqs = 1.0 / (10000 ** (torch.arange(0, half, device=device, dtype=dtype) / half)) # [D/2]
pos = torch.arange(seq_len, device=device, dtype=dtype) # [L]
angles = torch.outer(pos, freqs) # [L, D/2]
angles = torch.cat([angles, angles], dim=-1) # [L, D]
return torch.cos(angles), torch.sin(angles)


def rope_2d_cos_sin(Hp, Wp, head_dim, device, dtype):
"""2D RoPE frequencies for a (Hp x Wp) patch grid.

The first D/2 head dims encode the y-axis (row) and the last D/2 encode
the x-axis (col). Each half is treated as an independent 1D RoPE group
so the two axes do not mix during rotation.

Returns cos, sin each of shape [Hp*Wp, head_dim].
"""
assert head_dim % 4 == 0, f"head_dim must be divisible by 4 for 2D RoPE, got {head_dim}"
half = head_dim // 2
cos_y, sin_y = rope_1d_cos_sin(Hp, half, device, dtype) # [Hp, D/2]
cos_x, sin_x = rope_1d_cos_sin(Wp, half, device, dtype) # [Wp, D/2]
# broadcast to patch grid
cos_y = cos_y[:, None].expand(Hp, Wp, half).reshape(Hp * Wp, half) # [P, D/2]
sin_y = sin_y[:, None].expand(Hp, Wp, half).reshape(Hp * Wp, half)
cos_x = cos_x[None, :].expand(Hp, Wp, half).reshape(Hp * Wp, half) # [P, D/2]
sin_x = sin_x[None, :].expand(Hp, Wp, half).reshape(Hp * Wp, half)
return torch.cat([cos_y, cos_x], dim=-1), torch.cat([sin_y, sin_x], dim=-1) # each [P, D]


def apply_rope_1d(x, cos, sin):
"""Apply 1D RoPE to Q or K. Chunked convention: pairs (i, i+D/2).

x : [N, H, L, D]
cos: [L, D]
sin: [L, D]
"""
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
rot = torch.cat([-x2, x1], dim=-1) # rotate_half for chunked pairs
return x * cos[None, None] + rot * sin[None, None]


def apply_rope_2d(x, cos, sin):
"""Apply 2D RoPE to Q or K. First D/2 dims = y-axis, last D/2 dims = x-axis.

Each axis is rotated independently within its own D/2 block using the
chunked sub-pair convention: pairs (i, i+D/4) within each half.

x : [N, H, P, D]
cos: [P, D]
sin: [P, D]
"""
half = x.shape[-1] // 2
quarter = half // 2
# y-axis block
x1a, x1b = x[..., :quarter], x[..., quarter:half]
rot_y = torch.cat([-x1b, x1a], dim=-1) # [N, H, P, D/2]
# x-axis block
x2a, x2b = x[..., half:half + quarter], x[..., half + quarter:]
rot_x = torch.cat([-x2b, x2a], dim=-1) # [N, H, P, D/2]
rot = torch.cat([rot_y, rot_x], dim=-1) # [N, H, P, D]
return x * cos[None, None] + rot * sin[None, None]
Loading