From 0529bc88154d8cb672057e52ced8f4501634d91b Mon Sep 17 00:00:00 2001 From: Tasha Date: Thu, 16 Apr 2026 12:45:22 -0700 Subject: [PATCH] feat: add windowed attention encoder for action tokenizer Adds WindowedFrameAttention: for each pair of consecutive frames (t, t+1), concatenates their P patch embeddings into a 2P sequence, applies self-attention, and mean-pools to a single embedding. Compared to mean pool + concat, patches from both frames can interact before pooling, giving richer inter-frame signal for action inference. Controlled by use_windowed_attention (default False) on LatentActionModel and LatentActionsConfig. The action head input shrinks from embed_dim*2 to embed_dim. --- README.md | 2 +- configs/training.yaml | 3 + models/latent_actions.py | 100 +++++++++++++++++++++++--------- scripts/train_latent_actions.py | 1 + utils/config.py | 2 + 5 files changed, 79 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 0f00c0e..f7f3ccf 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,7 @@ There are still many TODOs which may offer significant performance gains... - [ ] Add more datasets (Terraria, Street Fighter, \) - [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) -- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` +- [x] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters - [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames - [x] Implement Mixture of Experts in the Feedforward Network - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20) diff --git a/configs/training.yaml b/configs/training.yaml index 0ca6528..e40a9b2 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -42,6 +42,9 @@ optimizer: "adamw" muon_momentum: 0.95 muon_backend_steps: 5 +# Windowed attention in action tokenizer encoder (alternative to mean pool + concat) +use_windowed_attention: false + # MoE (dynamics model only): replaces SwiGLU FFN with top-k routed experts use_moe: false num_experts: 4 diff --git a/models/latent_actions.py b/models/latent_actions.py index 900ab1b..791c605 100644 --- a/models/latent_actions.py +++ b/models/latent_actions.py @@ -10,43 +10,87 @@ NUM_LATENT_ACTIONS_BINS = 2 + +class WindowedFrameAttention(nn.Module): + """Self-attention over a length-2 window of consecutive frame patches. + + For each pair of frames (t, t+1), concatenates their P patch embeddings + into a 2P sequence, applies self-attention, and mean-pools back to E. + This lets patches in both frames interact before pooling, giving richer + inter-frame signal than mean pool + concat. + """ + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, frame_a, frame_b): + # frame_a, frame_b: [B, P, E] + B, P, E = frame_a.shape + x = torch.cat([frame_a, frame_b], dim=1) # [B, 2P, E] + x = self.norm(x) + H, D = self.num_heads, self.head_dim + q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=H) # [B, H, 2P, D] + k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=H) + v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=H) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D) # [B, H, 2P, 2P] + attn = F.softmax(scores, dim=-1) + out = torch.matmul(attn, v) # [B, H, 2P, D] + out = rearrange(out, 'b h n d -> b n (h d)') # [B, 2P, E] + out = self.out_proj(out) # [B, 2P, E] + return out.mean(dim=1) # [B, E] -- mean pool over all 2P attended patches + + class LatentActionsEncoder(nn.Module): - def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, action_dim=3): + def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, + hidden_dim=256, num_blocks=4, action_dim=3, use_windowed_attention=False): super().__init__() self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) - - # embeddings to discrete latent bottleneck actions - self.action_head = nn.Sequential( - nn.LayerNorm(embed_dim * 2), - nn.Linear(embed_dim * 2, 4 * action_dim), - nn.GELU(), - nn.Linear(4 * action_dim, action_dim) - ) + self.use_windowed_attention = use_windowed_attention + + if use_windowed_attention: + self.window_attn = WindowedFrameAttention(embed_dim, num_heads) + self.action_head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * action_dim), + nn.GELU(), + nn.Linear(4 * action_dim, action_dim), + ) + else: + self.action_head = nn.Sequential( + nn.LayerNorm(embed_dim * 2), + nn.Linear(embed_dim * 2, 4 * action_dim), + nn.GELU(), + nn.Linear(4 * action_dim, action_dim), + ) def forward(self, frames): # frames: [B, T, C, H, W] batch_size, seq_len, C, H, W = frames.shape embeddings = self.patch_embed(frames) # [B, T, P, E] - transformed = self.transformer(embeddings) + transformed = self.transformer(embeddings) # [B, T, P, E] - # TODO: try attention pooling + mean instead of mean + concat - # mean pool over patches (since one action per frame) - pooled = transformed.mean(dim=2) # [B, T, E] - - # combine features from current and next frame actions = [] for t in range(seq_len - 1): - # concat current and next frame features - combined = torch.cat([pooled[:, t], pooled[:, t+1]], dim=1) # [B, E*2] - action = self.action_head(combined) # [B, A] - actions.append(action) - - actions = torch.stack(actions, dim=1) # [B, T-1, A] - - return actions + if self.use_windowed_attention: + # length-2 windowed attention across all patches from both frames, then mean pool + pooled = self.window_attn(transformed[:, t], transformed[:, t + 1]) # [B, E] + else: + # mean pool over patches then concatenate consecutive frame features + pooled_t = transformed[:, t].mean(dim=1) # [B, E] + pooled_t1 = transformed[:, t + 1].mean(dim=1) # [B, E] + pooled = torch.cat([pooled_t, pooled_t1], dim=1) # [B, E*2] + actions.append(self.action_head(pooled)) # [B, A] + + return torch.stack(actions, dim=1) # [B, T-1, A] class LatentActionsDecoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, @@ -97,12 +141,12 @@ def forward(self, frames, actions, training=True): return pred_frames # [B, T-1, C, H, W] class LatentActionModel(nn.Module): - def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, - num_heads=8, hidden_dim=256, num_blocks=4): + def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, + num_heads=8, hidden_dim=256, num_blocks=4, use_windowed_attention=False): super().__init__() assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}" - self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) - self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim) + self.action_dim = int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) + self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_windowed_attention=use_windowed_attention) self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS) self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim) self.var_target = 0.01 diff --git a/scripts/train_latent_actions.py b/scripts/train_latent_actions.py index 1c3f18d..b706d5b 100644 --- a/scripts/train_latent_actions.py +++ b/scripts/train_latent_actions.py @@ -57,6 +57,7 @@ def main(): hidden_dim=args.hidden_dim, num_blocks=args.num_blocks, n_actions=args.n_actions, + use_windowed_attention=getattr(args, 'use_windowed_attention', False), ).to(args.device) if args.checkpoint: model, _ = load_latent_actions_from_checkpoint( diff --git a/utils/config.py b/utils/config.py index 846c7b2..8ae1a79 100644 --- a/utils/config.py +++ b/utils/config.py @@ -163,6 +163,7 @@ class LatentActionsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + use_windowed_attention: bool = False def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) @@ -272,6 +273,7 @@ class TrainingConfig: n_updates: Optional[int] = None # number of optimizer.step(), excluding grad_accum_step fps: Optional[int] = None preload_ratio: Optional[float] = None + use_windowed_attention: bool = False # MoE (dynamics only) use_moe: bool = False num_experts: int = 4