Skip to content

feat: replace action tokenizer with windowed attention#16

Open
imitation-alpha wants to merge 3 commits into
AlmondGod:mainfrom
imitation-alpha:feature/action-tokenizer-window-attention
Open

feat: replace action tokenizer with windowed attention#16
imitation-alpha wants to merge 3 commits into
AlmondGod:mainfrom
imitation-alpha:feature/action-tokenizer-window-attention

Conversation

@imitation-alpha
Copy link
Copy Markdown

Summary

This PR replaces the "mean pool + concat" mechanism in the LatentActionsEncoder with a "length-2 windowed attention + mean" mechanism. This change aims to better capture temporal dependencies between adjacent frames during action tokenization.

Changes

  • Modified models/latent_actions.py:
    • Imported SpatialAttention from models.st_transformer.
    • Updated LatentActionsEncoder to use SpatialAttention on concatenated windows of current and next frames.
    • Removed the old mean pooling and concatenation logic.

Verification

  • Verified the implementation with a synthetic test script (scripts/verify_latent_actions.py - deleted after verification).
  • Confirmed that the model processes input frames and produces output actions with the correct dimensions.
  • Loss calculation works as expected.

Notes

  • This is a breaking change for LatentActionsEncoder checkpoints.

@imitation-alpha imitation-alpha force-pushed the feature/action-tokenizer-window-attention branch from 93ed906 to 05765b0 Compare November 29, 2025 06:09
@AlmondGod
Copy link
Copy Markdown
Owner

this looks great! can you train a working world model to confirm the impact of the change?

@NewJerseyStyle
Copy link
Copy Markdown

Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change?
For example:

  • How to confirm it does not get worse, monitor steps used to converge?
  • How to confirm it gets better, monitor the loss of the model?

@AlmondGod
Copy link
Copy Markdown
Owner

Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change? For example:

  • How to confirm it does not get worse, monitor steps used to converge?
  • How to confirm it gets better, monitor the loss of the model?

yes, I'll add in a readme pr section specifying necessary criteria

@imitation-alpha
Copy link
Copy Markdown
Author

Testing Results: Windowed Attention vs Mean-Pool+Concat

Tested the full 3-stage pipeline (Video Tokenizer → Latent Actions → Dynamics) on PicoDoom dataset (17,935 frames, 30% preload), batch_size=16, on CPU (M4 Pro 64GB Ram). Ran both 1K steps and 10K steps per stage to evaluate short and long training behavior.

Stage 1: Video Tokenizer (identical model, 0.14M params)

Both branches converge identically (~0.006 at 10K steps), as expected since this stage is unchanged.

Stage 2: Latent Actions

PR (windowed attn, 78K params) Main (mean-pool+concat, 74K params)
1K steps Loss: 0.041, Codebook: 75%, Enc Var: 0.101 Loss: 0.031, Codebook: 50%, Enc Var: 0.024
10K steps Loss: 0.028, Codebook: 50%, Enc Var: 0.054 Loss: 0.027, Codebook: 75%, Enc Var: 0.138
Speed ~3.8 it/s ~5.5 it/s
  • At 1K steps, PR shows higher codebook usage (75% vs 50%) and 4x higher encoder variance
  • At 10K steps, both converge to similar loss (~0.027-0.028) and both achieve 100% codebook usage at some point during training
  • Main branch is ~45% faster per step

Stage 3: Dynamics (0.17M params)

Steps PR (windowed attn) Main (mean-pool+concat) Difference
1K 2.848 3.594 PR 21% lower
10K 4.009 3.960 Main 1.2% lower
  • At 1K steps, PR shows 21% lower dynamics loss — faster early convergence
  • At 10K steps, the gap closes to ~1.2% with main slightly ahead — both approaches converge to similar quality

Training Curves (10K steps)

Video Tokenizer (both branches):

Step     0 → 1K → 2K → 3K → 4K → 5K → 6K → 7K → 8K → 9K
Loss  0.31  0.036 0.011 0.009 0.008 0.007 0.006 0.006 0.006 0.006

Latent Actions (PR / Main):

Step      0        1K       5K       9K
Loss   1.23/1.21  0.030/0.027  0.023/0.030  0.028/0.027
Cdbk   50%/25%    50%/50%      75%/100%     50%/75%
EncVar 0.00/0.00  0.035/0.017  0.224/0.207  0.054/0.138

Dynamics (PR / Main):

Step      0        1K       3K       5K       7K       9K
Loss   7.03/7.09  5.29/5.30  4.41/4.39  4.26/4.22  4.12/4.07  4.01/3.96

Summary

