-
Notifications
You must be signed in to change notification settings - Fork 0
Add per-sample ODE horizons, batched RoPE/masking, Stiefel equivariant layer and enhanced ray marcher #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import torch | ||
| from torch import nn | ||
| from typing import Callable, Tuple, Optional | ||
| from typing import Optional | ||
| from torchdiffeq import odeint_adjoint as odeint | ||
|
|
||
| class HRMPhysicsODE(nn.Module): | ||
|
|
@@ -68,13 +68,40 @@ def __init__(self, dim: int, action_dim: int = 128): | |
| super().__init__() | ||
| self.ode_func = HRMPhysicsODE(dim, action_dim) | ||
|
|
||
| def forward(self, z: torch.Tensor, delta_t: float = 1.0, action: Optional[torch.Tensor] = None): | ||
| # Evolve using the true Adjoint Method with dopri5 adaptive solver | ||
| self.ode_func.current_action = action | ||
|
|
||
| t = torch.tensor([0.0, delta_t], device=z.device, dtype=z.dtype) | ||
| # odeint_adjoint handles O(1) memory backprop and adaptive steps perfectly | ||
| zt1 = odeint(self.ode_func, z, t, method='dopri5') | ||
|
|
||
| def forward(self, z: torch.Tensor, delta_t: torch.Tensor | float = 1.0, action: Optional[torch.Tensor] = None): | ||
| # Scalar horizon path. | ||
| if not torch.is_tensor(delta_t): | ||
| self.ode_func.current_action = action | ||
| t = torch.tensor([0.0, float(delta_t)], device=z.device, dtype=z.dtype) | ||
| zt1 = odeint(self.ode_func, z, t, method='dopri5') | ||
| self.ode_func.current_action = None | ||
| return zt1[-1] | ||
|
|
||
| # Tensor horizon path: keep sample-specific dt while minimizing solver calls. | ||
| if delta_t.ndim == 0: | ||
| delta_t = delta_t.expand(z.shape[0]) | ||
| else: | ||
| delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1) | ||
|
Comment on lines
+81
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: A 1-element tensor horizon (for example Severity Level: Critical 🚨- ❌ ContinuousTimeHRM crashes when given 1-element tensor horizon.
- ⚠️ Breaks external callers using tensor scalar-style delta_t.Steps of Reproduction ✅1. In a script using this repo, import and instantiate `ContinuousTimeHRM` from
`models/vjepa/physics_engine.py` (class defined at lines 61-69), e.g. `engine =
ContinuousTimeHRM(dim)`.
2. Create a batched latent state `z` with batch size greater than 1, for example `z =
torch.randn(2, dim)` on a GPU or CPU device.
3. Define a 1-element tensor horizon that is intended to act as a scalar for the whole
batch, e.g. `delta_t = torch.tensor([1.0], device=z.device, dtype=z.dtype)` (note: this is
a 1D tensor of length 1, not a Python float or 0-D tensor).
4. Call `engine.forward(z, delta_t, action=None)`; inside `ContinuousTimeHRM.forward`
(`models/vjepa/physics_engine.py:71-85`), the scalar path is skipped
(`torch.is_tensor(delta_t)` is True), `delta_t.ndim == 0` is False, and the code executes
`delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1)` at line 84, which attempts to
reshape a tensor of size 1 into shape `(2, -1)` and raises a `RuntimeError: shape '[2,
-1]' is invalid for input of size 1`, causing the physics engine (and thus
`predictor.physics_engine` at `models/vjepa/predictor.py:35-38`) to crash for this natural
usage pattern.Fix in Cursor | Fix in VSCode Claude (Use Cmd/Ctrl + Click for best experience) Prompt for AI Agent 🤖This is a comment left during a code review.
**Path:** models/vjepa/physics_engine.py
**Line:** 81:84
**Comment:**
*Possible Bug: A 1-element tensor horizon (for example `torch.tensor([1.0])`) is treated as a batched tensor and passed to `reshape(z.shape[0], -1)`, which crashes when batch size is greater than 1. Since the method contract says scalar or per-sample tensor, handle `delta_t.numel() == 1` as scalar-expansion before reshaping.
Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix |
||
| delta_t = delta_t.to(device=z.device, dtype=z.dtype) | ||
|
|
||
| # Fast path: all samples share one horizon => one batched ODE solve. | ||
| if torch.allclose(delta_t, delta_t[0].expand_as(delta_t)): | ||
| self.ode_func.current_action = action | ||
| t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), delta_t[0]]) | ||
| zt1 = odeint(self.ode_func, z, t, method='dopri5') | ||
| self.ode_func.current_action = None | ||
| return zt1[-1] | ||
|
|
||
| # Group by unique horizons to reduce Python-loop overhead. | ||
| evolved = torch.empty_like(z) | ||
| unique_dt, inverse = torch.unique(delta_t, sorted=False, return_inverse=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Grouping horizons with Useful? React with 👍 / 👎. |
||
| for group_idx, dt in enumerate(unique_dt): | ||
| idx = torch.nonzero(inverse == group_idx, as_tuple=False).squeeze(-1) | ||
| self.ode_func.current_action = None if action is None else action.index_select(0, idx) | ||
| t = torch.stack([torch.zeros((), device=z.device, dtype=z.dtype), dt]) | ||
| z_group = z.index_select(0, idx) | ||
| zt1 = odeint(self.ode_func, z_group, t, method='dopri5') | ||
| evolved.index_copy_(0, idx, zt1[-1]) | ||
|
|
||
| self.ode_func.current_action = None | ||
| return zt1[-1] # Return state at t=delta_t | ||
| return evolved | ||
Uh oh!
There was an error while loading. Please reload this page.