From 05765b0c9b9937ff997777ff258c438705eb8272 Mon Sep 17 00:00:00 2001 From: Imitation Alpha Date: Sat, 29 Nov 2025 01:02:58 -0500 Subject: [PATCH 1/2] feat: replace action tokenizer with windowed attention --- models/latent_actions.py | 41 +++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/models/latent_actions.py b/models/latent_actions.py index 900ab1b..c073aee 100644 --- a/models/latent_actions.py +++ b/models/latent_actions.py @@ -5,7 +5,7 @@ import torch.distributed as dist import math from einops import rearrange, repeat, reduce -from models.st_transformer import STTransformer, PatchEmbedding +from models.st_transformer import STTransformer, PatchEmbedding, SpatialAttention from models.fsq import FiniteScalarQuantizer NUM_LATENT_ACTIONS_BINS = 2 @@ -17,10 +17,13 @@ def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) + # windowed attention for action tokenization + self.window_attn = SpatialAttention(embed_dim, num_heads) + # embeddings to discrete latent bottleneck actions self.action_head = nn.Sequential( - nn.LayerNorm(embed_dim * 2), - nn.Linear(embed_dim * 2, 4 * action_dim), + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * action_dim), nn.GELU(), nn.Linear(4 * action_dim, action_dim) ) @@ -32,19 +35,27 @@ def forward(self, frames): embeddings = self.patch_embed(frames) # [B, T, P, E] transformed = self.transformer(embeddings) - # TODO: try attention pooling + mean instead of mean + concat - # mean pool over patches (since one action per frame) - pooled = transformed.mean(dim=2) # [B, T, E] - - # combine features from current and next frame - actions = [] - for t in range(seq_len - 1): - # concat current and next frame features - combined = torch.cat([pooled[:, t], pooled[:, t+1]], dim=1) # [B, E*2] - action = self.action_head(combined) # [B, A] - actions.append(action) + # windowed attention over length-2 windows (current and next frame) + # we want to combine frame t and t+1 + + # 1. Create windows: [B, T-1, 2, P, E] + # We slice to get T-1 windows + current_frames = transformed[:, :-1] # [B, T-1, P, E] + next_frames = transformed[:, 1:] # [B, T-1, P, E] + + # 2. Concatenate along patch dimension to treat as one large spatial sequence + # [B, T-1, 2*P, E] + windows = torch.cat([current_frames, next_frames], dim=2) + + # 3. Apply spatial attention + # SpatialAttention expects [B, T, P, E] + attended = self.window_attn(windows) # [B, T-1, 2*P, E] + + # 4. Mean pool over the combined patches + pooled = attended.mean(dim=2) # [B, T-1, E] - actions = torch.stack(actions, dim=1) # [B, T-1, A] + # 5. Project to actions + actions = self.action_head(pooled) # [B, T-1, A] return actions From 5d8125e87f93c5a297a84ea1c8b99dc7ee750f3d Mon Sep 17 00:00:00 2001 From: imitation Date: Sat, 9 May 2026 19:39:12 -0400 Subject: [PATCH 2/2] Make windowed action encoder optional --- datasets/data_utils.py | 60 +++++++++++++++++++------------- models/latent_actions.py | 37 +++++++++++++++----- scripts/run_inference.py | 9 ++++- scripts/train_dynamics.py | 1 + scripts/train_latent_actions.py | 2 ++ scripts/train_video_tokenizer.py | 1 + utils/config.py | 2 ++ utils/inference_utils.py | 22 +++++++++--- utils/utils.py | 6 ++-- 9 files changed, 100 insertions(+), 40 deletions(-) diff --git a/datasets/data_utils.py b/datasets/data_utils.py index db0489d..14da821 100644 --- a/datasets/data_utils.py +++ b/datasets/data_utils.py @@ -17,18 +17,23 @@ DEFAULT_PERSISTENT_WORKERS = True -def _default_video_transform(): - return transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) +def _default_video_transform(resolution=None): + ops = [transforms.ToTensor()] + if resolution is not None: + size = (resolution, resolution) if isinstance(resolution, int) else resolution + ops.append(transforms.Resize(size)) + ops.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + return transforms.Compose(ops) -def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frames, transform=None, fps=30, preload_ratio=1, **kwargs): +def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frames, transform=None, fps=30, preload_ratio=1, resolution=None, **kwargs): current_folder_path = os.getcwd() video_path = current_folder_path + video_rel_path preprocessed_path = current_folder_path + h5_rel_path - transform = _default_video_transform() if transform is None else transform + transform = _default_video_transform(resolution=resolution) if transform is None else transform + dataset_kwargs = dict(kwargs) + if resolution is not None: + dataset_kwargs["resolution"] = (resolution, resolution) if isinstance(resolution, int) else resolution train = dataset_cls( video_path, @@ -38,7 +43,7 @@ def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frame num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, - **kwargs + **dataset_kwargs ) val = dataset_cls( video_path, @@ -48,63 +53,68 @@ def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frame num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, - **kwargs + **dataset_kwargs ) return train, val -def load_pong(num_frames=1, fps=15, preload_ratio=1): +def load_pong(num_frames=1, fps=15, preload_ratio=1, resolution=None): return _load_video_dataset_pair( PongDataset, '/data/pong.mp4', '/data/pong_frames.h5', num_frames=num_frames, fps=fps, - preload_ratio=preload_ratio + preload_ratio=preload_ratio, + resolution=resolution, ) -def load_sonic(num_frames=4, fps=15, preload_ratio=1): +def load_sonic(num_frames=4, fps=15, preload_ratio=1, resolution=None): return _load_video_dataset_pair( SonicDataset, '/data/sonic_frames.mp4', '/data/sonic_frames.h5', num_frames=num_frames, fps=fps, - preload_ratio=preload_ratio + preload_ratio=preload_ratio, + resolution=resolution, ) -def load_pole_position(num_frames=4, fps=15, preload_ratio=1): +def load_pole_position(num_frames=4, fps=15, preload_ratio=1, resolution=None): return _load_video_dataset_pair( PolePositionDataset, '/data/pole_position.mp4', '/data/pole_position_frames.h5', num_frames=num_frames, fps=fps, - preload_ratio=preload_ratio + preload_ratio=preload_ratio, + resolution=resolution, ) -def load_picodoom(num_frames=4, fps=30, preload_ratio=1): +def load_picodoom(num_frames=4, fps=30, preload_ratio=1, resolution=None): return _load_video_dataset_pair( PicoDoomDataset, '/data/picodoom cleaned.mp4', '/data/picodoom_frames.h5', num_frames=num_frames, fps=30, - preload_ratio=preload_ratio + preload_ratio=preload_ratio, + resolution=resolution, ) -def load_zelda(num_frames=4, fps=15, preload_ratio=1): +def load_zelda(num_frames=4, fps=15, preload_ratio=1, resolution=None): return _load_video_dataset_pair( ZeldaDataset, '/data/Zelda oot2d 1 Cut.mp4', '/data/zelda_frames.h5', num_frames=num_frames, fps=fps, - preload_ratio=preload_ratio + preload_ratio=preload_ratio, + resolution=resolution, ) @@ -141,17 +151,17 @@ def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, wo return train_loader, val_loader -def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1): +def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1, frame_size=None): if dataset == 'PONG': - training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size) elif dataset == 'SONIC': - training_data, validation_data = load_sonic(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + training_data, validation_data = load_sonic(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size) elif dataset == 'POLE_POSITION': - training_data, validation_data = load_pole_position(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + training_data, validation_data = load_pole_position(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size) elif dataset == 'PICODOOM': - training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size) elif dataset == 'ZELDA': - training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size) else: raise ValueError('Invalid dataset') diff --git a/models/latent_actions.py b/models/latent_actions.py index c073aee..08c057d 100644 --- a/models/latent_actions.py +++ b/models/latent_actions.py @@ -12,18 +12,22 @@ class LatentActionsEncoder(nn.Module): def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8, - hidden_dim=256, num_blocks=4, action_dim=3): + hidden_dim=256, num_blocks=4, action_dim=3, use_windowed_attention=False): super().__init__() + self.use_windowed_attention = use_windowed_attention self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim) self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True) - # windowed attention for action tokenization - self.window_attn = SpatialAttention(embed_dim, num_heads) + if self.use_windowed_attention: + self.window_attn = SpatialAttention(embed_dim, num_heads) + action_head_dim = embed_dim + else: + action_head_dim = embed_dim * 2 # embeddings to discrete latent bottleneck actions self.action_head = nn.Sequential( - nn.LayerNorm(embed_dim), - nn.Linear(embed_dim, 4 * action_dim), + nn.LayerNorm(action_head_dim), + nn.Linear(action_head_dim, 4 * action_dim), nn.GELU(), nn.Linear(4 * action_dim, action_dim) ) @@ -35,6 +39,13 @@ def forward(self, frames): embeddings = self.patch_embed(frames) # [B, T, P, E] transformed = self.transformer(embeddings) + if not self.use_windowed_attention: + pooled = transformed.mean(dim=2) # [B, T, E] + current_frames = pooled[:, :-1] # [B, T-1, E] + next_frames = pooled[:, 1:] # [B, T-1, E] + combined = torch.cat([current_frames, next_frames], dim=-1) + return self.action_head(combined) # [B, T-1, A] + # windowed attention over length-2 windows (current and next frame) # we want to combine frame t and t+1 @@ -109,11 +120,21 @@ def forward(self, frames, actions, training=True): class LatentActionModel(nn.Module): def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128, - num_heads=8, hidden_dim=256, num_blocks=4): + num_heads=8, hidden_dim=256, num_blocks=4, use_windowed_attention=False): super().__init__() assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}" self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS)) - self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim) + self.use_windowed_attention = use_windowed_attention + self.encoder = LatentActionsEncoder( + frame_size, + patch_size, + embed_dim, + num_heads, + hidden_dim, + num_blocks, + action_dim=self.action_dim, + use_windowed_attention=use_windowed_attention, + ) self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS) self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim) self.var_target = 0.01 @@ -147,4 +168,4 @@ def encode(self, frames): @property def model_type(self) -> str: - return ModelType.LatentActionModel \ No newline at end of file + return ModelType.LatentActionModel diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 6bf54a9..577b550 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -27,7 +27,7 @@ def main(): # check if any path is missing def missing(path: Optional[str]) -> bool: - return (path is None) or (not os.path.isfile(path)) + return (path is None) or (not (os.path.isfile(path) or os.path.isdir(path))) # resolve latest checkpoints if requested or any path missing base_dir = os.getcwd() @@ -72,6 +72,13 @@ def missing(path: Optional[str]) -> bool: data_overrides = {'preload_ratio': args.preload_ratio} else: data_overrides = {} + # infer frame_size from the loaded video_tokenizer (its patch_embed retains the trained spatial dims) + try: + _trained_frame_size = video_tokenizer.encoder.patch_embed.frame_size + h = _trained_frame_size[0] if isinstance(_trained_frame_size, (tuple, list)) else _trained_frame_size + data_overrides['frame_size'] = h + except AttributeError: + pass _, _, data_loader, _, _ = load_data_and_data_loaders( dataset=args.dataset, batch_size=1, num_frames=frames_to_load, **data_overrides) diff --git a/scripts/train_dynamics.py b/scripts/train_dynamics.py index 147368c..9268896 100644 --- a/scripts/train_dynamics.py +++ b/scripts/train_dynamics.py @@ -150,6 +150,7 @@ def main(): distributed=dist_setup['is_distributed'], rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, world_size=dist_setup['world_size'], + frame_size=args.frame_size, **data_overrides, ) train_iter = iter(training_loader) diff --git a/scripts/train_latent_actions.py b/scripts/train_latent_actions.py index 850be9d..cf3779c 100644 --- a/scripts/train_latent_actions.py +++ b/scripts/train_latent_actions.py @@ -46,6 +46,7 @@ def main(): distributed=dist_setup['is_distributed'], rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, world_size=dist_setup['world_size'], + frame_size=args.frame_size, **data_overrides, ) @@ -58,6 +59,7 @@ def main(): hidden_dim=args.hidden_dim, num_blocks=args.num_blocks, n_actions=args.n_actions, + use_windowed_attention=args.use_windowed_attention, ).to(args.device) if args.checkpoint: model, _ = load_latent_actions_from_checkpoint( diff --git a/scripts/train_video_tokenizer.py b/scripts/train_video_tokenizer.py index 9e05935..ffeb0d4 100644 --- a/scripts/train_video_tokenizer.py +++ b/scripts/train_video_tokenizer.py @@ -46,6 +46,7 @@ def main(): distributed=dist_setup['is_distributed'], rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, world_size=dist_setup['world_size'], + frame_size=args.frame_size, **data_overrides, ) # print("Length of training data:", len(training_data)) diff --git a/utils/config.py b/utils/config.py index 0e7a40c..924af67 100644 --- a/utils/config.py +++ b/utils/config.py @@ -152,6 +152,7 @@ class LatentActionsConfig: checkpoint: Optional[str] # device device: DeviceType = DeviceType.CUDA + use_windowed_attention: bool = False # other params fps: Optional[int] = None preload_ratio: Optional[float] = None @@ -252,6 +253,7 @@ class TrainingConfig: batch_size_per_gpu: Optional[int] = None gradient_accumulation_steps: Optional[int] = None log_interval: Optional[int] = None + use_windowed_attention: Optional[bool] = None n_updates: Optional[int] = None # number of optimizer.step(), excluding grad_accum_step fps: Optional[int] = None preload_ratio: Optional[float] = None diff --git a/utils/inference_utils.py b/utils/inference_utils.py index 8ae3037..846ac5a 100644 --- a/utils/inference_utils.py +++ b/utils/inference_utils.py @@ -7,19 +7,33 @@ from utils.utils import load_videotokenizer_from_checkpoint, load_latent_actions_from_checkpoint, load_dynamics_from_checkpoint from einops import repeat -def load_models(video_tokenizer_path, latent_actions_path, dynamics_path, device, use_actions=True): - # Load tokenizer and dynamics, and Latent Actions if using actions +def load_models(video_tokenizer_path, latent_actions_path, dynamics_path, device, use_actions=True, load_dynamics=True): + # Load tokenizer and optionally dynamics/latent actions. video_tokenizer, _vt_ckpt = load_videotokenizer_from_checkpoint(video_tokenizer_path, device) video_tokenizer.eval() latent_action_model = None if use_actions: latent_action_model, _latent_action_ckpt = load_latent_actions_from_checkpoint(latent_actions_path, device) latent_action_model.eval() - dynamics_model, _dyn_ckpt = load_dynamics_from_checkpoint(dynamics_path, device) - dynamics_model.eval() + dynamics_model = None + if load_dynamics: + dynamics_model, _dyn_ckpt = load_dynamics_from_checkpoint(dynamics_path, device) + dynamics_model.eval() return video_tokenizer, latent_action_model, dynamics_model +def reconstruct_frames(video_tokenizer, frames): + indices = video_tokenizer.tokenize(frames) + latents = video_tokenizer.quantizer.get_latents_from_indices(indices) + return video_tokenizer.detokenize(latents) + + +def compute_frame_mse(predicted_frames, target_frames): + predicted = ((predicted_frames.detach().to(torch.float32) + 1) / 2).clamp(0, 1) + target = ((target_frames.detach().to(torch.float32) + 1) / 2).clamp(0, 1) + return torch.mean((predicted - target) ** 2).item() + + def visualize_inference(predicted_frames, ground_truth_frames, inferred_actions, fps, use_actions=True): # Move to CPU and convert to numpy predicted_frames = predicted_frames.detach().cpu() diff --git a/utils/utils.py b/utils/utils.py index fbb6173..e98bb8b 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -248,8 +248,9 @@ def load_latent_actions_from_checkpoint(checkpoint_path, device, model = None, i 'num_heads': cfg.get('num_heads', 8), 'hidden_dim': cfg.get('hidden_dim', 256), 'num_blocks': cfg.get('num_blocks', 4), + 'use_windowed_attention': cfg.get('use_windowed_attention', False), } - if model is None: + if model is None or getattr(model.encoder, 'use_windowed_attention', False) != kwargs['use_windowed_attention']: model = LatentActionModel(**kwargs) set_model_state_dict( model=model, @@ -275,7 +276,8 @@ def load_dynamics_from_checkpoint(checkpoint_path, device, model = None, is_dist conditioning_dim = cfg.get('conditioning_dim', None) if conditioning_dim is None: cond_inferred = None - for k, v in model_sd.get('model', {}).items(): + state_items = model_sd.items() if isinstance(model_sd, dict) else [] + for k, v in state_items: # Linear weight shape: [out_features, in_features]; in_features is conditioning dim if k.endswith('to_gamma_beta.1.weight'): cond_inferred = int(v.shape[1])