The windowed attention replacement converges faster early on (21% lower dynamics loss at 1K steps) but both approaches reach similar quality given enough training (~1.2% difference at 10K steps). The PR adds ~45% latency per step in the latent actions stage due to the attention computation. The full pipeline trains end-to-end without issues on both branches.

Environment: Mac M4 Pro 64GB, CPU-only, PicoDoom dataset, Python 3.13, PyTorch 2.10.0

@imitation-alpha
Copy link
Copy Markdown
Author

@AlmondGod friendly ping!

@AlmondGod
Copy link
Copy Markdown
Owner

This looks amazing! Thank you for this work.
Looks ready to merge except one last thing, can you visualize the inference of the model trained post-commit so we can be sure there’s no degradation?

Will add this to the PR requirements.

Thanks again and I look forward to seeing this merged!

@AlmondGod
Copy link
Copy Markdown
Owner

@imitation-alpha friendly ping, want to get this PR merged after your amazing work!

@imitation-alpha
Copy link
Copy Markdown
Author

@AlmondGod here's the post-commit inference visualization you asked for.

Setup

Trained both branches end-to-end on ZELDA at frame_size 64 with matched hyperparameters:

  • Pre-PR baseline: main (mean-pool + concat encoder)
  • Post-PR: this branch @ 05765b0 (windowed SpatialAttention)
Stage n_updates batch × grad_accum
Video tokenizer 3,000 16 × 4
Latent actions 5,000 16 × 4
Dynamics 8,000 8 × 4

The video tokenizer is unchanged by this PR, so it was trained once and the same checkpoint was reused on both branches' Dynamics stages. LAM and Dynamics were retrained from scratch on each branch (the LAM checkpoint format is a breaking change).

Inference: mode=autoregressive, generation_steps=16, context_window=4, use_gt_actions=true, identical sample_index on both runs.

Results

Variant Autoregressive MSE Teacher-forced MSE
Pre-PR (mean-pool + concat) 0.0141
Post-PR (windowed SpatialAttention) 0.0081 0.0060

Top: pre-PR rollout. Bottom: post-PR rollout. Each panel shows GT (top row) and predicted (bottom row) over the 16-step autoregressive rollout.

Pre vs Post comparison

No degradation visible — autoregressive MSE drops ~42% on this sample, and the rollout stays visually coherent for the full 16 steps.

Caveats

  • Single sample / single seed. This is a sanity check, not a statistical claim. Lines up with the convergence story from your earlier numbers (similar quality at the same compute).
  • Compute was constrained by a 6 GB RTX 2060, hence the small batches and short training budgets. Absolute MSE would drop with longer training; the relative comparison is what matters.

Two small things I had to patch locally to get the PR branch to run

Not blocking — flagging in case you want them in this PR or a follow-up:

  1. scripts/train_latent_actions.py and scripts/train_dynamics.py don't pass frame_size=args.frame_size to load_data_and_data_loaders, so the dataloader silently returns the dataset's default 128×128 frames while the model is built for args.frame_size. Causes a pos_spatial shape mismatch in patch_embed. One-line fix per script.
  2. torch.compile(..., mode="reduce-overhead", dynamic=True) trips on shape inference of pos_spatial inside the windowed attention path on this hardware. Setting compile: false in the configs is the easy workaround; pinning the spatial shape annotation would be the proper fix.

LGTM from a generation-quality standpoint — happy to see this merged.

@AlmondGod
Copy link
Copy Markdown
Owner

ok, this looks so great, last last thing, you've verified its backwards compatible? i think having spatial attention instantiated regardless isnt the best design decision, should be a optional setting instead so we have backwards compatibility

Copy link
Copy Markdown
Author

Good call. I updated the PR in 5d8125e so the windowed attention path is optional instead of always instantiated.

What changed:

  • Added use_windowed_attention, defaulting to false for backward compatibility.
  • Default mode now preserves the original mean-pool + concat encoder.
  • window_attn is only instantiated when use_windowed_attention=true.
  • Checkpoint loading reads the flag from the saved config and defaults missing values to false, so older LAM checkpoints load with the legacy architecture.

I also kept the frame-size plumbing fixes needed for the verification path, so the dataloader/inference code uses the configured/trained frame size consistently.

Verified:

  • Forward pass works for both legacy and windowed modes.
  • Synthetic checkpoints load correctly for both missing flag and use_windowed_attention=true.
  • Existing old latent-actions checkpoint loads successfully with use_windowed_attention=false and no window_attn.
  • py_compile passes on the edited Python files.

So backward compatibility is now explicit, and the windowed attention behavior is opt-in.

@AlmondGod
Copy link
Copy Markdown
Owner

beautiful, looks like it still has conflicts? other than that ready to merge

@imitation-alpha
Copy link
Copy Markdown
Author

done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants