Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ There are still many TODOs which may offer significant performance gains...
- [ ] Add more datasets (Terraria, Street Fighter, \<your favorite retro videogame\>)
- [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter)
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT)
- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [x] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters
- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames
- [x] Implement Mixture of Experts in the Feedforward Network - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20)
Expand Down
3 changes: 3 additions & 0 deletions configs/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ optimizer: "adamw"
muon_momentum: 0.95
muon_backend_steps: 5

# Windowed attention in action tokenizer encoder (alternative to mean pool + concat)
use_windowed_attention: false

# MoE (dynamics model only): replaces SwiGLU FFN with top-k routed experts
use_moe: false
num_experts: 4
Expand Down
100 changes: 72 additions & 28 deletions models/latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,87 @@

NUM_LATENT_ACTIONS_BINS = 2


class WindowedFrameAttention(nn.Module):
"""Self-attention over a length-2 window of consecutive frame patches.

For each pair of frames (t, t+1), concatenates their P patch embeddings
into a 2P sequence, applies self-attention, and mean-pools back to E.
This lets patches in both frames interact before pooling, giving richer
inter-frame signal than mean pool + concat.
"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm = nn.LayerNorm(embed_dim)

def forward(self, frame_a, frame_b):
# frame_a, frame_b: [B, P, E]
B, P, E = frame_a.shape
x = torch.cat([frame_a, frame_b], dim=1) # [B, 2P, E]
x = self.norm(x)
H, D = self.num_heads, self.head_dim
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=H) # [B, H, 2P, D]
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=H)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=H)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D) # [B, H, 2P, 2P]
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # [B, H, 2P, D]
out = rearrange(out, 'b h n d -> b n (h d)') # [B, 2P, E]
out = self.out_proj(out) # [B, 2P, E]
return out.mean(dim=1) # [B, E] -- mean pool over all 2P attended patches


class LatentActionsEncoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3, use_windowed_attention=False):
super().__init__()
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True)

# embeddings to discrete latent bottleneck actions
self.action_head = nn.Sequential(
nn.LayerNorm(embed_dim * 2),
nn.Linear(embed_dim * 2, 4 * action_dim),
nn.GELU(),
nn.Linear(4 * action_dim, action_dim)
)
self.use_windowed_attention = use_windowed_attention

if use_windowed_attention:
self.window_attn = WindowedFrameAttention(embed_dim, num_heads)
self.action_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4 * action_dim),
nn.GELU(),
nn.Linear(4 * action_dim, action_dim),
)
else:
self.action_head = nn.Sequential(
nn.LayerNorm(embed_dim * 2),
nn.Linear(embed_dim * 2, 4 * action_dim),
nn.GELU(),
nn.Linear(4 * action_dim, action_dim),
)

def forward(self, frames):
# frames: [B, T, C, H, W]
batch_size, seq_len, C, H, W = frames.shape

embeddings = self.patch_embed(frames) # [B, T, P, E]
transformed = self.transformer(embeddings)
transformed = self.transformer(embeddings) # [B, T, P, E]

# TODO: try attention pooling + mean instead of mean + concat
# mean pool over patches (since one action per frame)
pooled = transformed.mean(dim=2) # [B, T, E]

# combine features from current and next frame
actions = []
for t in range(seq_len - 1):
# concat current and next frame features
combined = torch.cat([pooled[:, t], pooled[:, t+1]], dim=1) # [B, E*2]
action = self.action_head(combined) # [B, A]
actions.append(action)

actions = torch.stack(actions, dim=1) # [B, T-1, A]

return actions
if self.use_windowed_attention:
# length-2 windowed attention across all patches from both frames, then mean pool
pooled = self.window_attn(transformed[:, t], transformed[:, t + 1]) # [B, E]
else:
# mean pool over patches then concatenate consecutive frame features
pooled_t = transformed[:, t].mean(dim=1) # [B, E]
pooled_t1 = transformed[:, t + 1].mean(dim=1) # [B, E]
pooled = torch.cat([pooled_t, pooled_t1], dim=1) # [B, E*2]
actions.append(self.action_head(pooled)) # [B, A]

return torch.stack(actions, dim=1) # [B, T-1, A]

class LatentActionsDecoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
Expand Down Expand Up @@ -97,12 +141,12 @@ def forward(self, frames, actions, training=True):
return pred_frames # [B, T-1, C, H, W]

class LatentActionModel(nn.Module):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4, use_windowed_attention=False):
super().__init__()
assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}"
self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim)
self.action_dim = int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_windowed_attention=use_windowed_attention)
self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim)
self.var_target = 0.01
Expand Down
1 change: 1 addition & 0 deletions scripts/train_latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def main():
hidden_dim=args.hidden_dim,
num_blocks=args.num_blocks,
n_actions=args.n_actions,
use_windowed_attention=getattr(args, 'use_windowed_attention', False),
).to(args.device)
if args.checkpoint:
model, _ = load_latent_actions_from_checkpoint(
Expand Down
2 changes: 2 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class LatentActionsConfig:
# other params
fps: Optional[int] = None
preload_ratio: Optional[float] = None
use_windowed_attention: bool = False

def __post_init__(self) -> None:
_validate_amp_fsdp(self.amp, self.distributed)
Expand Down Expand Up @@ -272,6 +273,7 @@ class TrainingConfig:
n_updates: Optional[int] = None # number of optimizer.step(), excluding grad_accum_step
fps: Optional[int] = None
preload_ratio: Optional[float] = None
use_windowed_attention: bool = False
# MoE (dynamics only)
use_moe: bool = False
num_experts: int = 4
Expand Down