feat: RoPE positional embeddings and cosine MaskGIT scheduler#21
Open
tashapais wants to merge 1 commit into
Open
feat: RoPE positional embeddings and cosine MaskGIT scheduler#21tashapais wants to merge 1 commit into
tashapais wants to merge 1 commit into
Conversation
Implements two TODOs from the contributor list:
1. RoPE (Rotary Position Embeddings) as a drop-in replacement for the
additive sinusoidal positional encodings. When use_rope=true:
- TemporalAttention applies 1D RoPE to Q/K across the time axis
- SpatialAttention applies 2D RoPE to Q/K using independent y/x axis
encodings in the first and second halves of the head dimension
- Additive temporal and spatial PEs are skipped entirely
Enable with `use_rope: true` in configs/training.yaml.
2. Cosine MaskGIT unmasking schedule alongside the existing exponential
schedule. The cosine variant reveals tokens proportional to
1 - cos(pi/2 * t/T), front-loading confidence early in decoding.
Select with `maskgit_schedule: "cosine"` in configs/training.yaml
or configs/inference.yaml (default: "exp", preserving prior behavior).
All new flags default to false/"exp" so existing checkpoints and runs
are unaffected.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
RoPE positional embeddings (
use_rope: trueinconfigs/training.yaml): replaces additive sinusoidal PE with Rotary Position Embeddings applied directly to Q/K in attention. Temporal attention uses 1D RoPE across the time axis; spatial attention uses 2D RoPE with independent y-axis and x-axis encodings split across the head dimension. When enabled, the additive sinusoidal temporal and spatial PEs are skipped.Cosine MaskGIT unmasking schedule (
maskgit_schedule: "cosine"in config): adds a cosine schedule alongside the existing exponential one. The cosine variant revealsP_total * (1 - cos(pi/2 * t/T))tokens by stept, front-loading high-confidence unmasking early in decoding. Selectable at inference time viaconfigs/inference.yaml.Both flags default to
false/"exp", so existing checkpoints and training runs are unaffected.Changes
models/positional_encoding.py:rope_1d_cos_sin,rope_2d_cos_sin,apply_rope_1d,apply_rope_2dmodels/st_transformer.py:use_rope+grid_sizeplumbed throughSpatialAttention,TemporalAttention,STTransformerBlock,STTransformermodels/video_tokenizer.py,models/latent_actions.py,models/dynamics.py:use_ropeparametermodels/dynamics.py:cosine_schedule_torch,scheduleparam onforward_inferenceutils/config.py:use_ropeon all model configs,maskgit_scheduleonDynamicsConfig,TrainingConfig,InferenceConfigconfigs/training.yaml,configs/inference.yaml: new fields with defaultsTest plan
rope_1d_cos_sin/rope_2d_cos_sinoutput correct shapesSTTransformerforward pass identical shape with and without RoPEVideoTokenizerforward pass withuse_rope=TrueLatentActionModelforward pass withuse_rope=TrueDynamicsModel.forward_inferencewith bothschedule="exp"andschedule="cosine"DynamicsModel.forward_inferencewithuse_rope=True+schedule="cosine"