Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ There are still many TODOs which may offer significant performance gains...
- [ ] Try `RoPE`/`AliBi` Position Embeddings
- [ ] Add more datasets (Terraria, Street Fighter, \<your favorite retro videogame\>)
- [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter)
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT)
- [x] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - Halton low-discrepancy schedule added, enable with `maskgit_schedule: "halton"` in `configs/training.yaml`
- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters
- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames
Expand Down
3 changes: 3 additions & 0 deletions configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use_actions: false # use random actions
use_gt_actions: false # use lam-inferred actions
use_interactive_mode: true # use user-inputted actions

# MaskGIT unmasking schedule ("exp" or "halton")
maskgit_schedule: "exp"

# inference acceleration
amp: false
tf32: false
Expand Down
3 changes: 3 additions & 0 deletions configs/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ use_moe: false
num_experts: 4
top_k_experts: 2
moe_aux_loss_coeff: 0.01

# MaskGIT unmasking schedule for dynamics inference ("exp" or "halton")
maskgit_schedule: "exp"
32 changes: 30 additions & 2 deletions models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,33 @@ def exp_schedule_torch(self, t, T, P_total, k, device):
return torch.tensor(P_total, dtype=result.dtype, device=device)
return result

@staticmethod
def _halton(n, base):
"""n-th value of the Halton low-discrepancy sequence in the given base."""
f, r = 1.0, 0.0
while n > 0:
f /= base
r += f * (n % base)
n //= base
return r

def halton_schedule_torch(self, t, T, P_total, device):
"""Halton low-discrepancy unmasking schedule.

Sorts the first T Halton values (base 2) and uses the t-th sorted entry
as the cumulative unmasking fraction. Compared to cosine and exponential
schedules this distributes unmasking steps more uniformly across the
confidence range, reducing over-commitment to early high-confidence tokens.
"""
seq = sorted(self._halton(i, 2) for i in range(1, T + 1))
ratio = seq[min(t, T - 1)]
result = torch.tensor(float(P_total) * ratio, device=device)
if t == T - 1:
return torch.tensor(float(P_total), device=device)
return result

@torch.no_grad()
def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0):
def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0, schedule: str = "exp"):
# MaskGIT-style iterative decoding across all prediction horizon steps
# context_latents: [B, T_ctx, P, L]
# T_ctx=context timesteps, H=prediction horizon, K=codebook size
Expand All @@ -108,7 +133,10 @@ def forward_inference(self, context_latents, prediction_horizon, num_steps, inde

P_total = H * P # total masked positions across the horizon window
for m in range(num_steps):
n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device)
if schedule == "halton":
n_tokens_raw = self.halton_schedule_torch(m, num_steps, P_total, device)
else:
n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device)

# predict logits for current input
logits, _, _ = self.forward(input_latents, training=False, conditioning=conditioning, targets=None) # [B, T_ctx+H, P, L^D]
Expand Down
1 change: 1 addition & 0 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def idx_to_latents(idx):
index_to_latents_fn=idx_to_latents,
conditioning=action_latent,
temperature=args.temperature,
schedule=getattr(args, 'maskgit_schedule', 'exp'),
)

# decode next video tokens to frames
Expand Down
10 changes: 8 additions & 2 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class DynamicsConfig:
num_experts: int = 4
top_k_experts: int = 2
moe_aux_loss_coeff: float = 0.01
# MaskGIT unmasking schedule ("exp", "halton")
maskgit_schedule: str = "exp"
# Optimizer
optimizer: str = "adamw"
muon_momentum: float = 0.95
Expand All @@ -221,7 +223,7 @@ class DynamicsConfig:
# other params
fps: Optional[int] = None
preload_ratio: Optional[float] = None

def __post_init__(self) -> None:
_validate_amp_fsdp(self.amp, self.distributed)
_validate_distibuted_training(self.nproc_per_node, self.distributed)
Expand Down Expand Up @@ -277,11 +279,13 @@ class TrainingConfig:
num_experts: int = 4
top_k_experts: int = 2
moe_aux_loss_coeff: float = 0.01
# MaskGIT unmasking schedule ("exp", "halton")
maskgit_schedule: str = "exp"
# Optimizer
optimizer: str = "adamw"
muon_momentum: float = 0.95
muon_backend_steps: int = 5

def __post_init__(self) -> None:
_validate_amp_fsdp(self.amp, self.distributed)
_validate_distibuted_training(self.nproc_per_node, self.distributed)
Expand Down Expand Up @@ -310,6 +314,8 @@ class InferenceConfig:
compile: bool
# Interactive mode (user enters action ids)
use_interactive_mode: bool
# MaskGIT unmasking schedule ("exp", "halton")
maskgit_schedule: str = "exp"
preload_ratio: Optional[float] = None


Expand Down