Skip to content

feat: RoPE positional embeddings and cosine MaskGIT scheduler#21

Open
tashapais wants to merge 1 commit into
AlmondGod:mainfrom
tashapais:feat/rope-and-cosine-maskgit
Open

feat: RoPE positional embeddings and cosine MaskGIT scheduler#21
tashapais wants to merge 1 commit into
AlmondGod:mainfrom
tashapais:feat/rope-and-cosine-maskgit

Conversation

@tashapais
Copy link
Copy Markdown

Summary

  • RoPE positional embeddings (use_rope: true in configs/training.yaml): replaces additive sinusoidal PE with Rotary Position Embeddings applied directly to Q/K in attention. Temporal attention uses 1D RoPE across the time axis; spatial attention uses 2D RoPE with independent y-axis and x-axis encodings split across the head dimension. When enabled, the additive sinusoidal temporal and spatial PEs are skipped.

  • Cosine MaskGIT unmasking schedule (maskgit_schedule: "cosine" in config): adds a cosine schedule alongside the existing exponential one. The cosine variant reveals P_total * (1 - cos(pi/2 * t/T)) tokens by step t, front-loading high-confidence unmasking early in decoding. Selectable at inference time via configs/inference.yaml.

Both flags default to false/"exp", so existing checkpoints and training runs are unaffected.

Changes

  • models/positional_encoding.py: rope_1d_cos_sin, rope_2d_cos_sin, apply_rope_1d, apply_rope_2d
  • models/st_transformer.py: use_rope + grid_size plumbed through SpatialAttention, TemporalAttention, STTransformerBlock, STTransformer
  • models/video_tokenizer.py, models/latent_actions.py, models/dynamics.py: use_rope parameter
  • models/dynamics.py: cosine_schedule_torch, schedule param on forward_inference
  • utils/config.py: use_rope on all model configs, maskgit_schedule on DynamicsConfig, TrainingConfig, InferenceConfig
  • configs/training.yaml, configs/inference.yaml: new fields with defaults
  • Training/inference scripts: pass new flags through

Test plan

  • rope_1d_cos_sin / rope_2d_cos_sin output correct shapes
  • STTransformer forward pass identical shape with and without RoPE
  • VideoTokenizer forward pass with use_rope=True
  • LatentActionModel forward pass with use_rope=True
  • DynamicsModel.forward_inference with both schedule="exp" and schedule="cosine"
  • DynamicsModel.forward_inference with use_rope=True + schedule="cosine"
  • Full training run + visual inference coherence check (requires GPU, per contributing guidelines)

Implements two TODOs from the contributor list:

1. RoPE (Rotary Position Embeddings) as a drop-in replacement for the
   additive sinusoidal positional encodings. When use_rope=true:
   - TemporalAttention applies 1D RoPE to Q/K across the time axis
   - SpatialAttention applies 2D RoPE to Q/K using independent y/x axis
     encodings in the first and second halves of the head dimension
   - Additive temporal and spatial PEs are skipped entirely
   Enable with `use_rope: true` in configs/training.yaml.

2. Cosine MaskGIT unmasking schedule alongside the existing exponential
   schedule. The cosine variant reveals tokens proportional to
   1 - cos(pi/2 * t/T), front-loading confidence early in decoding.
   Select with `maskgit_schedule: "cosine"` in configs/training.yaml
   or configs/inference.yaml (default: "exp", preserving prior behavior).

All new flags default to false/"exp" so existing checkpoints and runs
are unaffected.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant