diff --git a/README.md b/README.md index 0f00c0e..8d9fc3a 100644 --- a/README.md +++ b/README.md @@ -275,7 +275,7 @@ There are still many TODOs which may offer significant performance gains... - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` - [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters -- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames +- [x] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames - [x] Implement Mixture of Experts in the Feedforward Network - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20) - [x] Try different optimizers (`Muon`, `SOAP`) - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20) - [x] Train on more GPUs by adding `FSDP` Support — added by [alekseymalakhov11](https://github.com/alekseymalakhov11) in [#11](https://github.com/AlmondGod/tinyworlds/pull/11) diff --git a/configs/dynamics.yaml b/configs/dynamics.yaml index 6db4a5f..746a628 100644 --- a/configs/dynamics.yaml +++ b/configs/dynamics.yaml @@ -17,5 +17,9 @@ num_blocks: 8 video_tokenizer_path: latent_actions_path: +# Optional: path to pre-tokenized HDF5 (from scripts/preprocess_tokens.py) +# If set and file exists, skips video tokenizer forward pass each training step +cached_tokens_path: + # resume from checkpoint checkpoint: \ No newline at end of file diff --git a/datasets/data_utils.py b/datasets/data_utils.py index db0489d..bb7254d 100644 --- a/datasets/data_utils.py +++ b/datasets/data_utils.py @@ -9,7 +9,7 @@ import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid -from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset +from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset, TokenizedVideoDataset DEFAULT_NUM_WORKERS = 2 DEFAULT_PREFETCH_FACTOR = 2 @@ -141,6 +141,17 @@ def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, wo return train_loader, val_loader +def load_tokenized_dataset(tokens_path: str, batch_size: int, num_frames: int = 4, + preload_ratio: float = 1.0, distributed: bool = False, + rank: int = 0, world_size: int = 1): + """Load a pre-tokenized dataset (from preprocess_tokens.py) for dynamics training.""" + train_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio) + val_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio) + train_loader, val_loader = data_loaders(train_data, val_data, batch_size, + distributed=distributed, rank=rank, world_size=world_size) + return train_data, val_data, train_loader, val_loader + + def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1): if dataset == 'PONG': training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) diff --git a/datasets/datasets.py b/datasets/datasets.py index e8ee557..52f9cfa 100644 --- a/datasets/datasets.py +++ b/datasets/datasets.py @@ -7,7 +7,30 @@ import torch from typing import Optional, Tuple, Union -# TODO: Try pre-caching video tokens and have dataloader load video tokens instead of frames +class TokenizedVideoDataset(Dataset): + """Loads pre-tokenized patch-token indices produced by scripts/preprocess_tokens.py. + + Returns (indices, 0) where indices is [T, P] int32 -- a sequence of T frames, + each represented as P patch-token indices. Matches the interface of VideoHDF5Dataset + so it can be dropped in as a replacement for dynamics training. + """ + def __init__(self, tokens_path: str, num_frames: int = 4, preload_ratio: Optional[float] = None): + with h5py.File(tokens_path, 'r') as f: + total = len(f['tokens']) + n = total if preload_ratio is None else max(0, min(total, int(total * preload_ratio))) + self.data = f['tokens'][:n] # [N, P] int32 + self.num_frames = num_frames + + def __len__(self) -> int: + return max(0, len(self.data) - self.num_frames) + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError(f"Index {index} out of range for dataset of length {len(self)}") + tokens = self.data[index:index + self.num_frames] # [T, P] + return torch.from_numpy(tokens).long(), 0 + + class VideoHDF5Dataset(Dataset): def __init__( self, diff --git a/scripts/preprocess_tokens.py b/scripts/preprocess_tokens.py new file mode 100644 index 0000000..d63bdf9 --- /dev/null +++ b/scripts/preprocess_tokens.py @@ -0,0 +1,87 @@ +"""Pre-tokenize a video dataset using a trained VideoTokenizer and save indices to HDF5. + +The output HDF5 contains: + tokens: [N, P] int32 -- one row of P patch-token indices per frame + metadata: attrs with frame_size, patch_size, latent_dim, num_bins + +Usage: + python scripts/preprocess_tokens.py \ + --video_tokenizer_path runs/.../video_tokenizer \ + --dataset PONG \ + --output_path data/pong_tokens.h5 \ + --device cuda \ + --batch_size 64 +""" +import os +import argparse +import torch +import h5py +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader +from utils.utils import load_videotokenizer_from_checkpoint +from datasets.data_utils import load_data_and_data_loaders + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--video_tokenizer_path", required=True, help="Path to video tokenizer checkpoint dir") + parser.add_argument("--dataset", required=True, help="Dataset name (PONG, SONIC, ZELDA, ...)") + parser.add_argument("--output_path", required=True, help="Output HDF5 path for token indices") + parser.add_argument("--device", default="cuda", help="Device to run tokenizer on") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_frames", type=int, default=1, help="Frames per sample (use 1 to tokenize each frame independently)") + parser.add_argument("--preload_ratio", type=float, default=1.0) + return parser.parse_args() + + +@torch.no_grad() +def main(): + args = parse_args() + + print(f"Loading video tokenizer from {args.video_tokenizer_path}") + video_tokenizer, _ = load_videotokenizer_from_checkpoint( + checkpoint_path=args.video_tokenizer_path, + device=args.device, + ) + video_tokenizer.eval() + video_tokenizer.to(args.device) + + print(f"Loading dataset {args.dataset}") + train_data, val_data, _, _, _ = load_data_and_data_loaders( + dataset=args.dataset, + batch_size=args.batch_size, + num_frames=args.num_frames, + preload_ratio=args.preload_ratio, + ) + + all_tokens = [] + for split_name, split_data in [("train", train_data), ("val", val_data)]: + loader = DataLoader(split_data, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=2) + print(f"Tokenizing {split_name} split ({len(split_data)} samples)...") + for frames, _ in tqdm(loader): + frames = frames.to(args.device) # [B, T, C, H, W] + indices = video_tokenizer.tokenize(frames) # [B, T, P] + B, T, P = indices.shape + indices = indices.reshape(B * T, P) # flatten T into N + all_tokens.append(indices.cpu().numpy().astype(np.int32)) + + all_tokens = np.concatenate(all_tokens, axis=0) # [N, P] + print(f"Total frames tokenized: {all_tokens.shape[0]}, patches per frame: {all_tokens.shape[1]}") + + # save to HDF5 + os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True) + with h5py.File(args.output_path, 'w') as f: + f.create_dataset('tokens', data=all_tokens, compression='lzf') + # store metadata so consumers can reconstruct quantizer params + vt = video_tokenizer + f.attrs['patch_size'] = vt.encoder.patch_embed.patch_size if hasattr(vt.encoder.patch_embed, 'patch_size') else -1 + f.attrs['latent_dim'] = vt.quantizer.latent_dim + f.attrs['num_bins'] = vt.quantizer.num_bins + f.attrs['codebook_size'] = vt.codebook_size + f.attrs['patches_per_frame'] = int(all_tokens.shape[1]) + print(f"Saved tokenized dataset to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_dynamics.py b/scripts/train_dynamics.py index 934dc55..ad6dc07 100644 --- a/scripts/train_dynamics.py +++ b/scripts/train_dynamics.py @@ -4,7 +4,7 @@ from tqdm import tqdm from einops import rearrange from models.dynamics import DynamicsModel -from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders +from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders, load_tokenized_dataset from tqdm import tqdm from einops import rearrange from utils.wandb_utils import ( @@ -128,21 +128,37 @@ def main(): unwrap_model(dynamics_model).train() - # dataloader - data_overrides = {} - if hasattr(args, 'fps') and args.fps is not None: - data_overrides['fps'] = args.fps - if hasattr(args, 'preload_ratio') and args.preload_ratio is not None: - data_overrides['preload_ratio'] = args.preload_ratio - _, _, training_loader, _, _ = load_data_and_data_loaders( - dataset=args.dataset, - batch_size=args.batch_size_per_gpu, - num_frames=args.context_length, - distributed=dist_setup['is_distributed'], - rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, - world_size=dist_setup['world_size'], - **data_overrides, - ) + # dataloader: use pre-tokenized cache if available, otherwise load raw frames + cached_tokens_path = getattr(args, 'cached_tokens_path', None) + using_cached_tokens = cached_tokens_path is not None and os.path.isfile(cached_tokens_path) + if using_cached_tokens: + if is_main: + print(f"Using pre-tokenized dataset from {cached_tokens_path}") + preload_ratio = args.preload_ratio if hasattr(args, 'preload_ratio') and args.preload_ratio else 1.0 + _, _, training_loader, _ = load_tokenized_dataset( + tokens_path=cached_tokens_path, + batch_size=args.batch_size_per_gpu, + num_frames=args.context_length, + preload_ratio=preload_ratio, + distributed=dist_setup['is_distributed'], + rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, + world_size=dist_setup['world_size'], + ) + else: + data_overrides = {} + if hasattr(args, 'fps') and args.fps is not None: + data_overrides['fps'] = args.fps + if hasattr(args, 'preload_ratio') and args.preload_ratio is not None: + data_overrides['preload_ratio'] = args.preload_ratio + _, _, training_loader, _, _ = load_data_and_data_loaders( + dataset=args.dataset, + batch_size=args.batch_size_per_gpu, + num_frames=args.context_length, + distributed=dist_setup['is_distributed'], + rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0, + world_size=dist_setup['world_size'], + **data_overrides, + ) train_iter = iter(training_loader) use_moe = getattr(args, 'use_moe', False) @@ -161,15 +177,21 @@ def main(): train_iter = iter(training_loader) # reset iterator when epoch ends x, _ = next(train_iter) - x = x.to(args.device, non_blocking=True) # [batch_size, seq_len, channels, height, width] + x = x.to(args.device, non_blocking=True) - # get video tokens for batch - video_tokens = video_tokenizer.tokenize(x) # [B, T, P] - video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] - if args.use_actions: - quantized_actions = latent_action_model.encode(x) # [B, T - 1, A] + if using_cached_tokens: + # x is already token indices [B, T, P] -- skip video tokenizer forward pass + video_tokens = x # [B, T, P] + video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] + quantized_actions = None # action conditioning not supported with cached tokens else: - quantized_actions = None + # x is raw frames [B, T, C, H, W] + video_tokens = video_tokenizer.tokenize(x) # [B, T, P] + video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L] + if args.use_actions: + quantized_actions = latent_action_model.encode(x) # [B, T - 1, A] + else: + quantized_actions = None # predict masked frame latents with dynamics model (masking in dynamics model) with train_ctx: diff --git a/utils/config.py b/utils/config.py index 846c7b2..74a70a5 100644 --- a/utils/config.py +++ b/utils/config.py @@ -221,6 +221,7 @@ class DynamicsConfig: # other params fps: Optional[int] = None preload_ratio: Optional[float] = None + cached_tokens_path: Optional[str] = None def __post_init__(self) -> None: _validate_amp_fsdp(self.amp, self.distributed)