diff --git a/README.md b/README.md index 0f00c0e..424fd39 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ When you make a PR, please: There are still many TODOs which may offer significant performance gains... - [ ] Try `RoPE`/`AliBi` Position Embeddings -- [ ] Add more datasets (Terraria, Street Fighter, \) +- [x] Add more datasets (Terraria, Street Fighter, \) - [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` diff --git a/datasets/data_utils.py b/datasets/data_utils.py index db0489d..71655ee 100644 --- a/datasets/data_utils.py +++ b/datasets/data_utils.py @@ -9,7 +9,7 @@ import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid -from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset +from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset, StreetFighterDataset, TerrariaDataset, SpaceInvadersDataset DEFAULT_NUM_WORKERS = 2 DEFAULT_PREFETCH_FACTOR = 2 @@ -108,6 +108,39 @@ def load_zelda(num_frames=4, fps=15, preload_ratio=1): ) +def load_street_fighter(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + StreetFighterDataset, + '/data/street_fighter.mp4', + '/data/street_fighter_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + +def load_terraria(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + TerrariaDataset, + '/data/terraria.mp4', + '/data/terraria_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + +def load_space_invaders(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + SpaceInvadersDataset, + '/data/space_invaders.mp4', + '/data/space_invaders_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, world_size=1): train_sampler = None val_sampler = None @@ -152,6 +185,12 @@ def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=Fa training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) elif dataset == 'ZELDA': training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'STREET_FIGHTER': + training_data, validation_data = load_street_fighter(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'TERRARIA': + training_data, validation_data = load_terraria(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'SPACE_INVADERS': + training_data, validation_data = load_space_invaders(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) else: raise ValueError('Invalid dataset') diff --git a/datasets/datasets.py b/datasets/datasets.py index e8ee557..f46b044 100644 --- a/datasets/datasets.py +++ b/datasets/datasets.py @@ -138,7 +138,6 @@ def __del__(self): if hasattr(self, 'h5_file'): self.h5_file.close() -# TODO: add more datasets class PongDataset(VideoHDF5Dataset): def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=1, resolution=(64, 64), fps=30, preload_ratio=1): super().__init__( @@ -227,3 +226,57 @@ def __init__(self, video_path, transform=None, save_path=None, train=True, num_f preprocess_read_step=1, preprocess_slice=None, ) + +class StreetFighterDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(128, 128), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=100, + preprocess_read_step=1, + preprocess_slice=None, + ) + +class TerrariaDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(128, 128), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=100, + preprocess_read_step=1, + preprocess_slice=None, + ) + +class SpaceInvadersDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(64, 64), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=0, + preprocess_read_step=2, + preprocess_slice=None, + )