feat: replace action tokenizer with windowed attention#16
feat: replace action tokenizer with windowed attention#16imitation-alpha wants to merge 3 commits into
Conversation
93ed906 to
05765b0
Compare
|
this looks great! can you train a working world model to confirm the impact of the change? |
|
Sorry to interrupt. I am not an expert, but I am curious if there are "KPIs" to be monitored to evaluate the impact of a change?
|
yes, I'll add in a readme pr section specifying necessary criteria |
Testing Results: Windowed Attention vs Mean-Pool+ConcatTested the full 3-stage pipeline (Video Tokenizer → Latent Actions → Dynamics) on PicoDoom dataset (17,935 frames, 30% preload), batch_size=16, on CPU (M4 Pro 64GB Ram). Ran both 1K steps and 10K steps per stage to evaluate short and long training behavior. Stage 1: Video Tokenizer (identical model, 0.14M params)Both branches converge identically (~0.006 at 10K steps), as expected since this stage is unchanged. Stage 2: Latent Actions
Stage 3: Dynamics (0.17M params)
Training Curves (10K steps)Video Tokenizer (both branches): Latent Actions (PR / Main): Dynamics (PR / Main): SummaryThe windowed attention replacement converges faster early on (21% lower dynamics loss at 1K steps) but both approaches reach similar quality given enough training (~1.2% difference at 10K steps). The PR adds ~45% latency per step in the latent actions stage due to the attention computation. The full pipeline trains end-to-end without issues on both branches. Environment: Mac M4 Pro 64GB, CPU-only, PicoDoom dataset, Python 3.13, PyTorch 2.10.0 |
|
@AlmondGod friendly ping! |
|
This looks amazing! Thank you for this work. Will add this to the PR requirements. Thanks again and I look forward to seeing this merged! |
|
@imitation-alpha friendly ping, want to get this PR merged after your amazing work! |
|
@AlmondGod here's the post-commit inference visualization you asked for. SetupTrained both branches end-to-end on ZELDA at frame_size 64 with matched hyperparameters:
The video tokenizer is unchanged by this PR, so it was trained once and the same checkpoint was reused on both branches' Dynamics stages. LAM and Dynamics were retrained from scratch on each branch (the LAM checkpoint format is a breaking change). Inference: Results
Top: pre-PR rollout. Bottom: post-PR rollout. Each panel shows GT (top row) and predicted (bottom row) over the 16-step autoregressive rollout. No degradation visible — autoregressive MSE drops ~42% on this sample, and the rollout stays visually coherent for the full 16 steps. Caveats
Two small things I had to patch locally to get the PR branch to runNot blocking — flagging in case you want them in this PR or a follow-up:
LGTM from a generation-quality standpoint — happy to see this merged. |
|
ok, this looks so great, last last thing, you've verified its backwards compatible? i think having spatial attention instantiated regardless isnt the best design decision, should be a optional setting instead so we have backwards compatibility |
|
Good call. I updated the PR in What changed:
I also kept the frame-size plumbing fixes needed for the verification path, so the dataloader/inference code uses the configured/trained frame size consistently. Verified:
So backward compatibility is now explicit, and the windowed attention behavior is opt-in. |
|
beautiful, looks like it still has conflicts? other than that ready to merge |
|
done |

Summary
This PR replaces the "mean pool + concat" mechanism in the
LatentActionsEncoderwith a "length-2 windowed attention + mean" mechanism. This change aims to better capture temporal dependencies between adjacent frames during action tokenization.Changes
models/latent_actions.py:SpatialAttentionfrommodels.st_transformer.LatentActionsEncoderto useSpatialAttentionon concatenated windows of current and next frames.Verification
scripts/verify_latent_actions.py- deleted after verification).Notes
LatentActionsEncodercheckpoints.