diff --git a/README.md b/README.md index 0f00c0e..7e3b522 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ There are still many TODOs which may offer significant performance gains... - [ ] Try `RoPE`/`AliBi` Position Embeddings - [ ] Add more datasets (Terraria, Street Fighter, \) - [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) -- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) +- [x] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - Halton low-discrepancy schedule added, enable with `maskgit_schedule: "halton"` in `configs/training.yaml` - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters - [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames diff --git a/configs/inference.yaml b/configs/inference.yaml index 616d8f1..fd363d9 100644 --- a/configs/inference.yaml +++ b/configs/inference.yaml @@ -20,6 +20,9 @@ use_actions: false # use random actions use_gt_actions: false # use lam-inferred actions use_interactive_mode: true # use user-inputted actions +# MaskGIT unmasking schedule ("exp" or "halton") +maskgit_schedule: "exp" + # inference acceleration amp: false tf32: false diff --git a/configs/training.yaml b/configs/training.yaml index 0ca6528..2fe9245 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -47,3 +47,6 @@ use_moe: false num_experts: 4 top_k_experts: 2 moe_aux_loss_coeff: 0.01 + +# MaskGIT unmasking schedule for dynamics inference ("exp" or "halton") +maskgit_schedule: "exp" diff --git a/models/dynamics.py b/models/dynamics.py index 62f2c7a..3f9c655 100644 --- a/models/dynamics.py +++ b/models/dynamics.py @@ -91,8 +91,33 @@ def exp_schedule_torch(self, t, T, P_total, k, device): return torch.tensor(P_total, dtype=result.dtype, device=device) return result + @staticmethod + def _halton(n, base): + """n-th value of the Halton low-discrepancy sequence in the given base.""" + f, r = 1.0, 0.0 + while n > 0: + f /= base + r += f * (n % base) + n //= base + return r + + def halton_schedule_torch(self, t, T, P_total, device): + """Halton low-discrepancy unmasking schedule. + + Sorts the first T Halton values (base 2) and uses the t-th sorted entry + as the cumulative unmasking fraction. Compared to cosine and exponential + schedules this distributes unmasking steps more uniformly across the + confidence range, reducing over-commitment to early high-confidence tokens. + """ + seq = sorted(self._halton(i, 2) for i in range(1, T + 1)) + ratio = seq[min(t, T - 1)] + result = torch.tensor(float(P_total) * ratio, device=device) + if t == T - 1: + return torch.tensor(float(P_total), device=device) + return result + @torch.no_grad() - def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0): + def forward_inference(self, context_latents, prediction_horizon, num_steps, index_to_latents_fn, conditioning=None, schedule_k=5.0, temperature: float = 0.0, schedule: str = "exp"): # MaskGIT-style iterative decoding across all prediction horizon steps # context_latents: [B, T_ctx, P, L] # T_ctx=context timesteps, H=prediction horizon, K=codebook size @@ -108,7 +133,10 @@ def forward_inference(self, context_latents, prediction_horizon, num_steps, inde P_total = H * P # total masked positions across the horizon window for m in range(num_steps): - n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device) + if schedule == "halton": + n_tokens_raw = self.halton_schedule_torch(m, num_steps, P_total, device) + else: + n_tokens_raw = self.exp_schedule_torch(m, num_steps, P_total, schedule_k, device) # predict logits for current input logits, _, _ = self.forward(input_latents, training=False, conditioning=conditioning, targets=None) # [B, T_ctx+H, P, L^D] diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 10f6439..76687a8 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -128,6 +128,7 @@ def idx_to_latents(idx): index_to_latents_fn=idx_to_latents, conditioning=action_latent, temperature=args.temperature, + schedule=getattr(args, 'maskgit_schedule', 'exp'), ) # decode next video tokens to frames diff --git a/utils/config.py b/utils/config.py index 846c7b2..6cb074a 100644 --- a/utils/config.py +++ b/utils/config.py @@ -212,6 +212,8 @@ class DynamicsConfig: num_experts: int = 4 top_k_experts: int = 2 moe_aux_loss_coeff: float = 0.01 + # MaskGIT unmasking schedule ("exp", "halton") + maskgit_schedule: str = "exp" # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 @@ -221,7 +223,7 @@ class DynamicsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -277,11 +279,13 @@ class TrainingConfig: num_experts: int = 4 top_k_experts: int = 2 moe_aux_loss_coeff: float = 0.01 + # MaskGIT unmasking schedule ("exp", "halton") + maskgit_schedule: str = "exp" # Optimizer optimizer: str = "adamw" muon_momentum: float = 0.95 muon_backend_steps: int = 5 - + def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) _validate_distibuted_training(self.nproc_per_node, self.distributed) @@ -310,6 +314,8 @@ class InferenceConfig: compile: bool # Interactive mode (user enters action ids) use_interactive_mode: bool + # MaskGIT unmasking schedule ("exp", "halton") + maskgit_schedule: str = "exp" preload_ratio: Optional[float] = None