Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ There are still many TODOs which may offer significant performance gains...
- [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT)
- [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean`
- [ ] Spend more compute on a much larger training run, scale to multi-billions of parameters
- [ ] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames
- [x] Accelerate dynamics training by producing, saving, and loading pre-processed image patch embeddings instead of full frames
- [x] Implement Mixture of Experts in the Feedforward Network - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20)
- [x] Try different optimizers (`Muon`, `SOAP`) - added by [eren23](https://github.com/eren23) in [#20](https://github.com/AlmondGod/tinyworlds/pull/20)
- [x] Train on more GPUs by adding `FSDP` Support — added by [alekseymalakhov11](https://github.com/alekseymalakhov11) in [#11](https://github.com/AlmondGod/tinyworlds/pull/11)
Expand Down
4 changes: 4 additions & 0 deletions configs/dynamics.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ num_blocks: 8
video_tokenizer_path:
latent_actions_path:

# Optional: path to pre-tokenized HDF5 (from scripts/preprocess_tokens.py)
# If set and file exists, skips video tokenizer forward pass each training step
cached_tokens_path:

# resume from checkpoint
checkpoint:
13 changes: 12 additions & 1 deletion datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset
from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset, TokenizedVideoDataset

DEFAULT_NUM_WORKERS = 2
DEFAULT_PREFETCH_FACTOR = 2
Expand Down Expand Up @@ -141,6 +141,17 @@ def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, wo
return train_loader, val_loader


def load_tokenized_dataset(tokens_path: str, batch_size: int, num_frames: int = 4,
preload_ratio: float = 1.0, distributed: bool = False,
rank: int = 0, world_size: int = 1):
"""Load a pre-tokenized dataset (from preprocess_tokens.py) for dynamics training."""
train_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio)
val_data = TokenizedVideoDataset(tokens_path, num_frames=num_frames, preload_ratio=preload_ratio)
train_loader, val_loader = data_loaders(train_data, val_data, batch_size,
distributed=distributed, rank=rank, world_size=world_size)
return train_data, val_data, train_loader, val_loader


def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1):
if dataset == 'PONG':
training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
Expand Down
25 changes: 24 additions & 1 deletion datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,30 @@
import torch
from typing import Optional, Tuple, Union

# TODO: Try pre-caching video tokens and have dataloader load video tokens instead of frames
class TokenizedVideoDataset(Dataset):
"""Loads pre-tokenized patch-token indices produced by scripts/preprocess_tokens.py.

Returns (indices, 0) where indices is [T, P] int32 -- a sequence of T frames,
each represented as P patch-token indices. Matches the interface of VideoHDF5Dataset
so it can be dropped in as a replacement for dynamics training.
"""
def __init__(self, tokens_path: str, num_frames: int = 4, preload_ratio: Optional[float] = None):
with h5py.File(tokens_path, 'r') as f:
total = len(f['tokens'])
n = total if preload_ratio is None else max(0, min(total, int(total * preload_ratio)))
self.data = f['tokens'][:n] # [N, P] int32
self.num_frames = num_frames

def __len__(self) -> int:
return max(0, len(self.data) - self.num_frames)

def __getitem__(self, index: int):
if index >= len(self):
raise IndexError(f"Index {index} out of range for dataset of length {len(self)}")
tokens = self.data[index:index + self.num_frames] # [T, P]
return torch.from_numpy(tokens).long(), 0


class VideoHDF5Dataset(Dataset):
def __init__(
self,
Expand Down
87 changes: 87 additions & 0 deletions scripts/preprocess_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Pre-tokenize a video dataset using a trained VideoTokenizer and save indices to HDF5.

The output HDF5 contains:
tokens: [N, P] int32 -- one row of P patch-token indices per frame
metadata: attrs with frame_size, patch_size, latent_dim, num_bins

Usage:
python scripts/preprocess_tokens.py \
--video_tokenizer_path runs/.../video_tokenizer \
--dataset PONG \
--output_path data/pong_tokens.h5 \
--device cuda \
--batch_size 64
"""
import os
import argparse
import torch
import h5py
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils.utils import load_videotokenizer_from_checkpoint
from datasets.data_utils import load_data_and_data_loaders


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--video_tokenizer_path", required=True, help="Path to video tokenizer checkpoint dir")
parser.add_argument("--dataset", required=True, help="Dataset name (PONG, SONIC, ZELDA, ...)")
parser.add_argument("--output_path", required=True, help="Output HDF5 path for token indices")
parser.add_argument("--device", default="cuda", help="Device to run tokenizer on")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_frames", type=int, default=1, help="Frames per sample (use 1 to tokenize each frame independently)")
parser.add_argument("--preload_ratio", type=float, default=1.0)
return parser.parse_args()


@torch.no_grad()
def main():
args = parse_args()

print(f"Loading video tokenizer from {args.video_tokenizer_path}")
video_tokenizer, _ = load_videotokenizer_from_checkpoint(
checkpoint_path=args.video_tokenizer_path,
device=args.device,
)
video_tokenizer.eval()
video_tokenizer.to(args.device)

print(f"Loading dataset {args.dataset}")
train_data, val_data, _, _, _ = load_data_and_data_loaders(
dataset=args.dataset,
batch_size=args.batch_size,
num_frames=args.num_frames,
preload_ratio=args.preload_ratio,
)

all_tokens = []
for split_name, split_data in [("train", train_data), ("val", val_data)]:
loader = DataLoader(split_data, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=2)
print(f"Tokenizing {split_name} split ({len(split_data)} samples)...")
for frames, _ in tqdm(loader):
frames = frames.to(args.device) # [B, T, C, H, W]
indices = video_tokenizer.tokenize(frames) # [B, T, P]
B, T, P = indices.shape
indices = indices.reshape(B * T, P) # flatten T into N
all_tokens.append(indices.cpu().numpy().astype(np.int32))

all_tokens = np.concatenate(all_tokens, axis=0) # [N, P]
print(f"Total frames tokenized: {all_tokens.shape[0]}, patches per frame: {all_tokens.shape[1]}")

# save to HDF5
os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
with h5py.File(args.output_path, 'w') as f:
f.create_dataset('tokens', data=all_tokens, compression='lzf')
# store metadata so consumers can reconstruct quantizer params
vt = video_tokenizer
f.attrs['patch_size'] = vt.encoder.patch_embed.patch_size if hasattr(vt.encoder.patch_embed, 'patch_size') else -1
f.attrs['latent_dim'] = vt.quantizer.latent_dim
f.attrs['num_bins'] = vt.quantizer.num_bins
f.attrs['codebook_size'] = vt.codebook_size
f.attrs['patches_per_frame'] = int(all_tokens.shape[1])
print(f"Saved tokenized dataset to {args.output_path}")


if __name__ == "__main__":
main()
68 changes: 45 additions & 23 deletions scripts/train_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm
from einops import rearrange
from models.dynamics import DynamicsModel
from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders
from datasets.data_utils import visualize_reconstruction, load_data_and_data_loaders, load_tokenized_dataset
from tqdm import tqdm
from einops import rearrange
from utils.wandb_utils import (
Expand Down Expand Up @@ -128,21 +128,37 @@ def main():

unwrap_model(dynamics_model).train()

# dataloader
data_overrides = {}
if hasattr(args, 'fps') and args.fps is not None:
data_overrides['fps'] = args.fps
if hasattr(args, 'preload_ratio') and args.preload_ratio is not None:
data_overrides['preload_ratio'] = args.preload_ratio
_, _, training_loader, _, _ = load_data_and_data_loaders(
dataset=args.dataset,
batch_size=args.batch_size_per_gpu,
num_frames=args.context_length,
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
**data_overrides,
)
# dataloader: use pre-tokenized cache if available, otherwise load raw frames
cached_tokens_path = getattr(args, 'cached_tokens_path', None)
using_cached_tokens = cached_tokens_path is not None and os.path.isfile(cached_tokens_path)
if using_cached_tokens:
if is_main:
print(f"Using pre-tokenized dataset from {cached_tokens_path}")
preload_ratio = args.preload_ratio if hasattr(args, 'preload_ratio') and args.preload_ratio else 1.0
_, _, training_loader, _ = load_tokenized_dataset(
tokens_path=cached_tokens_path,
batch_size=args.batch_size_per_gpu,
num_frames=args.context_length,
preload_ratio=preload_ratio,
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
)
else:
data_overrides = {}
if hasattr(args, 'fps') and args.fps is not None:
data_overrides['fps'] = args.fps
if hasattr(args, 'preload_ratio') and args.preload_ratio is not None:
data_overrides['preload_ratio'] = args.preload_ratio
_, _, training_loader, _, _ = load_data_and_data_loaders(
dataset=args.dataset,
batch_size=args.batch_size_per_gpu,
num_frames=args.context_length,
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
**data_overrides,
)
train_iter = iter(training_loader)

use_moe = getattr(args, 'use_moe', False)
Expand All @@ -161,15 +177,21 @@ def main():
train_iter = iter(training_loader) # reset iterator when epoch ends
x, _ = next(train_iter)

x = x.to(args.device, non_blocking=True) # [batch_size, seq_len, channels, height, width]
x = x.to(args.device, non_blocking=True)

# get video tokens for batch
video_tokens = video_tokenizer.tokenize(x) # [B, T, P]
video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L]
if args.use_actions:
quantized_actions = latent_action_model.encode(x) # [B, T - 1, A]
if using_cached_tokens:
# x is already token indices [B, T, P] -- skip video tokenizer forward pass
video_tokens = x # [B, T, P]
video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L]
quantized_actions = None # action conditioning not supported with cached tokens
else:
quantized_actions = None
# x is raw frames [B, T, C, H, W]
video_tokens = video_tokenizer.tokenize(x) # [B, T, P]
video_latents = video_tokenizer.quantizer.get_latents_from_indices(video_tokens, dim=-1) # [B, T, P, L]
if args.use_actions:
quantized_actions = latent_action_model.encode(x) # [B, T - 1, A]
else:
quantized_actions = None

# predict masked frame latents with dynamics model (masking in dynamics model)
with train_ctx:
Expand Down
1 change: 1 addition & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class DynamicsConfig:
# other params
fps: Optional[int] = None
preload_ratio: Optional[float] = None
cached_tokens_path: Optional[str] = None

def __post_init__(self) -> None:
_validate_amp_fsdp(self.amp, self.distributed)
Expand Down