From 68396d0eed604cd263c64b915fe0de5ba11ef027 Mon Sep 17 00:00:00 2001 From: Tasha Date: Thu, 16 Apr 2026 11:19:17 -0700 Subject: [PATCH] feat: add RoPE positional embeddings and cosine MaskGIT scheduler Implements two TODOs from the contributor list: 1. RoPE (Rotary Position Embeddings) as a drop-in replacement for the additive sinusoidal positional encodings. When use_rope=true: - TemporalAttention applies 1D RoPE to Q/K across the time axis - SpatialAttention applies 2D RoPE to Q/K using independent y/x axis encodings in the first and second halves of the head dimension - Additive temporal and spatial PEs are skipped entirely Enable with `use_rope: true` in configs/training.yaml. 2. Cosine MaskGIT unmasking schedule alongside the existing exponential schedule. The cosine variant reveals tokens proportional to 1 - cos(pi/2 * t/T), front-loading confidence early in decoding. Select with `maskgit_schedule: "cosine"` in configs/training.yaml or configs/inference.yaml (default: "exp", preserving prior behavior). All new flags default to false/"exp" so existing checkpoints and runs are unaffected. --- README.md | 4 +- configs/inference.yaml | 3 ++ configs/training.yaml | 6 +++ models/dynamics.py | 30 ++++++++++--- models/latent_actions.py | 25 +++++++---- models/positional_encoding.py | 77 +++++++++++++++++++++++++++++++- models/st_transformer.py | 68 ++++++++++++++++++---------- models/video_tokenizer.py | 32 +++++++------ scripts/run_inference.py | 1 + scripts/train_dynamics.py | 1 + scripts/train_latent_actions.py | 1 + scripts/train_video_tokenizer.py | 3 +- utils/config.py | 22 +++++++-- 13 files changed, 212 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 0f00c0e..0cccd87 100644 --- a/README.md +++ b/README.md @@ -269,10 +269,10 @@ When you make a PR, please: There are still many TODOs which may offer significant performance gains... -- [ ] Try `RoPE`/`AliBi` Position Embeddings +- [x] Try `RoPE`/`AliBi` Position Embeddings - `RoPE` added (1D temporal + 2D spatial), enable with `use_rope: true` in `configs/training.yaml` - [ ] Add more datasets (Terraria, Street Fighter, \) - [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) -- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) +- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - cosine schedule added, enable with `maskgit_schedule: "cosine"` in `configs/training.yaml` - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters - [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames diff --git a/configs/inference.yaml b/configs/inference.yaml index 616d8f1..adaef20 100644 --- a/configs/inference.yaml +++ b/configs/inference.yaml @@ -20,6 +20,9 @@ use_actions: false # use random actions use_gt_actions: false # use lam-inferred actions use_interactive_mode: true # use user-inputted actions +# MaskGIT unmasking schedule ("exp" or "cosine") +maskgit_schedule: "exp" + # inference acceleration amp: false tf32: false diff --git a/configs/training.yaml b/configs/training.yaml index 0ca6528..6e7de6c 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -47,3 +47,9 @@ use_moe: false num_experts: 4 top_k_experts: 2 moe_aux_loss_coeff: 0.01 + +# RoPE: replace additive sinusoidal PE with rotary position embeddings in all attention layers +use_rope: false + +# MaskGIT unmasking schedule for dynamics inference ("exp" or "cosine") +maskgit_schedule: "exp" diff --git a/models/dynamics.py b/models/dynamics.py index 62f2c7a..c7c5492 100644 --- a/models/dynamics.py +++ b/models/dynamics.py @@ -9,10 +9,13 @@ class DynamicsModel(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads=8, hidden_dim=128, num_blocks=4, num_bins=4, n_actions=8, conditioning_dim=3, latent_dim=5, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_rope=False): super().__init__() H, W = frame_size codebook_size = num_bins**latent_dim + self.use_rope = use_rope + grid_size = (H // patch_size, W // patch_size) self.latent_embed = nn.Linear(latent_dim, embed_dim) self.transformer = STTransformer( @@ -20,10 +23,11 @@ def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads conditioning_dim=conditioning_dim, use_moe=use_moe, num_experts=num_experts, top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff, + use_rope=use_rope, grid_size=grid_size, ) self.output_mlp = nn.Linear(embed_dim, codebook_size) - # shared spatial-only PE (zeros in temporal tail) + # shared spatial-only PE (zeros in temporal tail); used when use_rope=False pe_spatial = build_spatial_only_pe((H, W), patch_size, embed_dim, device='cpu', dtype=torch.float32) # [1,P,E] self.register_buffer("pos_spatial_dec", pe_spatial, persistent=False) @@ -59,9 +63,9 @@ def forward(self, discrete_latents, training=True, conditioning=None, targets=No embeddings = self.latent_embed(discrete_latents) # [B, T, P, E] - # add spatial PE (affects only first 2/3 of dimensions) - # STTransformer adds temporal PE to last 1/3 of dimensions - embeddings = embeddings + self.pos_spatial_dec.to(embeddings.device, embeddings.dtype) + # add spatial PE when not using RoPE (STTransformer adds temporal PE to last 1/3 of dims) + if not self.use_rope: + embeddings = embeddings + self.pos_spatial_dec.to(embeddings.device, embeddings.dtype) transformed = self.transformer(embeddings, conditioning=conditioning) # [B, T, P, E] # transform to logits for each token in codebook @@ -91,8 +95,17 @@ def exp_schedule_torch(self, t, T, P_total, k, device): return torch.tensor(P_total, dtype=result.dtype, device=device) return result + def cosine_schedule_torch(self, t, T, P_total, device): + # cosine schedule: reveals P_total * (1 - cos(pi/2 * t/T)) tokens by step t + x = t / max(T, 1) + ratio = 1.0 - math.cos(math.pi / 2.0 * x) + result = torch.tensor(float(P_total) * ratio, device=device) + if t == T - 1: + return torch.tensor(float(P_total), device=device) + return result + @torch.no_grad() - def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0): + def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0, schedule: str = "exp"): # MaskGIT-style iterative decoding across all prediction horizon steps # context_latents: [B, T_ctx, P, L] # T_ctx=context timesteps, H=prediction horizon, K=codebook size @@ -108,7 +121,10 @@ def forward_inference(self, context_latents, prediction_horizon, num_steps, inde P_total = H * P # total masked positions across the horizon window for m in range(num_steps): - n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device) + if schedule == "cosine": + n_tokens_raw = self.cosine_schedule_torch(m, num_steps, P_total, device) + else: + n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device) # predict logits for current input logits, _, _ = self.forward(input_latents, training=False, conditioning=conditioning, targets=None) # [B, T_ctx+H, P, L^D] diff --git a/models/latent_actions.py b/models/latent_actions.py index 900ab1b..f25373d 100644 --- a/models/latent_actions.py +++ b/models/latent_actions.py @@ -11,11 +11,14 @@ NUM_LATENT_ACTIONS_BINS = 2 class LatentActionsEncoder(nn.Module): - def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, action_dim=3): + def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, + hidden_dim=256, num_blocks=4, action_dim=3, use_rope=False): super().__init__() + H, W = frame_size + grid_size = (H // patch_size, W // patch_size) self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_rope=use_rope, grid_size=grid_size) # embeddings to discrete latent bottleneck actions self.action_head = nn.Sequential( @@ -50,10 +53,14 @@ def forward(self, frames): class LatentActionsDecoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, conditioning_dim=3): + hidden_dim=256, num_blocks=4, conditioning_dim=3, use_rope=False): super().__init__() + H, W = frame_size + grid_size = (H // patch_size, W // patch_size) self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=conditioning_dim) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + conditioning_dim=conditioning_dim, + use_rope=use_rope, grid_size=grid_size) # embeddings to mixed frame output patches self.frame_head = nn.Sequential( @@ -97,14 +104,14 @@ def forward(self, frames, actions, training=True): return pred_frames # [B, T-1, C, H, W] class LatentActionModel(nn.Module): - def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, - num_heads=8, hidden_dim=256, num_blocks=4): + def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, + num_heads=8, hidden_dim=256, num_blocks=4, use_rope=False): super().__init__() assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}" self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) - self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim) + self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_rope=use_rope) self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS) - self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim) + self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim, use_rope=use_rope) self.var_target = 0.01 self.var_lambda = 100.0 diff --git a/models/positional_encoding.py b/models/positional_encoding.py index 18b358c..3bffdcc 100644 --- a/models/positional_encoding.py +++ b/models/positional_encoding.py @@ -1,7 +1,7 @@ import torch from einops import rearrange, repeat -# TODO: Try RoPE / AliBi + def sincos_1d(L, D, device, dtype): # 1d sinusoidal position encoding where element j of ith patch embedding is encoded as: # PE[i, 2j] = sin(i / 10000^(2j/D)) # even indices @@ -57,4 +57,77 @@ def build_spatial_only_pe(frame_size, patch_size, embed_dim, device='cpu', dtype ], dim=-1) # [Hp, Wp, E] pe_spatial = rearrange(pe_spatial, 'hp wp e -> 1 (hp wp) e') # [1, P, E] - return pe_spatial # [1, P, E] \ No newline at end of file + return pe_spatial # [1, P, E] + + +# --------------------------------------------------------------------------- +# Rotary Position Embeddings (RoPE) +# --------------------------------------------------------------------------- + +def rope_1d_cos_sin(seq_len, head_dim, device, dtype): + """1D RoPE frequencies using the chunked convention: pairs (i, i+D/2). + Returns cos, sin each of shape [seq_len, head_dim]. + """ + assert head_dim % 2 == 0, "head_dim must be even for RoPE" + half = head_dim // 2 + freqs = 1.0 / (10000 ** (torch.arange(0, half, device=device, dtype=dtype) / half)) # [D/2] + pos = torch.arange(seq_len, device=device, dtype=dtype) # [L] + angles = torch.outer(pos, freqs) # [L, D/2] + angles = torch.cat([angles, angles], dim=-1) # [L, D] + return torch.cos(angles), torch.sin(angles) + + +def rope_2d_cos_sin(Hp, Wp, head_dim, device, dtype): + """2D RoPE frequencies for a (Hp x Wp) patch grid. + + The first D/2 head dims encode the y-axis (row) and the last D/2 encode + the x-axis (col). Each half is treated as an independent 1D RoPE group + so the two axes do not mix during rotation. + + Returns cos, sin each of shape [Hp*Wp, head_dim]. + """ + assert head_dim % 4 == 0, f"head_dim must be divisible by 4 for 2D RoPE, got {head_dim}" + half = head_dim // 2 + cos_y, sin_y = rope_1d_cos_sin(Hp, half, device, dtype) # [Hp, D/2] + cos_x, sin_x = rope_1d_cos_sin(Wp, half, device, dtype) # [Wp, D/2] + # broadcast to patch grid + cos_y = cos_y[:, None].expand(Hp, Wp, half).reshape(Hp * Wp, half) # [P, D/2] + sin_y = sin_y[:, None].expand(Hp, Wp, half).reshape(Hp * Wp, half) + cos_x = cos_x[None, :].expand(Hp, Wp, half).reshape(Hp * Wp, half) # [P, D/2] + sin_x = sin_x[None, :].expand(Hp, Wp, half).reshape(Hp * Wp, half) + return torch.cat([cos_y, cos_x], dim=-1), torch.cat([sin_y, sin_x], dim=-1) # each [P, D] + + +def apply_rope_1d(x, cos, sin): + """Apply 1D RoPE to Q or K. Chunked convention: pairs (i, i+D/2). + + x : [N, H, L, D] + cos: [L, D] + sin: [L, D] + """ + half = x.shape[-1] // 2 + x1, x2 = x[..., :half], x[..., half:] + rot = torch.cat([-x2, x1], dim=-1) # rotate_half for chunked pairs + return x * cos[None, None] + rot * sin[None, None] + + +def apply_rope_2d(x, cos, sin): + """Apply 2D RoPE to Q or K. First D/2 dims = y-axis, last D/2 dims = x-axis. + + Each axis is rotated independently within its own D/2 block using the + chunked sub-pair convention: pairs (i, i+D/4) within each half. + + x : [N, H, P, D] + cos: [P, D] + sin: [P, D] + """ + half = x.shape[-1] // 2 + quarter = half // 2 + # y-axis block + x1a, x1b = x[..., :quarter], x[..., quarter:half] + rot_y = torch.cat([-x1b, x1a], dim=-1) # [N, H, P, D/2] + # x-axis block + x2a, x2b = x[..., half:half + quarter], x[..., half + quarter:] + rot_x = torch.cat([-x2b, x2a], dim=-1) # [N, H, P, D/2] + rot = torch.cat([rot_y, rot_x], dim=-1) # [N, H, P, D] + return x * cos[None, None] + rot * sin[None, None] \ No newline at end of file diff --git a/models/st_transformer.py b/models/st_transformer.py index a1621e7..eafba54 100644 --- a/models/st_transformer.py +++ b/models/st_transformer.py @@ -1,14 +1,17 @@ import torch import torch.nn as nn from einops import rearrange -from models.positional_encoding import build_spatial_only_pe, sincos_time +from models.positional_encoding import ( + build_spatial_only_pe, sincos_time, + rope_1d_cos_sin, rope_2d_cos_sin, apply_rope_1d, apply_rope_2d, +) from models.norms import AdaptiveNormalizer from models.patch_embed import PatchEmbedding import math import torch.nn.functional as F class SpatialAttention(nn.Module): - def __init__(self, embed_dim, num_heads, conditioning_dim=None): + def __init__(self, embed_dim, num_heads, conditioning_dim=None, use_rope=False, grid_size=None): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -21,16 +24,24 @@ def __init__(self, embed_dim, num_heads, conditioning_dim=None): self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) + self.use_rope = use_rope + self.grid_size = grid_size # (Hp, Wp) required when use_rope=True def forward(self, x, conditioning=None): B, T, P, E = x.shape - # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H] + # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H] # (4 dims to work with torch compile attention) q = rearrange(self.q_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) k = rearrange(self.k_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) v = rearrange(self.v_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) + if self.use_rope: + Hp, Wp = self.grid_size + cos, sin = rope_2d_cos_sin(Hp, Wp, self.head_dim, x.device, x.dtype) # [P, D] + q = apply_rope_2d(q, cos, sin) + k = apply_rope_2d(k, cos, sin) + k_t = k.transpose(-2, -1) # [(B*T), H, P, D, P] # attention(q, k, v) = softmax(qk^T / sqrt(d)) v @@ -48,30 +59,36 @@ def forward(self, x, conditioning=None): return out # [B, T, P, E] class TemporalAttention(nn.Module): - def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None): + def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None, use_rope=False): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == embed_dim - + self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) - + self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) self.causal = causal - + self.use_rope = use_rope + def forward(self, x, conditioning=None): B, T, P, E = x.shape - - # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D] + + # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D] # (4 dims to work with torch compile attention) q = rearrange(self.q_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) k = rearrange(self.k_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) v = rearrange(self.v_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) # [B, P, H, T, D] + if self.use_rope: + cos, sin = rope_1d_cos_sin(T, self.head_dim, x.device, x.dtype) # [T, D] + q = apply_rope_1d(q, cos, sin) + k = apply_rope_1d(k, cos, sin) + k_t = k.transpose(-2, -1) # [(B*P), H, T, D, T] # attention(q, k, v) = softmax(qk^T / sqrt(d)) v @@ -200,10 +217,11 @@ def forward(self, x, conditioning=None): class STTransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, hidden_dim, causal=True, conditioning_dim=None, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_rope=False, grid_size=None): super().__init__() - self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim) - self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim) + self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim, use_rope=use_rope, grid_size=grid_size) + self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim, use_rope=use_rope) if use_moe: self.ffn = MoESwiGLUFFN( embed_dim, hidden_dim, @@ -224,9 +242,11 @@ def forward(self, x, conditioning=None): class STTransformer(nn.Module): def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=None, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_rope=False, grid_size=None): super().__init__() - # calculate temporal PE dim + self.use_rope = use_rope + # temporal PE dims only needed when not using RoPE self.temporal_dim = (embed_dim // 3) & ~1 # round down to even number self.spatial_dims = embed_dim - self.temporal_dim # rest goes to spatial @@ -235,24 +255,26 @@ def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, co embed_dim, num_heads, hidden_dim, causal, conditioning_dim, use_moe=use_moe, num_experts=num_experts, top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff, + use_rope=use_rope, grid_size=grid_size, ) for _ in range(num_blocks) ]) - + def forward(self, x, conditioning=None): # x: [B, T, P, E] # conditioning: [B, T, E] B, T, P, E = x.shape - tpe = sincos_time(T, self.temporal_dim, x.device, x.dtype) # [T, E/3] - # temporal PE (pad with 0s for first 2/3s spatial PE, last 1/3 temporal PE) - tpe_padded = torch.cat([ - torch.zeros(T, self.spatial_dims, device=x.device, dtype=x.dtype), - tpe - ], dim=-1) # [T, E] - x = x + tpe_padded[None, :, None, :] # [B,T,P,E] + if not self.use_rope: + tpe = sincos_time(T, self.temporal_dim, x.device, x.dtype) # [T, E/3] + # temporal PE (pad with 0s for first 2/3s spatial PE, last 1/3 temporal PE) + tpe_padded = torch.cat([ + torch.zeros(T, self.spatial_dims, device=x.device, dtype=x.dtype), + tpe + ], dim=-1) # [T, E] + x = x + tpe_padded[None, :, None, :] # [B,T,P,E] - # apply transformer blocks + # apply transformer blocks (RoPE applied inside each attention layer when use_rope=True) for block in self.blocks: x = block(x, conditioning) return x diff --git a/models/video_tokenizer.py b/models/video_tokenizer.py index 0a56c8b..b74287f 100644 --- a/models/video_tokenizer.py +++ b/models/video_tokenizer.py @@ -9,11 +9,14 @@ from models.positional_encoding import build_spatial_only_pe class VideoTokenizerEncoder(nn.Module): - def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=5): + def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, + hidden_dim=256, num_blocks=4, latent_dim=5, use_rope=False): super().__init__() + H, W = frame_size + grid_size = (H // patch_size, W // patch_size) self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_rope=use_rope, grid_size=grid_size) self.latent_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, latent_dim) @@ -46,28 +49,31 @@ def forward(self, tokens): # [B, T, P, E] class VideoTokenizerDecoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=5): + hidden_dim=256, num_blocks=4, latent_dim=5, use_rope=False): super().__init__() H, W = frame_size self.patch_size = patch_size self.Hp, self.Wp = H // patch_size, W // patch_size self.num_patches = self.Hp * self.Wp - + self.use_rope = use_rope + self.latent_embed = nn.Linear(latent_dim, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_rope=use_rope, grid_size=(self.Hp, self.Wp)) self.frame_head = PixelShuffleFrameHead(embed_dim, patch_size=patch_size, channels=3, H=H, W=W) - # first 2/3 spatial PE (temporal is last 1/3) + # spatial PE (only used when not using RoPE) pe_spatial_dec = build_spatial_only_pe((H, W), self.patch_size, embed_dim, device='cpu', dtype=torch.float32) # [1,P,E] self.register_buffer("pos_spatial_dec", pe_spatial_dec, persistent=False) def forward(self, latents): # latents: [B, T, P, L] - # embed latents and add spatial PE + # embed latents and optionally add spatial PE embedding = self.latent_embed(latents) # [B, T, P, E] - embedding = embedding + self.pos_spatial_dec.to(dtype=embedding.dtype, device=embedding.device) + if not self.use_rope: + embedding = embedding + self.pos_spatial_dec.to(dtype=embedding.dtype, device=embedding.device) - # apply transformer (temporal PE added inside) + # apply transformer (temporal PE or RoPE applied inside) embedding = self.transformer(embedding) # [B, T, P, E] # reconstruct frames using patch-wise head @@ -78,10 +84,10 @@ def forward(self, latents): class VideoTokenizer(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=3, num_bins=4): + hidden_dim=256, num_blocks=4, latent_dim=3, num_bins=4, use_rope=False): super().__init__() - self.encoder = VideoTokenizerEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim) - self.decoder = VideoTokenizerDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim) + self.encoder = VideoTokenizerEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim, use_rope=use_rope) + self.decoder = VideoTokenizerDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim, use_rope=use_rope) self.quantizer = FiniteScalarQuantizer(latent_dim, num_bins) self.codebook_size = num_bins**latent_dim diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 10f6439..76687a8 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -128,6 +128,7 @@ def idx_to_latents(idx): index_to_latents_fn=idx_to_latents, conditioning=action_latent, temperature=args.temperature, + schedule=getattr(args, 'maskgit_schedule', 'exp'), ) # decode next video tokens to frames diff --git a/scripts/train_dynamics.py b/scripts/train_dynamics.py index 934dc55..10e9224 100644 --- a/scripts/train_dynamics.py +++ b/scripts/train_dynamics.py @@ -81,6 +81,7 @@ def main(): num_experts=getattr(args, 'num_experts', 4), top_k_experts=getattr(args, 'top_k_experts', 2), moe_aux_loss_coeff=getattr(args, 'moe_aux_loss_coeff', 0.01), + use_rope=getattr(args, 'use_rope', False), ).to(args.device) if args.checkpoint: dynamics_model, _ = load_dynamics_from_checkpoint( diff --git a/scripts/train_latent_actions.py b/scripts/train_latent_actions.py index 1c3f18d..48d9fac 100644 --- a/scripts/train_latent_actions.py +++ b/scripts/train_latent_actions.py @@ -57,6 +57,7 @@ def main(): hidden_dim=args.hidden_dim, num_blocks=args.num_blocks, n_actions=args.n_actions, + use_rope=getattr(args, 'use_rope', False), ).to(args.device) if args.checkpoint: model, _ = load_latent_actions_from_checkpoint( diff --git a/scripts/train_video_tokenizer.py b/scripts/train_video_tokenizer.py index fc808bb..b6270b7 100644 --- a/scripts/train_video_tokenizer.py +++ b/scripts/train_video_tokenizer.py @@ -51,7 +51,7 @@ def main(): # print("Length of validation data:", len(validation_data)) # init model and optional ckpt load model = VideoTokenizer( - frame_size=(args.frame_size, args.frame_size), + frame_size=(args.frame_size, args.frame_size), patch_size=args.patch_size, embed_dim=args.embed_dim, num_heads=args.num_heads, @@ -59,6 +59,7 @@ def main(): num_blocks=args.num_blocks, latent_dim=args.latent_dim, num_bins=args.num_bins, + use_rope=getattr(args, 'use_rope', False), ).to(args.device) if args.checkpoint: model, _ = load_videotokenizer_from_checkpoint( diff --git a/utils/config.py b/utils/config.py index 846c7b2..a56e1f1 100644 --- a/utils/config.py +++ b/utils/config.py @@ -108,6 +108,8 @@ class VideoTokenizerConfig: wandb_project: str # resume from checkpoint checkpoint: Optional[str] + # RoPE + use_rope: bool = False # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 @@ -117,7 +119,7 @@ class VideoTokenizerConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -154,6 +156,8 @@ class LatentActionsConfig: wandb_project: str # resume from checkpoint checkpoint: Optional[str] + # RoPE + use_rope: bool = False # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 @@ -163,7 +167,7 @@ class LatentActionsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -212,6 +216,10 @@ class DynamicsConfig: num_experts: int = 4 top_k_experts: int = 2 moe_aux_loss_coeff: float = 0.01 + # RoPE + use_rope: bool = False + # MaskGIT unmasking schedule ("exp" or "cosine") + maskgit_schedule: str = "exp" # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 @@ -221,7 +229,7 @@ class DynamicsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -277,11 +285,15 @@ class TrainingConfig: num_experts: int = 4 top_k_experts: int = 2 moe_aux_loss_coeff: float = 0.01 + # RoPE + use_rope: bool = False + # MaskGIT unmasking schedule ("exp" or "cosine") + maskgit_schedule: str = "exp" # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 muon_backend_steps: int = 5 - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -310,6 +322,8 @@ class InferenceConfig: compile: bool # Interactive mode (user enters action ids) use_interactive_mode: bool + # MaskGIT unmasking schedule ("exp" or "cosine") + maskgit_schedule: str = "exp" preload_ratio: Optional[float] = None