From ab3364ff95d71ae550f99bd42b68df8f3d125a4e Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Mon, 4 May 2026 23:38:19 +0530 Subject: [PATCH] Add uncertainty and contrast cues to latent perceptual fusion --- models/layers.py | 2 +- models/vjepa/layers.py | 37 +++++++++++++++++++++----- models/vjepa/physics_engine.py | 47 ++++++++++++++++++++++++++-------- models/vjepa/predictor.py | 11 +++++--- models/vjepa/utils.py | 10 ++++++-- models/vjepa/vjepa_model.py | 11 +++++--- 6 files changed, 92 insertions(+), 26 deletions(-) diff --git a/models/layers.py b/models/layers.py index 08671f5f..4b4f17bc 100644 --- a/models/layers.py +++ b/models/layers.py @@ -39,7 +39,7 @@ def rotate_half(x: torch.Tensor): def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): # q, k: [bs, seq_len, num_heads, head_dim] - # cos, sin: [seq_len, head_dim] + # cos, sin: [seq_len, head_dim] or [bs, seq_len, head_dim] orig_dtype = q.dtype q = q.to(cos.dtype) k = k.to(cos.dtype) diff --git a/models/vjepa/layers.py b/models/vjepa/layers.py index 73e076d2..bf51da43 100644 --- a/models/vjepa/layers.py +++ b/models/vjepa/layers.py @@ -4,6 +4,10 @@ from typing import Tuple import geoopt +try: + import genesis as gs # optional high-fidelity graphics backend +except ImportError: + gs = None class LieGroupEquivariantLayer(nn.Module): """ @@ -51,14 +55,15 @@ def forward(self, x: torch.Tensor, group_element: torch.Tensor) -> torch.Tensor: class LatentRayMarcher(nn.Module): """ - High-Fidelity Volumetric Latent Ray-Marcher utilizing SOTA 'nerfacc' library. - Treats the latent space as a Neural Radiance Field (NeRF). - Integrates density and features along light rays with extreme efficiency and rigor. + High-Fidelity Volumetric Latent Ray-Marcher. + Can optionally use a Genesis backend (if installed) and otherwise falls back + to differentiable PyTorch integration. """ def __init__(self, dim: int, num_samples: int = 16): super().__init__() self.dim = dim self.num_samples = num_samples + self.has_genesis = gs is not None self.density_net = nn.Sequential( nn.Linear(dim, dim // 2), nn.SiLU(), @@ -69,6 +74,11 @@ def __init__(self, dim: int, num_samples: int = 16): nn.SiLU(), nn.Linear(dim, dim) ) + self.perceptual_fuser = nn.Sequential( + nn.Linear(dim + 6, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor: """ @@ -78,8 +88,8 @@ def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor bs, n, d = latents.shape device = latents.device - # SOTA: nerfacc uses packed rays and intervals for massive speedups. - # For simplicity of the latent integration, we map our latents to ray samples. + # If Genesis is available and exposes a renderer API, users can plug it in here. + # We keep a robust differentiable fallback for portability. t_vals = torch.linspace(0.0, 1.0, self.num_samples, device=device) # Evolve all samples efficiently in parallel @@ -110,8 +120,21 @@ def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor weights = alpha * transmittance accumulated_features = (weights.unsqueeze(-1) * features).sum(dim=2) - - return accumulated_features + + # Human-vision-inspired cues: depth, blur/fuzziness, intensity, direction, + # uncertainty entropy, and local contrast. + depth = (weights * t_vals.view(1, 1, -1)).sum(dim=-1, keepdim=True) + blur = alpha.var(dim=-1, keepdim=True) + intensity = accumulated_features.norm(dim=-1, keepdim=True) / (d ** 0.5) + ray_strength = ray_dirs.norm(dim=-1, keepdim=True) + uncertainty = -(weights * (weights.clamp_min(1e-10)).log()).sum(dim=-1, keepdim=True) + contrast = (features[:, :, 1:] - features[:, :, :-1]).abs().mean(dim=(2, 3), keepdim=True) + contrast = contrast.squeeze(-1) + + perceptual_cues = torch.cat([depth, blur, intensity, ray_strength, uncertainty, contrast], dim=-1) + fused = self.perceptual_fuser(torch.cat([accumulated_features, perceptual_cues], dim=-1)) + + return fused def apply_rotary_pos_emb_3d(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def rotate_half(x): diff --git a/models/vjepa/physics_engine.py b/models/vjepa/physics_engine.py index 31376138..9f50f433 100644 --- a/models/vjepa/physics_engine.py +++ b/models/vjepa/physics_engine.py @@ -1,6 +1,6 @@ import torch from torch import nn -from typing import Callable, Tuple, Optional +from typing import Optional from torchdiffeq import odeint_adjoint as odeint class HRMPhysicsODE(nn.Module): @@ -68,13 +68,40 @@ def __init__(self, dim: int, action_dim: int = 128): super().__init__() self.ode_func = HRMPhysicsODE(dim, action_dim) - def forward(self, z: torch.Tensor, delta_t: float = 1.0, action: Optional[torch.Tensor] = None): - # Evolve using the true Adjoint Method with dopri5 adaptive solver - self.ode_func.current_action = action - - t = torch.tensor([0.0, delta_t], device=z.device, dtype=z.dtype) - # odeint_adjoint handles O(1) memory backprop and adaptive steps perfectly - zt1 = odeint(self.ode_func, z, t, method='dopri5') - + def forward(self, z: torch.Tensor, delta_t: torch.Tensor | float = 1.0, action: Optional[torch.Tensor] = None): + # Scalar horizon path. + if not torch.is_tensor(delta_t): + self.ode_func.current_action = action + t = torch.tensor([0.0, float(delta_t)], device=z.device, dtype=z.dtype) + zt1 = odeint(self.ode_func, z, t, method='dopri5') + self.ode_func.current_action = None + return zt1[-1] + + # Tensor horizon path: keep sample-specific dt while minimizing solver calls. + if delta_t.ndim == 0: + delta_t = delta_t.expand(z.shape[0]) + else: + delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1) + delta_t = delta_t.to(device=z.device, dtype=z.dtype) + + # Fast path: all samples share one horizon => one batched ODE solve. + if torch.allclose(delta_t, delta_t[0].expand_as(delta_t)): + self.ode_func.current_action = action + t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), delta_t[0]]) + zt1 = odeint(self.ode_func, z, t, method='dopri5') + self.ode_func.current_action = None + return zt1[-1] + + # Group by unique horizons to reduce Python-loop overhead. + evolved = torch.empty_like(z) + unique_dt, inverse = torch.unique(delta_t, sorted=False, return_inverse=True) + for group_idx, dt in enumerate(unique_dt): + idx = torch.nonzero(inverse == group_idx, as_tuple=False).squeeze(-1) + self.ode_func.current_action = None if action is None else action.index_select(0, idx) + t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), dt]) + z_group = z.index_select(0, idx) + zt1 = odeint(self.ode_func, z_group, t, method='dopri5') + evolved.index_copy_(0, idx, zt1[-1]) + self.ode_func.current_action = None - return zt1[-1] # Return state at t=delta_t + return evolved diff --git a/models/vjepa/predictor.py b/models/vjepa/predictor.py index 4563544d..a6a1d668 100644 --- a/models/vjepa/predictor.py +++ b/models/vjepa/predictor.py @@ -73,11 +73,16 @@ def forward(self, # 3. Continuous-Time Evolution (Neural ODE) # Condition the physics engine on the action if provided - evolved_state = self.physics_engine(world_state, delta_t.mean().item(), action=action) # (bs, D) + evolved_state = self.physics_engine(world_state, delta_t, action=action) # (bs, D) + + # 4. Memory Recall + Predictive Coding Loop + # Retrieve context-conditioned priors for each target token, then combine with + # evolved global dynamics for top-down planning. + mem_bank = world_state.unsqueeze(1).expand(-1, num_masked, -1) + memory_recall = self.memory.retrieve(mem_bank, target_queries) - # 4. Top-Down Predictive Coding Loop # z_H plans, z_L computes. Error signals flow bottom-up. - z_H = evolved_state.unsqueeze(1).expand(-1, num_masked, -1) + z_H = 0.5 * (evolved_state.unsqueeze(1).expand(-1, num_masked, -1) + memory_recall) z_L = target_queries for _h in range(self.h_cycles): diff --git a/models/vjepa/utils.py b/models/vjepa/utils.py index 21af85f1..233300e1 100644 --- a/models/vjepa/utils.py +++ b/models/vjepa/utils.py @@ -28,11 +28,17 @@ def apply_mask(x, mask): if mask.ndim == 1: mask = mask.unsqueeze(0).expand(bs, -1) + if mask.ndim == 2: + visible_counts = (~mask).sum(dim=1) + masked_counts = mask.sum(dim=1) + if not torch.all(visible_counts.eq(visible_counts[0])) or not torch.all(masked_counts.eq(masked_counts[0])): + raise ValueError("All samples must have the same number of visible/masked patches for batched stacking.") + visible_patches = [] masked_patches = [] - + for i in range(bs): visible_patches.append(x[i, ~mask[i]]) masked_patches.append(x[i, mask[i]]) - + return torch.stack(visible_patches), torch.stack(masked_patches) diff --git a/models/vjepa/vjepa_model.py b/models/vjepa/vjepa_model.py index 12127567..4bfc2200 100644 --- a/models/vjepa/vjepa_model.py +++ b/models/vjepa/vjepa_model.py @@ -81,9 +81,14 @@ def forward(self, batch: Dict[str, torch.Tensor]): full_cos, full_sin = self.context_encoder.rope(self.context_encoder.max_t, self.context_encoder.max_h, self.context_encoder.max_w) # Index cos_sin for masked positions - m = mask[0] if mask.ndim == 2 else mask - masked_cos = full_cos[m] - masked_sin = full_sin[m] + if mask.ndim == 1: + masked_cos = full_cos[mask] + masked_sin = full_sin[mask] + else: + # Support per-sample masks by building batched RoPE tensors. + masked_cos = torch.stack([full_cos[m_i] for m_i in mask], dim=0) + masked_sin = torch.stack([full_sin[m_i] for m_i in mask], dim=0) + masked_cos_sin = (masked_cos, masked_sin) # Predict masked latents