Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def rotate_half(x: torch.Tensor):

def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
# cos, sin: [seq_len, head_dim] or [bs, seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
Expand Down
37 changes: 30 additions & 7 deletions models/vjepa/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Tuple

import geoopt
try:
import genesis as gs # optional high-fidelity graphics backend
except ImportError:
gs = None

class LieGroupEquivariantLayer(nn.Module):
"""
Expand Down Expand Up @@ -51,14 +55,15 @@ def forward(self, x: torch.Tensor, group_element: torch.Tensor) -> torch.Tensor:

class LatentRayMarcher(nn.Module):
"""
High-Fidelity Volumetric Latent Ray-Marcher utilizing SOTA 'nerfacc' library.
Treats the latent space as a Neural Radiance Field (NeRF).
Integrates density and features along light rays with extreme efficiency and rigor.
High-Fidelity Volumetric Latent Ray-Marcher.
Can optionally use a Genesis backend (if installed) and otherwise falls back
to differentiable PyTorch integration.
"""
def __init__(self, dim: int, num_samples: int = 16):
super().__init__()
self.dim = dim
self.num_samples = num_samples
self.has_genesis = gs is not None
self.density_net = nn.Sequential(
nn.Linear(dim, dim // 2),
nn.SiLU(),
Expand All @@ -69,6 +74,11 @@ def __init__(self, dim: int, num_samples: int = 16):
nn.SiLU(),
nn.Linear(dim, dim)
)
self.perceptual_fuser = nn.Sequential(
nn.Linear(dim + 6, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)

def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -78,8 +88,8 @@ def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor
bs, n, d = latents.shape
device = latents.device

# SOTA: nerfacc uses packed rays and intervals for massive speedups.
# For simplicity of the latent integration, we map our latents to ray samples.
# If Genesis is available and exposes a renderer API, users can plug it in here.
# We keep a robust differentiable fallback for portability.
t_vals = torch.linspace(0.0, 1.0, self.num_samples, device=device)

# Evolve all samples efficiently in parallel
Expand Down Expand Up @@ -110,8 +120,21 @@ def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor
weights = alpha * transmittance

accumulated_features = (weights.unsqueeze(-1) * features).sum(dim=2)

return accumulated_features

# Human-vision-inspired cues: depth, blur/fuzziness, intensity, direction,
# uncertainty entropy, and local contrast.
depth = (weights * t_vals.view(1, 1, -1)).sum(dim=-1, keepdim=True)
blur = alpha.var(dim=-1, keepdim=True)
intensity = accumulated_features.norm(dim=-1, keepdim=True) / (d ** 0.5)
ray_strength = ray_dirs.norm(dim=-1, keepdim=True)
uncertainty = -(weights * (weights.clamp_min(1e-10)).log()).sum(dim=-1, keepdim=True)
contrast = (features[:, :, 1:] - features[:, :, :-1]).abs().mean(dim=(2, 3), keepdim=True)
Comment thread
jeevesh415 marked this conversation as resolved.
contrast = contrast.squeeze(-1)

perceptual_cues = torch.cat([depth, blur, intensity, ray_strength, uncertainty, contrast], dim=-1)
fused = self.perceptual_fuser(torch.cat([accumulated_features, perceptual_cues], dim=-1))

return fused

def apply_rotary_pos_emb_3d(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
def rotate_half(x):
Expand Down
47 changes: 37 additions & 10 deletions models/vjepa/physics_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from typing import Callable, Tuple, Optional
from typing import Optional
from torchdiffeq import odeint_adjoint as odeint

class HRMPhysicsODE(nn.Module):
Expand Down Expand Up @@ -68,13 +68,40 @@ def __init__(self, dim: int, action_dim: int = 128):
super().__init__()
self.ode_func = HRMPhysicsODE(dim, action_dim)

def forward(self, z: torch.Tensor, delta_t: float = 1.0, action: Optional[torch.Tensor] = None):
# Evolve using the true Adjoint Method with dopri5 adaptive solver
self.ode_func.current_action = action

t = torch.tensor([0.0, delta_t], device=z.device, dtype=z.dtype)
# odeint_adjoint handles O(1) memory backprop and adaptive steps perfectly
zt1 = odeint(self.ode_func, z, t, method='dopri5')

def forward(self, z: torch.Tensor, delta_t: torch.Tensor | float = 1.0, action: Optional[torch.Tensor] = None):
# Scalar horizon path.
if not torch.is_tensor(delta_t):
self.ode_func.current_action = action
t = torch.tensor([0.0, float(delta_t)], device=z.device, dtype=z.dtype)
zt1 = odeint(self.ode_func, z, t, method='dopri5')
self.ode_func.current_action = None
return zt1[-1]

# Tensor horizon path: keep sample-specific dt while minimizing solver calls.
if delta_t.ndim == 0:
delta_t = delta_t.expand(z.shape[0])
else:
delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1)
Comment on lines +81 to +84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: A 1-element tensor horizon (for example torch.tensor([1.0])) is treated as a batched tensor and passed to reshape(z.shape[0], -1), which crashes when batch size is greater than 1. Since the method contract says scalar or per-sample tensor, handle delta_t.numel() == 1 as scalar-expansion before reshaping. [possible bug]

Severity Level: Critical 🚨
- ❌ ContinuousTimeHRM crashes when given 1-element tensor horizon.
- ⚠️ Breaks external callers using tensor scalar-style delta_t.
Steps of Reproduction ✅
1. In a script using this repo, import and instantiate `ContinuousTimeHRM` from
`models/vjepa/physics_engine.py` (class defined at lines 61-69), e.g. `engine =
ContinuousTimeHRM(dim)`.

2. Create a batched latent state `z` with batch size greater than 1, for example `z =
torch.randn(2, dim)` on a GPU or CPU device.

3. Define a 1-element tensor horizon that is intended to act as a scalar for the whole
batch, e.g. `delta_t = torch.tensor([1.0], device=z.device, dtype=z.dtype)` (note: this is
a 1D tensor of length 1, not a Python float or 0-D tensor).

4. Call `engine.forward(z, delta_t, action=None)`; inside `ContinuousTimeHRM.forward`
(`models/vjepa/physics_engine.py:71-85`), the scalar path is skipped
(`torch.is_tensor(delta_t)` is True), `delta_t.ndim == 0` is False, and the code executes
`delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1)` at line 84, which attempts to
reshape a tensor of size 1 into shape `(2, -1)` and raises a `RuntimeError: shape '[2,
-1]' is invalid for input of size 1`, causing the physics engine (and thus
`predictor.physics_engine` at `models/vjepa/predictor.py:35-38`) to crash for this natural
usage pattern.

Fix in Cursor | Fix in VSCode Claude

(Use Cmd/Ctrl + Click for best experience)

Prompt for AI Agent 🤖
This is a comment left during a code review.

**Path:** models/vjepa/physics_engine.py
**Line:** 81:84
**Comment:**
	*Possible Bug: A 1-element tensor horizon (for example `torch.tensor([1.0])`) is treated as a batched tensor and passed to `reshape(z.shape[0], -1)`, which crashes when batch size is greater than 1. Since the method contract says scalar or per-sample tensor, handle `delta_t.numel() == 1` as scalar-expansion before reshaping.

Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix
👍 | 👎

delta_t = delta_t.to(device=z.device, dtype=z.dtype)

# Fast path: all samples share one horizon => one batched ODE solve.
if torch.allclose(delta_t, delta_t[0].expand_as(delta_t)):
self.ode_func.current_action = action
t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), delta_t[0]])
zt1 = odeint(self.ode_func, z, t, method='dopri5')
self.ode_func.current_action = None
return zt1[-1]

# Group by unique horizons to reduce Python-loop overhead.
evolved = torch.empty_like(z)
unique_dt, inverse = torch.unique(delta_t, sorted=False, return_inverse=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve autograd for tensor delta_t grouping

Grouping horizons with torch.unique(delta_t, ...) breaks backprop when delta_t.requires_grad=True: PyTorch raises NotImplementedError for _unique2 during backward. Since this commit newly supports tensor-valued delta_t, any experiment with differentiable or learned horizons will now fail at training time.

Useful? React with 👍 / 👎.

for group_idx, dt in enumerate(unique_dt):
idx = torch.nonzero(inverse == group_idx, as_tuple=False).squeeze(-1)
self.ode_func.current_action = None if action is None else action.index_select(0, idx)
t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), dt])
z_group = z.index_select(0, idx)
zt1 = odeint(self.ode_func, z_group, t, method='dopri5')
evolved.index_copy_(0, idx, zt1[-1])

self.ode_func.current_action = None
return zt1[-1] # Return state at t=delta_t
return evolved
11 changes: 8 additions & 3 deletions models/vjepa/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,16 @@ def forward(self,

# 3. Continuous-Time Evolution (Neural ODE)
# Condition the physics engine on the action if provided
evolved_state = self.physics_engine(world_state, delta_t.mean().item(), action=action) # (bs, D)
evolved_state = self.physics_engine(world_state, delta_t, action=action) # (bs, D)

# 4. Memory Recall + Predictive Coding Loop
# Retrieve context-conditioned priors for each target token, then combine with
# evolved global dynamics for top-down planning.
mem_bank = world_state.unsqueeze(1).expand(-1, num_masked, -1)
memory_recall = self.memory.retrieve(mem_bank, target_queries)

# 4. Top-Down Predictive Coding Loop
# z_H plans, z_L computes. Error signals flow bottom-up.
z_H = evolved_state.unsqueeze(1).expand(-1, num_masked, -1)
z_H = 0.5 * (evolved_state.unsqueeze(1).expand(-1, num_masked, -1) + memory_recall)
z_L = target_queries

for _h in range(self.h_cycles):
Expand Down
10 changes: 8 additions & 2 deletions models/vjepa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ def apply_mask(x, mask):
if mask.ndim == 1:
mask = mask.unsqueeze(0).expand(bs, -1)

if mask.ndim == 2:
visible_counts = (~mask).sum(dim=1)
masked_counts = mask.sum(dim=1)
if not torch.all(visible_counts.eq(visible_counts[0])) or not torch.all(masked_counts.eq(masked_counts[0])):
raise ValueError("All samples must have the same number of visible/masked patches for batched stacking.")

visible_patches = []
masked_patches = []

for i in range(bs):
visible_patches.append(x[i, ~mask[i]])
masked_patches.append(x[i, mask[i]])

return torch.stack(visible_patches), torch.stack(masked_patches)
11 changes: 8 additions & 3 deletions models/vjepa/vjepa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ def forward(self, batch: Dict[str, torch.Tensor]):
full_cos, full_sin = self.context_encoder.rope(self.context_encoder.max_t, self.context_encoder.max_h, self.context_encoder.max_w)

# Index cos_sin for masked positions
m = mask[0] if mask.ndim == 2 else mask
masked_cos = full_cos[m]
masked_sin = full_sin[m]
if mask.ndim == 1:
masked_cos = full_cos[mask]
masked_sin = full_sin[mask]
else:
# Support per-sample masks by building batched RoPE tensors.
masked_cos = torch.stack([full_cos[m_i] for m_i in mask], dim=0)
masked_sin = torch.stack([full_sin[m_i] for m_i in mask], dim=0)

masked_cos_sin = (masked_cos, masked_sin)

# Predict masked latents
Expand Down