Add per-sample ODE horizons, batched RoPE/masking, Stiefel equivariant layer and enhanced ray marcher#2
Conversation
|
CodeAnt AI is reviewing your PR. Thanks for using CodeAnt! 🎉We're free for open-source projects. if you're enjoying it, help us grow by sharing. Share on X · |
|
Warning Rate limit exceeded
To keep reviews running without waiting, you can enable usage-based add-on for your organization. This allows additional reviews beyond the hourly cap. Account admins can enable it under billing. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (6)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 0/1 reviews remaining, refill in 26 minutes and 15 seconds.Comment |
| if delta_t.ndim == 0: | ||
| delta_t = delta_t.expand(z.shape[0]) | ||
| else: | ||
| delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1) |
There was a problem hiding this comment.
Suggestion: A 1-element tensor horizon (for example torch.tensor([1.0])) is treated as a batched tensor and passed to reshape(z.shape[0], -1), which crashes when batch size is greater than 1. Since the method contract says scalar or per-sample tensor, handle delta_t.numel() == 1 as scalar-expansion before reshaping. [possible bug]
Severity Level: Critical 🚨
- ❌ ContinuousTimeHRM crashes when given 1-element tensor horizon.
- ⚠️ Breaks external callers using tensor scalar-style delta_t.Steps of Reproduction ✅
1. In a script using this repo, import and instantiate `ContinuousTimeHRM` from
`models/vjepa/physics_engine.py` (class defined at lines 61-69), e.g. `engine =
ContinuousTimeHRM(dim)`.
2. Create a batched latent state `z` with batch size greater than 1, for example `z =
torch.randn(2, dim)` on a GPU or CPU device.
3. Define a 1-element tensor horizon that is intended to act as a scalar for the whole
batch, e.g. `delta_t = torch.tensor([1.0], device=z.device, dtype=z.dtype)` (note: this is
a 1D tensor of length 1, not a Python float or 0-D tensor).
4. Call `engine.forward(z, delta_t, action=None)`; inside `ContinuousTimeHRM.forward`
(`models/vjepa/physics_engine.py:71-85`), the scalar path is skipped
(`torch.is_tensor(delta_t)` is True), `delta_t.ndim == 0` is False, and the code executes
`delta_t = delta_t.reshape(z.shape[0], -1).mean(dim=-1)` at line 84, which attempts to
reshape a tensor of size 1 into shape `(2, -1)` and raises a `RuntimeError: shape '[2,
-1]' is invalid for input of size 1`, causing the physics engine (and thus
`predictor.physics_engine` at `models/vjepa/predictor.py:35-38`) to crash for this natural
usage pattern.Fix in Cursor | Fix in VSCode Claude
(Use Cmd/Ctrl + Click for best experience)
Prompt for AI Agent 🤖
This is a comment left during a code review.
**Path:** models/vjepa/physics_engine.py
**Line:** 81:84
**Comment:**
*Possible Bug: A 1-element tensor horizon (for example `torch.tensor([1.0])`) is treated as a batched tensor and passed to `reshape(z.shape[0], -1)`, which crashes when batch size is greater than 1. Since the method contract says scalar or per-sample tensor, handle `delta_t.numel() == 1` as scalar-expansion before reshaping.
Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix|
CodeAnt AI finished reviewing your PR. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ab3364ff95
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| # Group by unique horizons to reduce Python-loop overhead. | ||
| evolved = torch.empty_like(z) | ||
| unique_dt, inverse = torch.unique(delta_t, sorted=False, return_inverse=True) |
There was a problem hiding this comment.
Preserve autograd for tensor delta_t grouping
Grouping horizons with torch.unique(delta_t, ...) breaks backprop when delta_t.requires_grad=True: PyTorch raises NotImplementedError for _unique2 during backward. Since this commit newly supports tensor-valued delta_t, any experiment with differentiable or learned horizons will now fail at training time.
Useful? React with 👍 / 👎.
User description
Motivation
Description
ContinuousTimeHRM.forwardto accept scalar or per-sampledelta_t(torch.Tensor | float) and handle grouped solves by unique horizon values to reduce solver calls, while preserving adjoint integration (odeintwithdopri5).LieGroupEquivariantLayerbased ongeoopt.manifolds.Stiefelplus low-rank generators, and add an enrichedLatentRayMarcherthat optionally detects agenesisbackend, computes NeRF-style accumulation and fuses human-vision-inspired cues through aperceptual_fusernetwork.VJEPA.forwardby stacking per-samplecos, sintensors, and updateapply_rotary_pos_embdocstring to indicatecos/sincan be batched.apply_maskenforce that all samples in a batch have the same number of visible/masked patches and raise aValueErrorif they differ to avoid silent shape mismatches when stacking.delta_tinto the physics engine (no longer reducing todelta_t.mean().item()), add a memory recall blending step forz_H, and integrate optionalray_marchershadow features into predicted latents.typingusage) and small API resilience (optionalgenesisimport fallback andhas_genesisflag).Testing
ContinuousTimeHRMwith a scalardelta_tand with a per-sampledelta_ttensor and verified forward passes complete without shape errors (passed).VJEPAwith a dummy batch (random video, mask,delta_t) to validate masked RoPE indexing,apply_maskbehavior and predictor integration (passed).apply_maskmismatch detection to ensureValueErroris raised on uneven per-sample mask counts (passed).Codex Task
CodeAnt-AI Description
Support per-sample masking, time horizons, and richer latent perception
What Changed
Impact
✅ Fewer batched masking shape errors✅ Safer variable-timestep prediction✅ Richer latent features for masked video prediction🔄 Retrigger CodeAnt AI Review
Details
💡 Usage Guide
Checking Your Pull Request
Every time you make a pull request, our system automatically looks through it. We check for security issues, mistakes in how you're setting up your infrastructure, and common code problems. We do this to make sure your changes are solid and won't cause any trouble later.
Talking to CodeAnt AI
Got a question or need a hand with something in your pull request? You can easily get in touch with CodeAnt AI right here. Just type the following in a comment on your pull request, and replace "Your question here" with whatever you want to ask:
This lets you have a chat with CodeAnt AI about your pull request, making it easier to understand and improve your code.
Example
Preserve Org Learnings with CodeAnt
You can record team preferences so CodeAnt AI applies them in future reviews. Reply directly to the specific CodeAnt AI suggestion (in the same thread) and replace "Your feedback here" with your input:
This helps CodeAnt AI learn and adapt to your team's coding style and standards.
Example
Retrigger review
Ask CodeAnt AI to review the PR again, by typing:
Check Your Repository Health
To analyze the health of your code repository, visit our dashboard at https://app.codeant.ai. This tool helps you identify potential issues and areas for improvement in your codebase, ensuring your repository maintains high standards of code health.