Skip to content

Vectorize full equation for JAX jit and vmap#61

Draft
ealt wants to merge 1 commit intoeric/mathfrom
cursor/vectorize-full-equation-for-jax-jit-and-vmap-b4ea
Draft

Vectorize full equation for JAX jit and vmap#61
ealt wants to merge 1 commit intoeric/mathfrom
cursor/vectorize-full-equation-for-jax-jit-and-vmap-b4ea

Conversation

@ealt
Copy link
Collaborator

@ealt ealt commented Aug 5, 2025

Refactor full_equation to be compatible with JAX jit and vmap for improved performance and parallelization.

The original full_equation contained JAX vectorization blockers such as data-dependent while loops, dynamic array indexing, and mutable state. This refactoring replaces these with JAX primitives like jax.lax.scan, jax.lax.fori_loop, and jax.lax.select to enable efficient compilation and vectorization.


Open in Cursor Open in Web

Co-authored-by: ericallenalt <ericallenalt@gmail.com>
@cursor
Copy link

cursor bot commented Aug 5, 2025

Cursor Agent can help with this pull request. Just @cursor in comments and I'll start working on changes in this branch.
Learn more about Cursor Agents

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants