Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@
DEFAULT_PERSISTENT_WORKERS = True


def _default_video_transform():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def _default_video_transform(resolution=None):
ops = [transforms.ToTensor()]
if resolution is not None:
size = (resolution, resolution) if isinstance(resolution, int) else resolution
ops.append(transforms.Resize(size))
ops.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
return transforms.Compose(ops)


def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frames, transform=None, fps=30, preload_ratio=1, **kwargs):
def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frames, transform=None, fps=30, preload_ratio=1, resolution=None, **kwargs):
current_folder_path = os.getcwd()
video_path = current_folder_path + video_rel_path
preprocessed_path = current_folder_path + h5_rel_path
transform = _default_video_transform() if transform is None else transform
transform = _default_video_transform(resolution=resolution) if transform is None else transform
dataset_kwargs = dict(kwargs)
if resolution is not None:
dataset_kwargs["resolution"] = (resolution, resolution) if isinstance(resolution, int) else resolution

train = dataset_cls(
video_path,
Expand All @@ -38,7 +43,7 @@ def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frame
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio,
**kwargs
**dataset_kwargs
)
val = dataset_cls(
video_path,
Expand All @@ -48,63 +53,68 @@ def _load_video_dataset_pair(dataset_cls, video_rel_path, h5_rel_path, num_frame
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio,
**kwargs
**dataset_kwargs
)
return train, val


def load_pong(num_frames=1, fps=15, preload_ratio=1):
def load_pong(num_frames=1, fps=15, preload_ratio=1, resolution=None):
return _load_video_dataset_pair(
PongDataset,
'/data/pong.mp4',
'/data/pong_frames.h5',
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio
preload_ratio=preload_ratio,
resolution=resolution,
)


def load_sonic(num_frames=4, fps=15, preload_ratio=1):
def load_sonic(num_frames=4, fps=15, preload_ratio=1, resolution=None):
return _load_video_dataset_pair(
SonicDataset,
'/data/sonic_frames.mp4',
'/data/sonic_frames.h5',
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio
preload_ratio=preload_ratio,
resolution=resolution,
)


def load_pole_position(num_frames=4, fps=15, preload_ratio=1):
def load_pole_position(num_frames=4, fps=15, preload_ratio=1, resolution=None):
return _load_video_dataset_pair(
PolePositionDataset,
'/data/pole_position.mp4',
'/data/pole_position_frames.h5',
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio
preload_ratio=preload_ratio,
resolution=resolution,
)


def load_picodoom(num_frames=4, fps=30, preload_ratio=1):
def load_picodoom(num_frames=4, fps=30, preload_ratio=1, resolution=None):
return _load_video_dataset_pair(
PicoDoomDataset,
'/data/picodoom cleaned.mp4',
'/data/picodoom_frames.h5',
num_frames=num_frames,
fps=30,
preload_ratio=preload_ratio
preload_ratio=preload_ratio,
resolution=resolution,
)


def load_zelda(num_frames=4, fps=15, preload_ratio=1):
def load_zelda(num_frames=4, fps=15, preload_ratio=1, resolution=None):
return _load_video_dataset_pair(
ZeldaDataset,
'/data/Zelda oot2d 1 Cut.mp4',
'/data/zelda_frames.h5',
num_frames=num_frames,
fps=fps,
preload_ratio=preload_ratio
preload_ratio=preload_ratio,
resolution=resolution,
)


Expand Down Expand Up @@ -141,17 +151,17 @@ def data_loaders(train_data, val_data, batch_size, distributed=False, rank=0, wo
return train_loader, val_loader


def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1):
def load_data_and_data_loaders(dataset, batch_size, num_frames=1, distributed=False, rank=0, world_size=1, fps=15, preload_ratio=1, frame_size=None):
if dataset == 'PONG':
training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
training_data, validation_data = load_pong(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size)
elif dataset == 'SONIC':
training_data, validation_data = load_sonic(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
training_data, validation_data = load_sonic(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size)
elif dataset == 'POLE_POSITION':
training_data, validation_data = load_pole_position(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
training_data, validation_data = load_pole_position(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size)
elif dataset == 'PICODOOM':
training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
training_data, validation_data = load_picodoom(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size)
elif dataset == 'ZELDA':
training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio)
training_data, validation_data = load_zelda(num_frames=num_frames, fps=fps, preload_ratio=preload_ratio, resolution=frame_size)
else:
raise ValueError('Invalid dataset')

Expand Down
68 changes: 50 additions & 18 deletions models/latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,29 @@
import torch.distributed as dist
import math
from einops import rearrange, repeat, reduce
from models.st_transformer import STTransformer, PatchEmbedding
from models.st_transformer import STTransformer, PatchEmbedding, SpatialAttention
from models.fsq import FiniteScalarQuantizer

NUM_LATENT_ACTIONS_BINS = 2

class LatentActionsEncoder(nn.Module):
def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads=8,
hidden_dim=256, num_blocks=4, action_dim=3):
hidden_dim=256, num_blocks=4, action_dim=3, use_windowed_attention=False):
super().__init__()
self.use_windowed_attention = use_windowed_attention
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True)

if self.use_windowed_attention:
self.window_attn = SpatialAttention(embed_dim, num_heads)
action_head_dim = embed_dim
else:
action_head_dim = embed_dim * 2

# embeddings to discrete latent bottleneck actions
self.action_head = nn.Sequential(
nn.LayerNorm(embed_dim * 2),
nn.Linear(embed_dim * 2, 4 * action_dim),
nn.LayerNorm(action_head_dim),
nn.Linear(action_head_dim, 4 * action_dim),
nn.GELU(),
nn.Linear(4 * action_dim, action_dim)
)
Expand All @@ -32,19 +39,34 @@ def forward(self, frames):
embeddings = self.patch_embed(frames) # [B, T, P, E]
transformed = self.transformer(embeddings)

# TODO: try attention pooling + mean instead of mean + concat
# mean pool over patches (since one action per frame)
pooled = transformed.mean(dim=2) # [B, T, E]
if not self.use_windowed_attention:
pooled = transformed.mean(dim=2) # [B, T, E]
current_frames = pooled[:, :-1] # [B, T-1, E]
next_frames = pooled[:, 1:] # [B, T-1, E]
combined = torch.cat([current_frames, next_frames], dim=-1)
return self.action_head(combined) # [B, T-1, A]

# combine features from current and next frame
actions = []
for t in range(seq_len - 1):
# concat current and next frame features
combined = torch.cat([pooled[:, t], pooled[:, t+1]], dim=1) # [B, E*2]
action = self.action_head(combined) # [B, A]
actions.append(action)
# windowed attention over length-2 windows (current and next frame)
# we want to combine frame t and t+1

# 1. Create windows: [B, T-1, 2, P, E]
# We slice to get T-1 windows
current_frames = transformed[:, :-1] # [B, T-1, P, E]
next_frames = transformed[:, 1:] # [B, T-1, P, E]

# 2. Concatenate along patch dimension to treat as one large spatial sequence
# [B, T-1, 2*P, E]
windows = torch.cat([current_frames, next_frames], dim=2)

# 3. Apply spatial attention
# SpatialAttention expects [B, T, P, E]
attended = self.window_attn(windows) # [B, T-1, 2*P, E]

# 4. Mean pool over the combined patches
pooled = attended.mean(dim=2) # [B, T-1, E]

actions = torch.stack(actions, dim=1) # [B, T-1, A]
# 5. Project to actions
actions = self.action_head(pooled) # [B, T-1, A]

return actions

Expand Down Expand Up @@ -98,11 +120,21 @@ def forward(self, frames, actions, training=True):

class LatentActionModel(nn.Module):
def __init__(self, frame_size=(128, 128), n_actions=8, patch_size=8, embed_dim=128,
num_heads=8, hidden_dim=256, num_blocks=4):
num_heads=8, hidden_dim=256, num_blocks=4, use_windowed_attention=False):
super().__init__()
assert math.log(n_actions, NUM_LATENT_ACTIONS_BINS).is_integer(), f"n_actions must be a power of {NUM_LATENT_ACTIONS_BINS}"
self.action_dim=int(math.log(n_actions, NUM_LATENT_ACTIONS_BINS))
self.encoder = LatentActionsEncoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, action_dim=self.action_dim)
self.use_windowed_attention = use_windowed_attention
self.encoder = LatentActionsEncoder(
frame_size,
patch_size,
embed_dim,
num_heads,
hidden_dim,
num_blocks,
action_dim=self.action_dim,
use_windowed_attention=use_windowed_attention,
)
self.quantizer = FiniteScalarQuantizer(latent_dim=self.action_dim, num_bins=NUM_LATENT_ACTIONS_BINS)
self.decoder = LatentActionsDecoder(frame_size, patch_size, embed_dim, num_heads, hidden_dim, num_blocks, conditioning_dim=self.action_dim)
self.var_target = 0.01
Expand Down Expand Up @@ -136,4 +168,4 @@ def encode(self, frames):

@property
def model_type(self) -> str:
return ModelType.LatentActionModel
return ModelType.LatentActionModel
7 changes: 7 additions & 0 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def missing(path: Optional[str]) -> bool:
data_overrides = {'preload_ratio': args.preload_ratio}
else:
data_overrides = {}
# infer frame_size from the loaded video_tokenizer (its patch_embed retains the trained spatial dims)
try:
_trained_frame_size = video_tokenizer.encoder.patch_embed.frame_size
h = _trained_frame_size[0] if isinstance(_trained_frame_size, (tuple, list)) else _trained_frame_size
data_overrides['frame_size'] = h
except AttributeError:
pass
_, _, data_loader, _, _ = load_data_and_data_loaders(
dataset=args.dataset, batch_size=1, num_frames=frames_to_load, **data_overrides)

