Skip to content

feat: add pre-tokenized dataset cache to accelerate dynamics training#26

Open
tashapais wants to merge 1 commit into
AlmondGod:mainfrom
tashapais:feat/patch-embed-cache
Open

feat: add pre-tokenized dataset cache to accelerate dynamics training#26
tashapais wants to merge 1 commit into
AlmondGod:mainfrom
tashapais:feat/patch-embed-cache

Conversation

@tashapais
Copy link
Copy Markdown

Summary

  • Adds scripts/preprocess_tokens.py: one-time script that runs a trained VideoTokenizer over a full dataset and saves token indices as [N, P] int32 to HDF5. Stores metadata (latent_dim, num_bins, codebook_size) as HDF5 attrs.
  • Adds TokenizedVideoDataset to datasets/datasets.py: loads the pre-tokenized HDF5 and returns [T, P] index sequences, matching the (tokens, 0) interface of VideoHDF5Dataset.
  • Updates train_dynamics.py: when cached_tokens_path is set in config and the file exists, the dataloader returns token indices directly. The video tokenizer forward pass is skipped each training step, eliminating the per-batch tokenization overhead.
  • Adds cached_tokens_path: Optional[str] = None to DynamicsConfig.

Why it matters: dynamics training repeats the same tokenization pass every step. With a large dataset and a GPU-bound tokenizer, this can cut 30-50% off wall-clock training time.

Usage

# 1. pre-tokenize (one time)
python scripts/preprocess_tokens.py \
    --video_tokenizer_path runs/.../video_tokenizer \
    --dataset PONG \
    --output_path data/pong_tokens.h5 \
    --device cuda

# 2. set in configs/dynamics.yaml
# cached_tokens_path: data/pong_tokens.h5

# 3. train as usual — tokenizer forward pass is skipped
python scripts/train_dynamics.py --config configs/dynamics.yaml

Test plan

  • Run preprocess_tokens.py on a small dataset and verify HDF5 is created with correct shape [N, P]
  • Train dynamics with cached_tokens_path set and verify loss curve matches raw-frame training
  • Verify cached_tokens_path: null (default) falls back to the original frame-loading path

Adds scripts/preprocess_tokens.py: runs a trained VideoTokenizer over an
entire dataset and saves token indices as [N, P] int32 to HDF5.

Adds TokenizedVideoDataset: loads pre-tokenized HDF5 and returns [T, P]
index sequences, with the same (tokens, 0) interface as VideoHDF5Dataset.

In train_dynamics.py, if cached_tokens_path is set and exists, the dataloader
returns token indices directly, skipping the video tokenizer forward pass
each training step. This eliminates the tokenizer bottleneck for repeated
runs on the same dataset.

Use: python scripts/preprocess_tokens.py --video_tokenizer_path <ckpt> \
         --dataset PONG --output_path data/pong_tokens.h5
Then: set cached_tokens_path in configs/dynamics.yaml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant