From 35edb96e8c969a9e55587e31e3ca44503bd97c54 Mon Sep 17 00:00:00 2001 From: Tasha Date: Thu, 16 Apr 2026 12:35:36 -0700 Subject: [PATCH] feat: add Street Fighter, Terraria, and Space Invaders dataset classes Adds StreetFighterDataset, TerrariaDataset, and SpaceInvadersDataset subclasses of VideoHDF5Dataset with appropriate resolution and fps defaults. Also adds corresponding loader functions and dispatch cases in data_utils.py. --- README.md | 2 +- datasets/data_utils.py | 41 ++++++++++++++++++++++++++++++- datasets/datasets.py | 55 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0f00c0e..424fd39 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ When you make a PR, please: There are still many TODOs which may offer significant performance gains... - [ ] Try `RoPE`/`AliBi` Position Embeddings -- [ ] Add more datasets (Terraria, Street Fighter, \) +- [x] Add more datasets (Terraria, Street Fighter, \) - [ ] Try [AdaLN-Zero](https://arxiv.org/pdf/2212.09748) instead of `FiLM` (adds a pre-scale parameter) - [ ] Add new schedulers for MaskGIT like cosine and [Halton](https://github.com/valeoai/Halton-MaskGIT) - [ ] Replace `mean pool + concat` in the action tokenizer with `length-2 windowed attention + mean` diff --git a/datasets/data_utils.py b/datasets/data_utils.py index db0489d..71655ee 100644 --- a/datasets/data_utils.py +++ b/datasets/data_utils.py @@ -9,7 +9,7 @@ import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid -from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset +from datasets.datasets import PongDataset, SonicDataset, PolePositionDataset, PicoDoomDataset, ZeldaDataset, StreetFighterDataset, TerrariaDataset, SpaceInvadersDataset DEFAULT_NUM_WORKERS = 2 DEFAULT_PREFETCH_FACTOR = 2 @@ -108,6 +108,39 @@ def load_zelda(num_frames=4, fps=15, preload_ratio=1): ) +def load_street_fighter(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + StreetFighterDataset, + '/data/street_fighter.mp4', + '/data/street_fighter_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + +def load_terraria(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + TerrariaDataset, + '/data/terraria.mp4', + '/data/terraria_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + +def load_space_invaders(num_frames=4, fps=15, preload_ratio=1): + return _load_video_dataset_pair( + SpaceInvadersDataset, + '/data/space_invaders.mp4', + '/data/space_invaders_frames.h5', + num_frames=num_frames, + fps=fps, + preload_ratio=preload_ratio + ) + + def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, world_size=1): train_sampler = None val_sampler = None @@ -152,6 +185,12 @@ def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=Fa training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) elif dataset == 'ZELDA': training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'STREET_FIGHTER': + training_data, validation_data = load_street_fighter(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'TERRARIA': + training_data, validation_data = load_terraria(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) + elif dataset == 'SPACE_INVADERS': + training_data, validation_data = load_space_invaders(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio) else: raise ValueError('Invalid dataset') diff --git a/datasets/datasets.py b/datasets/datasets.py index e8ee557..f46b044 100644 --- a/datasets/datasets.py +++ b/datasets/datasets.py @@ -138,7 +138,6 @@ def __del__(self): if hasattr(self, 'h5_file'): self.h5_file.close() -# TODO: add more datasets class PongDataset(VideoHDF5Dataset): def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=1, resolution=(64, 64), fps=30, preload_ratio=1): super().__init__( @@ -227,3 +226,57 @@ def __init__(self, video_path, transform=None, save_path=None, train=True, num_f preprocess_read_step=1, preprocess_slice=None, ) + +class StreetFighterDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(128, 128), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=100, + preprocess_read_step=1, + preprocess_slice=None, + ) + +class TerrariaDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(128, 128), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=100, + preprocess_read_step=1, + preprocess_slice=None, + ) + +class SpaceInvadersDataset(VideoHDF5Dataset): + def __init__(self, video_path, transform=None, save_path=None, train=True, num_frames=4, resolution=(64, 64), fps=15, preload_ratio=1): + super().__init__( + video_path=video_path, + transform=transform, + save_path=save_path, + train=train, + num_frames=num_frames, + resize_to=resolution, + fps=fps, + preload_ratio=preload_ratio, + sequence_stride=None, + load_chunk_size=1000, + load_start_index=0, + preprocess_read_step=2, + preprocess_slice=None, + )