Expand Down
1 change: 1 addition & 0 deletions scripts/train_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def main():
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
frame_size=args.frame_size,
**data_overrides,
)
train_iter = iter(training_loader)
Expand Down
2 changes: 2 additions & 0 deletions scripts/train_latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def main():
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
frame_size=args.frame_size,
**data_overrides,
)

Expand All @@ -57,6 +58,7 @@ def main():
hidden_dim=args.hidden_dim,
num_blocks=args.num_blocks,
n_actions=args.n_actions,
use_windowed_attention=args.use_windowed_attention,
).to(args.device)
if args.checkpoint:
model, _ = load_latent_actions_from_checkpoint(
Expand Down
1 change: 1 addition & 0 deletions scripts/train_video_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def main():
distributed=dist_setup['is_distributed'],
rank=dist_setup['device_mesh'].get_rank() if dist_setup['device_mesh'] is not None else 0,
world_size=dist_setup['world_size'],
frame_size=args.frame_size,
**data_overrides,
)
# print("Length of training data:", len(training_data))
Expand Down
2 changes: 2 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class LatentActionsConfig:
muon_backend_steps: int = 5
# device
device: DeviceType = DeviceType.CUDA
use_windowed_attention: bool = False
# other params
fps: Optional[int] = None
preload_ratio: Optional[float] = None
Expand Down Expand Up @@ -269,6 +270,7 @@ class TrainingConfig:
batch_size_per_gpu: Optional[int] = None
gradient_accumulation_steps: Optional[int] = None
log_interval: Optional[int] = None
use_windowed_attention: Optional[bool] = None
n_updates: Optional[int] = None # number of optimizer.step(), excluding grad_accum_step
fps: Optional[int] = None
preload_ratio: Optional[float] = None
Expand Down
22 changes: 18 additions & 4 deletions utils/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,33 @@
from utils.utils import load_videotokenizer_from_checkpoint, load_latent_actions_from_checkpoint, load_dynamics_from_checkpoint
from einops import repeat

def load_models(video_tokenizer_path, latent_actions_path, dynamics_path, device, use_actions=True):
# Load tokenizer and dynamics, and Latent Actions if using actions
def load_models(video_tokenizer_path, latent_actions_path, dynamics_path, device, use_actions=True, load_dynamics=True):
# Load tokenizer and optionally dynamics/latent actions.
video_tokenizer, _vt_ckpt = load_videotokenizer_from_checkpoint(video_tokenizer_path, device)
video_tokenizer.eval()
latent_action_model = None
if use_actions:
latent_action_model, _latent_action_ckpt = load_latent_actions_from_checkpoint(latent_actions_path, device)
latent_action_model.eval()
dynamics_model, _dyn_ckpt = load_dynamics_from_checkpoint(dynamics_path, device)
dynamics_model.eval()
dynamics_model = None
if load_dynamics:
dynamics_model, _dyn_ckpt = load_dynamics_from_checkpoint(dynamics_path, device)
dynamics_model.eval()
return video_tokenizer, latent_action_model, dynamics_model


def reconstruct_frames(video_tokenizer, frames):
indices = video_tokenizer.tokenize(frames)
latents = video_tokenizer.quantizer.get_latents_from_indices(indices)
return video_tokenizer.detokenize(latents)


def compute_frame_mse(predicted_frames, target_frames):
predicted = ((predicted_frames.detach().to(torch.float32) + 1) / 2).clamp(0, 1)
target = ((target_frames.detach().to(torch.float32) + 1) / 2).clamp(0, 1)
return torch.mean((predicted - target) ** 2).item()


def visualize_inference(predicted_frames, ground_truth_frames, inferred_actions, fps, use_actions=True):
# Move to CPU and convert to numpy
predicted_frames = predicted_frames.detach().cpu()
Expand Down
6 changes: 4 additions & 2 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ def load_latent_actions_from_checkpoint(checkpoint_path, device, model = None, i
'num_heads': cfg.get('num_heads', 8),
'hidden_dim': cfg.get('hidden_dim', 256),
'num_blocks': cfg.get('num_blocks', 4),
'use_windowed_attention': cfg.get('use_windowed_attention', False),
}
if model is None:
if model is None or getattr(model.encoder, 'use_windowed_attention', False) != kwargs['use_windowed_attention']:
model = LatentActionModel(**kwargs)
set_model_state_dict(
model=model,
Expand All @@ -275,7 +276,8 @@ def load_dynamics_from_checkpoint(checkpoint_path, device, model = None, is_dist
conditioning_dim = cfg.get('conditioning_dim', None)
if conditioning_dim is None:
cond_inferred = None
for k, v in model_sd.items():
state_items = model_sd.items() if isinstance(model_sd, dict) else []
for k, v in state_items:
# Linear weight shape: [out_features, in_features]; in_features is conditioning dim
if k.endswith('to_gamma_beta.1.weight'):
cond_inferred = int(v.shape[1])
Expand Down