From c055fc9fdd4aac96fd8662d8dd7776e17af8e8c7 Mon Sep 17 00:00:00 2001 From: Tasha Date: Thu, 16 Apr 2026 12:49:56 -0700 Subject: [PATCH] feat: add pre-tokenized dataset cache to accelerate dynamics training Adds scripts/preprocess_tokens.py: runs a trained VideoTokenizer over an entire dataset and saves token indices as [N, P] int32 to HDF5. Adds TokenizedVideoDataset: loads pre-tokenized HDF5 and returns [T, P] index sequences, with the same (tokens, 0) interface as VideoHDF5Dataset. In train_dynamics.py, if cached_tokens_path is set and exists, the dataloader returns token indices directly, skipping the video tokenizer forward pass each training step. This eliminates the tokenizer bottleneck for repeated runs on the same dataset. Use: python scripts/preprocess_tokens.py --video_tokenizer_path \ --dataset PONG --output_path data/pong_tokens.h5 Then: set cached_tokens_path in configs/dynamics.yaml --- README.md | 2 +- configs/dynamics.yaml | 4 ++ datasets/data_utils.py | 13 +++++- datasets/datasets.py | 25 ++++++++++- scripts/preprocess_tokens.py | 87 ++++++++++++++++++++++++++++++++++++ scripts/train_dynamics.py | 68 ++++++++++++++++++---------- utils/config.py | 1 + 7 files changed, 174 insertions(+), 26 deletions(-) create mode 100644 scripts/preprocess_tokens.py diff --git a/README.md b/README.md index 0f00c0e..8d9fc3a 100644 --- a/README.md +++ b/README.md @@ -275,7 +275,7 @@ There are still many TODOs which may offer significant performance gains... - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters -- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames +- [x] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames - [x] Implement Mixture of Experts in the Feedforward Network - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20) - [x] Try different optimizers (`Muon`, `SOAP`) - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20) - [x] Train on more GPUs by adding `FSDP` Support — added by [alekseymalakhov11](https://github.com/alekseymalakhov11) in [#11](https://github.com/AlmondGod/tinyworlds/pull/11) diff --git a/configs/dynamics.yaml b/configs/dynamics.yaml index 6db4a5f..746a628 100644 --- a/configs/dynamics.yaml +++ b/configs/dynamics.yaml @@ -17,5 +17,9 @@ num_blocks: 8 video_tokenizer_path: latent_actions_path: +# Optional: path to pre-tokenized HDF5 (from scripts/preprocess_tokens.py) +# If set and file exists, skips video tokenizer forward pass each training step +cached_tokens_path: + # resume from checkpoint checkpoint: \ No newline at end of file diff --git a/datasets/data_utils.py b/datasets/data_utils.py index db0489d..bb7254d 100644 --- a/datasets/data_utils.py +++ b/datasets/data_utils.py @@ -9,7 +9,7 @@ import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid -from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset +from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset, TokenizedVideoDataset DEFAULT_NUM_WORKERS = 2 DEFAULT_PREFETCH_FACTOR = 2 @@ -141,6 +141,17 @@ def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, wo return train_loader, val_loader +def load_tokenized_dataset(tokens_path: str, batch_size: int, num_frames: int = 4, + preload_ratio: float = 1.0, distributed: bool = False, + rank: int = 0, world_size: int = 1): + """Load a pre-tokenized dataset (from preprocess_tokens.py) for dynamics training.""" + train_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio) + val_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio) + train_loader, val_loader = data_loaders(train_data, val_data, batch_size, + distributed=distributed, rank=rank, world_size=world_size) + return train_data, val_data, train_loader, val_loader + + def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1): if dataset == 'PONG': training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) diff --git a/datasets/datasets.py b/datasets/datasets.py index e8ee557..52f9cfa 100644 --- a/datasets/datasets.py +++ b/datasets/datasets.py @@ -7,7 +7,30 @@ import torch from typing import Optional, Tuple, Union -# TODO: Try pre-caching video tokens and have dataloader load video tokens instead of frames +class TokenizedVideoDataset(Dataset): + """Loads pre-tokenized patch-token indices produced by scripts/preprocess_tokens.py. + + Returns (indices, 0) where indices is [T, P] int32 -- a sequence of T frames, + each represented as P patch-token indices. Matches the interface of VideoHDF5Dataset + so it can be dropped in as a replacement for dynamics training. + """ + def __init__(self, tokens_path: str, num_frames: int = 4, preload_ratio: Optional[float] = None): + with h5py.File(tokens_path, 'r') as f: + total = len(f['tokens']) + n = total if preload_ratio is None else max(0, min(total, int(total * preload_ratio))) + self.data = f['tokens'][:n] # [N, P] int32 + self.num_frames = num_frames + + def __len__(self) -> int: + return max(0, len(self.data) - self.num_frames) + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError(f"Index {index} out of range for dataset of length {len(self)}") + tokens = self.data[index:index + self.num_frames] # [T, P] + return torch.from_numpy(tokens).long(), 0 + + class VideoHDF5Dataset(Dataset): def __init__( self, diff --git a/scripts/preprocess_tokens.py b/scripts/preprocess_tokens.py new file mode 100644 index 0000000..d63bdf9 --- /dev/null +++ b/scripts/preprocess_tokens.py @@ -0,0 +1,87 @@ +"""Pre-tokenize a video dataset using a trained VideoTokenizer and save indices to HDF5. + +The output HDF5 contains: + tokens: [N, P] int32 -- one row of P patch-token indices per frame + metadata: attrs with frame_size, patch_size, latent_dim, num_bins + +Usage: + python scripts/preprocess_tokens.py \ + --video_tokenizer_path runs/.../video_tokenizer \ + --dataset PONG \ + --output_path data/pong_tokens.h5 \ + --device cuda \ + --batch_size 64 +""" +import os +import argparse +import torch +import h5py +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader +from utils.utils import load_videotokenizer_from_checkpoint +from datasets.data_utils import load_data_and_data_loaders + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--video_tokenizer_path", required=True, help="Path to video tokenizer checkpoint dir") + parser.add_argument("--dataset", required=True, help="Dataset name (PONG, SONIC, ZELDA, ...)") + parser.add_argument("--output_path", required=True, help="Output HDF5 path for token indices") + parser.add_argument("--device", default="cuda", help="Device to run tokenizer on") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_frames", type=int, default=1, help="Frames per sample (use 1 to tokenize each frame independently)") + parser.add_argument("--preload_ratio", type=float, default=1.0) + return parser.parse_args() + + +@torch.no_grad() +def main(): + args = parse_args() + + print(f"Loading video tokenizer from {args.video_tokenizer_path}") + video_tokenizer, _ = load_videotokenizer_from_checkpoint( + checkpoint_path=args.video_tokenizer_path, + device=args.device, + ) + video_tokenizer.eval() + video_tokenizer.to(args.device) + + print(f"Loading dataset {args.dataset}") + train_data, val_data, _, _, _ = load_data_and_data_loaders( + dataset=args.dataset, + batch_size=args.batch_size, + num_frames=args.num_frames, + preload_ratio=args.preload_ratio, + ) + + all_tokens = [] + for split_name, split_data in [("train", train_data), ("val", val_data)]: + loader = DataLoader(split_data, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=2) + print(f"Tokenizing {split_name} split ({len(split_data)} samples)...") + for frames, _ in tqdm(loader): + frames = frames.to(args.device) # [B, T, C, H, W] + indices = video_tokenizer.tokenize(frames) # [B, T, P] + B, T, P = indices.shape + indices = indices.reshape(B * T, P) # flatten T into N + all_tokens.append(indices.cpu().numpy().astype(np.int32)) + + all_tokens = np.concatenate(all_tokens, axis=0) # [N, P] + print(f"Total frames tokenized: {all_tokens.shape[0]}, patches per frame: {all_tokens.shape[1]}") + + # save to HDF5 + os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True) + with h5py.File(args.output_path, 'w') as f: + f.create_dataset('tokens', data=all_tokens, compression='lzf') + # store metadata so consumers can reconstruct quantizer params + vt = video_tokenizer + f.attrs['patch_size'] = vt.encoder.patch_embed.patch_size if hasattr(vt.encoder.patch_embed, 'patch_size') else -1 + f.attrs['latent_dim'] = vt.quantizer.latent_dim + f.attrs['num_bins'] = vt.quantizer.num_bins + f.attrs['codebook_size'] = vt.codebook_size + f.attrs['patches_per_frame'] = int(all_tokens.shape[1]) + print(f"Saved tokenized dataset to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_dynamics.py b/scripts/train_dynamics.py index 934dc55..ad6dc07 100644 --- a/scripts/train_dynamics.py +++ b/scripts/train_dynamics.py @@ -4,7 +4,7 @@ from tqdm import tqdm from einops import rearrange from models.dynamics import DynamicsModel -from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders +from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders, load_tokenized_dataset from tqdm import tqdm from einops import rearrange from utils.wandb_utils import ( @@ -128,21 +128,37 @@ def main(): unwrap_model(dynamics_model).train() - # dataloader - data_overrides = {} - if hasattr(args, 'fps') and args.fps is not None: - data_overrides['fps'] = args.fps - if hasattr(args, 'preload_ratio') and args.preload_ratio is not None: - data_overrides['preload_ratio'] = args.preload_ratio - _, _, training_loader, _, _ = load_data_and_data_loaders( - dataset=args.dataset, - batch_size=args.batch_size_per_gpu, - num_frames=args.context_length, - distributed=dist_setup['is_distributed'], - rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, - world_size=dist_setup['world_size'], - **data_overrides, - ) + # dataloader: use pre-tokenized cache if available, otherwise load raw frames + cached_tokens_path = getattr(args, 'cached_tokens_path', None) + using_cached_tokens = cached_tokens_path is not None and os.path.isfile(cached_tokens_path) + if using_cached_tokens: + if is_main: + print(f"Using pre-tokenized dataset from {cached_tokens_path}") + preload_ratio = args.preload_ratio if hasattr(args, 'preload_ratio') and args.preload_ratio else 1.0 + _, _, training_loader, _ = load_tokenized_dataset( + tokens_path=cached_tokens_path, + batch_size=args.batch_size_per_gpu, + num_frames=args.context_length, + preload_ratio=preload_ratio, + distributed=dist_setup['is_distributed'], + rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, + world_size=dist_setup['world_size'], + ) + else: + data_overrides = {} + if hasattr(args, 'fps') and args.fps is not None: + data_overrides['fps'] = args.fps + if hasattr(args, 'preload_ratio') and args.preload_ratio is not None: + data_overrides['preload_ratio'] = args.preload_ratio + _, _, training_loader, _, _ = load_data_and_data_loaders( + dataset=args.dataset, + batch_size=args.batch_size_per_gpu, + num_frames=args.context_length, + distributed=dist_setup['is_distributed'], + rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, + world_size=dist_setup['world_size'], + **data_overrides, + ) train_iter = iter(training_loader) use_moe = getattr(args, 'use_moe', False) @@ -161,15 +177,21 @@ def main(): train_iter = iter(training_loader) # reset iterator when epoch ends x, _ = next(train_iter) - x = x.to(args.device, non_blocking=True) # [batch_size, seq_len, channels, height, width] + x = x.to(args.device, non_blocking=True) - # get video tokens for batch - video_tokens = video_tokenizer.tokenize(x) # [B, T, P] - video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] - if args.use_actions: - quantized_actions = latent_action_model.encode(x) # [B, T - 1, A] + if using_cached_tokens: + # x is already token indices [B, T, P] -- skip video tokenizer forward pass + video_tokens = x # [B, T, P] + video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] + quantized_actions = None # action conditioning not supported with cached tokens else: - quantized_actions = None + # x is raw frames [B, T, C, H, W] + video_tokens = video_tokenizer.tokenize(x) # [B, T, P] + video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] + if args.use_actions: + quantized_actions = latent_action_model.encode(x) # [B, T - 1, A] + else: + quantized_actions = None # predict masked frame latents with dynamics model (masking in dynamics model) with train_ctx: diff --git a/utils/config.py b/utils/config.py index 846c7b2..74a70a5 100644 --- a/utils/config.py +++ b/utils/config.py @@ -221,6 +221,7 @@ class DynamicsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + cached_tokens_path: Optional[str] = None def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed)