diff --git a/README.md b/README.md index 0f00c0e..76e8df8 100644 --- a/README.md +++ b/README.md @@ -271,7 +271,7 @@ There are still many TODOs which may offer significant performance gains... - [ ] Try `RoPE`/`AliBi` Position Embeddings - [ ] Add more datasets (Terraria, Street Fighter, \) -- [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) +- [x] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters diff --git a/configs/inference.yaml b/configs/inference.yaml index 616d8f1..01a27fe 100644 --- a/configs/inference.yaml +++ b/configs/inference.yaml @@ -23,4 +23,7 @@ use_interactive_mode: true # use user-inputted actions # inference acceleration amp: false tf32: false -compile: false \ No newline at end of file +compile: false + +# AdaLN-Zero: must match the training config used for loaded checkpoints +use_adaln_zero: false \ No newline at end of file diff --git a/configs/training.yaml b/configs/training.yaml index 0ca6528..62f0451 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -42,6 +42,9 @@ optimizer: "adamw" muon_momentum: 0.95 muon_backend_steps: 5 +# AdaLN-Zero: pre-norm conditioning with zero-init gate (alternative to FiLM) +use_adaln_zero: false + # MoE (dynamics model only): replaces SwiGLU FFN with top-k routed experts use_moe: false num_experts: 4 diff --git a/models/dynamics.py b/models/dynamics.py index 62f2c7a..ae61087 100644 --- a/models/dynamics.py +++ b/models/dynamics.py @@ -9,7 +9,8 @@ class DynamicsModel(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads=8, hidden_dim=128, num_blocks=4, num_bins=4, n_actions=8, conditioning_dim=3, latent_dim=5, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_adaln_zero=False): super().__init__() H, W = frame_size codebook_size = num_bins**latent_dim @@ -20,6 +21,7 @@ def __init__(self, frame_size=(128, 128), patch_size=4, embed_dim=128, num_heads conditioning_dim=conditioning_dim, use_moe=use_moe, num_experts=num_experts, top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff, + use_adaln_zero=use_adaln_zero, ) self.output_mlp = nn.Linear(embed_dim, codebook_size) diff --git a/models/latent_actions.py b/models/latent_actions.py index 900ab1b..0690e99 100644 --- a/models/latent_actions.py +++ b/models/latent_actions.py @@ -11,11 +11,12 @@ NUM_LATENT_ACTIONS_BINS = 2 class LatentActionsEncoder(nn.Module): - def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, action_dim=3): + def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, + hidden_dim=256, num_blocks=4, action_dim=3, use_adaln_zero=False): super().__init__() self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_adaln_zero=use_adaln_zero) # embeddings to discrete latent bottleneck actions self.action_head = nn.Sequential( @@ -50,10 +51,11 @@ def forward(self, frames): class LatentActionsDecoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, conditioning_dim=3): + hidden_dim=256, num_blocks=4, conditioning_dim=3, use_adaln_zero=False): super().__init__() self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=conditioning_dim) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + conditioning_dim=conditioning_dim, use_adaln_zero=use_adaln_zero) # embeddings to mixed frame output patches self.frame_head = nn.Sequential( @@ -97,14 +99,14 @@ def forward(self, frames, actions, training=True): return pred_frames # [B, T-1, C, H, W] class LatentActionModel(nn.Module): - def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, - num_heads=8, hidden_dim=256, num_blocks=4): + def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, + num_heads=8, hidden_dim=256, num_blocks=4, use_adaln_zero=False): super().__init__() assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}" - self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) - self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim) + self.action_dim = int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) + self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim, use_adaln_zero=use_adaln_zero) self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS) - self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim) + self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim, use_adaln_zero=use_adaln_zero) self.var_target = 0.01 self.var_lambda = 100.0 diff --git a/models/norms.py b/models/norms.py index b5cc38d..377a808 100644 --- a/models/norms.py +++ b/models/norms.py @@ -2,6 +2,40 @@ import torch.nn as nn from einops import repeat +class AdaLNZeroNorm(nn.Module): + """Pre-norm with zero-init conditioning MLP producing (scale, shift, gate). + + Gate alpha is zero-initialized so residual paths start as identity. + Returns (normed_x, gate); caller computes: x + gate * sublayer(normed_x). + When unconditioned, gate is None and caller computes: x + sublayer(normed_x). + """ + def __init__(self, embed_dim, conditioning_dim=None): + super().__init__() + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + self.to_params = None + if conditioning_dim is not None: + self.to_params = nn.Sequential( + nn.SiLU(), + nn.Linear(conditioning_dim, 3 * embed_dim), # scale, shift, gate + ) + nn.init.zeros_(self.to_params[-1].weight) + nn.init.zeros_(self.to_params[-1].bias) + + def forward(self, x, conditioning=None): + # x: [B, T, P, E]; returns (normed_x: [B,T,P,E], gate: [B,T,P,E] or None) + x_normed = self.norm(x) + if self.to_params is None or conditioning is None: + return x_normed, None + B, T, P, E = x_normed.shape + params = self.to_params(conditioning) # [B, T, 3E] + params = repeat(params, 'b t e -> b t p e', p=P) # [B, T, P, 3E] + scale, shift, gate = params.chunk(3, dim=-1) # each [B, T, P, E] + if scale.shape[1] == x.shape[1] - 1: + pad = lambda t: torch.cat([torch.zeros_like(t[:, :1]), t], dim=1) + scale, shift, gate = pad(scale), pad(shift), pad(gate) + return x_normed * (1 + scale) + shift, gate + + class RMSNorm(nn.Module): def __init__(self, embed_dim, eps=1e-5): super().__init__() diff --git a/models/st_transformer.py b/models/st_transformer.py index a1621e7..f6b2d3a 100644 --- a/models/st_transformer.py +++ b/models/st_transformer.py @@ -2,13 +2,13 @@ import torch.nn as nn from einops import rearrange from models.positional_encoding import build_spatial_only_pe, sincos_time -from models.norms import AdaptiveNormalizer +from models.norms import AdaptiveNormalizer, AdaLNZeroNorm from models.patch_embed import PatchEmbedding import math import torch.nn.functional as F class SpatialAttention(nn.Module): - def __init__(self, embed_dim, num_heads, conditioning_dim=None): + def __init__(self, embed_dim, num_heads, conditioning_dim=None, use_adaln_zero=False): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -20,16 +20,25 @@ def __init__(self, embed_dim, num_heads, conditioning_dim=None): self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) - self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) + self.use_adaln_zero = use_adaln_zero + if use_adaln_zero: + self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim) + else: + self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) def forward(self, x, conditioning=None): B, T, P, E = x.shape - # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H] + if self.use_adaln_zero: + x_in, gate = self.norm(x, conditioning) # pre-norm + conditioning params + else: + x_in = x + + # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*T), H, P, E/H] # (4 dims to work with torch compile attention) - q = rearrange(self.q_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) - k = rearrange(self.k_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) - v = rearrange(self.v_proj(x), 'B T P (H D) -> (B T) H P D', H=self.num_heads) + q = rearrange(self.q_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads) + k = rearrange(self.k_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads) + v = rearrange(self.v_proj(x_in), 'B T P (H D) -> (B T) H P D', H=self.num_heads) k_t = k.transpose(-2, -1) # [(B*T), H, P, D, P] @@ -42,35 +51,46 @@ def forward(self, x, conditioning=None): # out proj to mix head information attn_out = self.out_proj(attn_output) # [B, T, P, E] - # residual and optionally conditioned norm - out = self.norm(x + attn_out, conditioning) # [B, T, P, E] - - return out # [B, T, P, E] + if self.use_adaln_zero: + # gated residual; gate is None when unconditioned (identity) + return x + (gate * attn_out if gate is not None else attn_out) # [B, T, P, E] + else: + # residual and optionally conditioned post-norm (FiLM) + return self.norm(x + attn_out, conditioning) # [B, T, P, E] class TemporalAttention(nn.Module): - def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None): + def __init__(self, embed_dim, num_heads, causal=True, conditioning_dim=None, use_adaln_zero=False): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == embed_dim - + self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) - - self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) + + self.use_adaln_zero = use_adaln_zero + if use_adaln_zero: + self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim) + else: + self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) self.causal = causal - + def forward(self, x, conditioning=None): B, T, P, E = x.shape - - # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D] + + if self.use_adaln_zero: + x_in, gate = self.norm(x, conditioning) # pre-norm + conditioning params + else: + x_in = x + + # project to Q, K, V and split into heads: [B, T, P, E] -> [(B*P), H, T, D] # (4 dims to work with torch compile attention) - q = rearrange(self.q_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) - k = rearrange(self.k_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) - v = rearrange(self.v_proj(x), 'b t p (h d) -> (b p) h t d', h=self.num_heads) # [B, P, H, T, D] + q = rearrange(self.q_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads) + k = rearrange(self.k_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads) + v = rearrange(self.v_proj(x_in), 'b t p (h d) -> (b p) h t d', h=self.num_heads) k_t = k.transpose(-2, -1) # [(B*P), H, T, D, T] @@ -89,26 +109,37 @@ def forward(self, x, conditioning=None): # out proj to mix head information attn_out = self.out_proj(attn_output) # [B, T, P, E] - # residual and optionally conditioned norm - out = self.norm(x + attn_out, conditioning) # [B, T, P, E] - - return out # [B, T, P, E] + if self.use_adaln_zero: + return x + (gate * attn_out if gate is not None else attn_out) # [B, T, P, E] + else: + # residual and optionally conditioned post-norm (FiLM) + return self.norm(x + attn_out, conditioning) # [B, T, P, E] class SwiGLUFFN(nn.Module): # swiglu(x) = W3(sigmoid(W1(x) + b1) * (W2(x) + b2)) + b3 - def __init__(self, embed_dim, hidden_dim, conditioning_dim=None): + def __init__(self, embed_dim, hidden_dim, conditioning_dim=None, use_adaln_zero=False): super().__init__() h = math.floor(2 * hidden_dim / 3) self.w_v = nn.Linear(embed_dim, h) self.w_g = nn.Linear(embed_dim, h) self.w_o = nn.Linear(h, embed_dim) - self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) + self.use_adaln_zero = use_adaln_zero + if use_adaln_zero: + self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim) + else: + self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) def forward(self, x, conditioning=None): - v = F.silu(self.w_v(x)) # [B, T, P, h] - g = self.w_g(x) # [B, T, P, h] - out = self.w_o(v * g) # [B, T, P, E] - return self.norm(x + out, conditioning) # [B, T, P, E] + if self.use_adaln_zero: + x_in, gate = self.norm(x, conditioning) + else: + x_in = x + v = F.silu(self.w_v(x_in)) # [B, T, P, h] + g = self.w_g(x_in) # [B, T, P, h] + out = self.w_o(v * g) # [B, T, P, E] + if self.use_adaln_zero: + return x + (gate * out if gate is not None else out) # [B, T, P, E] + return self.norm(x + out, conditioning) # [B, T, P, E] class SwiGLUExpert(nn.Module): @@ -125,7 +156,7 @@ def forward(self, x): class MoESwiGLUFFN(nn.Module): def __init__(self, embed_dim, hidden_dim, num_experts=4, top_k=2, - aux_loss_coeff=0.01, conditioning_dim=None): + aux_loss_coeff=0.01, conditioning_dim=None, use_adaln_zero=False): super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -135,7 +166,11 @@ def __init__(self, embed_dim, hidden_dim, num_experts=4, top_k=2, self.experts = nn.ModuleList([ SwiGLUExpert(embed_dim, hidden_dim) for _ in range(num_experts) ]) - self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) + self.use_adaln_zero = use_adaln_zero + if use_adaln_zero: + self.norm = AdaLNZeroNorm(embed_dim, conditioning_dim) + else: + self.norm = AdaptiveNormalizer(embed_dim, conditioning_dim) self._aux_loss = None self._expert_counts = None # per-expert token fractions from last forward @@ -156,6 +191,9 @@ def forward(self, x, conditioning=None): B, T, P, E = x.shape residual = x + if self.use_adaln_zero: + x, gate = self.norm(x, conditioning) # pre-norm; gate applied after expert combine + # flatten spatial dims for routing: [B*T*P, E] flat = x.reshape(-1, E) N = flat.shape[0] @@ -196,23 +234,27 @@ def forward(self, x, conditioning=None): # reshape back and apply residual + norm out = output.reshape(B, T, P, E) + if self.use_adaln_zero: + return residual + (gate * out if gate is not None else out) # [B, T, P, E] return self.norm(residual + out, conditioning) # [B, T, P, E] class STTransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, hidden_dim, causal=True, conditioning_dim=None, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_adaln_zero=False): super().__init__() - self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim) - self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim) + self.spatial_attn = SpatialAttention(embed_dim, num_heads, conditioning_dim, use_adaln_zero) + self.temporal_attn = TemporalAttention(embed_dim, num_heads, causal, conditioning_dim, use_adaln_zero) if use_moe: self.ffn = MoESwiGLUFFN( embed_dim, hidden_dim, num_experts=num_experts, top_k=top_k_experts, aux_loss_coeff=moe_aux_loss_coeff, conditioning_dim=conditioning_dim, + use_adaln_zero=use_adaln_zero, ) else: - self.ffn = SwiGLUFFN(embed_dim, hidden_dim, conditioning_dim) + self.ffn = SwiGLUFFN(embed_dim, hidden_dim, conditioning_dim, use_adaln_zero) def forward(self, x, conditioning=None): # x: [B, T, P, E] @@ -224,7 +266,8 @@ def forward(self, x, conditioning=None): class STTransformer(nn.Module): def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, conditioning_dim=None, - use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01): + use_moe=False, num_experts=4, top_k_experts=2, moe_aux_loss_coeff=0.01, + use_adaln_zero=False): super().__init__() # calculate temporal PE dim self.temporal_dim = (embed_dim // 3) & ~1 # round down to even number @@ -235,6 +278,7 @@ def __init__(self, embed_dim, num_heads, hidden_dim, num_blocks, causal=True, co embed_dim, num_heads, hidden_dim, causal, conditioning_dim, use_moe=use_moe, num_experts=num_experts, top_k_experts=top_k_experts, moe_aux_loss_coeff=moe_aux_loss_coeff, + use_adaln_zero=use_adaln_zero, ) for _ in range(num_blocks) ]) diff --git a/models/video_tokenizer.py b/models/video_tokenizer.py index 0a56c8b..8ae98c5 100644 --- a/models/video_tokenizer.py +++ b/models/video_tokenizer.py @@ -9,11 +9,12 @@ from models.positional_encoding import build_spatial_only_pe class VideoTokenizerEncoder(nn.Module): - def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=5): + def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, + hidden_dim=256, num_blocks=4, latent_dim=5, use_adaln_zero=False): super().__init__() self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_adaln_zero=use_adaln_zero) self.latent_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, latent_dim) @@ -46,15 +47,16 @@ def forward(self, tokens): # [B, T, P, E] class VideoTokenizerDecoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=5): + hidden_dim=256, num_blocks=4, latent_dim=5, use_adaln_zero=False): super().__init__() H, W = frame_size self.patch_size = patch_size self.Hp, self.Wp = H // patch_size, W // patch_size self.num_patches = self.Hp * self.Wp - + self.latent_embed = nn.Linear(latent_dim, embed_dim) - self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True, + use_adaln_zero=use_adaln_zero) self.frame_head = PixelShuffleFrameHead(embed_dim, patch_size=patch_size, channels=3, H=H, W=W) # first 2/3 spatial PE (temporal is last 1/3) @@ -78,10 +80,10 @@ def forward(self, latents): class VideoTokenizer(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, latent_dim=3, num_bins=4): + hidden_dim=256, num_blocks=4, latent_dim=3, num_bins=4, use_adaln_zero=False): super().__init__() - self.encoder = VideoTokenizerEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim) - self.decoder = VideoTokenizerDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim) + self.encoder = VideoTokenizerEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim, use_adaln_zero) + self.decoder = VideoTokenizerDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, latent_dim, use_adaln_zero) self.quantizer = FiniteScalarQuantizer(latent_dim, num_bins) self.codebook_size = num_bins**latent_dim diff --git a/scripts/train_dynamics.py b/scripts/train_dynamics.py index 934dc55..0280adf 100644 --- a/scripts/train_dynamics.py +++ b/scripts/train_dynamics.py @@ -81,6 +81,7 @@ def main(): num_experts=getattr(args, 'num_experts', 4), top_k_experts=getattr(args, 'top_k_experts', 2), moe_aux_loss_coeff=getattr(args, 'moe_aux_loss_coeff', 0.01), + use_adaln_zero=getattr(args, 'use_adaln_zero', False), ).to(args.device) if args.checkpoint: dynamics_model, _ = load_dynamics_from_checkpoint( diff --git a/scripts/train_latent_actions.py b/scripts/train_latent_actions.py index 1c3f18d..6436ca8 100644 --- a/scripts/train_latent_actions.py +++ b/scripts/train_latent_actions.py @@ -57,6 +57,7 @@ def main(): hidden_dim=args.hidden_dim, num_blocks=args.num_blocks, n_actions=args.n_actions, + use_adaln_zero=getattr(args, 'use_adaln_zero', False), ).to(args.device) if args.checkpoint: model, _ = load_latent_actions_from_checkpoint( diff --git a/scripts/train_video_tokenizer.py b/scripts/train_video_tokenizer.py index fc808bb..5398aba 100644 --- a/scripts/train_video_tokenizer.py +++ b/scripts/train_video_tokenizer.py @@ -51,7 +51,7 @@ def main(): # print("Length of validation data:", len(validation_data)) # init model and optional ckpt load model = VideoTokenizer( - frame_size=(args.frame_size, args.frame_size), + frame_size=(args.frame_size, args.frame_size), patch_size=args.patch_size, embed_dim=args.embed_dim, num_heads=args.num_heads, @@ -59,6 +59,7 @@ def main(): num_blocks=args.num_blocks, latent_dim=args.latent_dim, num_bins=args.num_bins, + use_adaln_zero=getattr(args, 'use_adaln_zero', False), ).to(args.device) if args.checkpoint: model, _ = load_videotokenizer_from_checkpoint( diff --git a/utils/config.py b/utils/config.py index 846c7b2..5a841a0 100644 --- a/utils/config.py +++ b/utils/config.py @@ -117,6 +117,7 @@ class VideoTokenizerConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + use_adaln_zero: bool = False def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) @@ -163,6 +164,7 @@ class LatentActionsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + use_adaln_zero: bool = False def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) @@ -221,6 +223,7 @@ class DynamicsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + use_adaln_zero: bool = False def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed) @@ -272,6 +275,7 @@ class TrainingConfig: n_updates: Optional[int] = None # number of optimizer.step(), excluding grad_accum_step fps: Optional[int] = None preload_ratio: Optional[float] = None + use_adaln_zero: bool = False # MoE (dynamics only) use_moe: bool = False num_experts: int = 4 @@ -311,6 +315,7 @@ class InferenceConfig: # Interactive mode (user enters action ids) use_interactive_mode: bool preload_ratio: Optional[float] = None + use_adaln_zero: bool = False def load_config(config_cls, default_config_path: Optional[str] = None):