You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Rewrite the prange race condition section with an accurate explanation
of why the result is always -inf (unrecognized reduction pattern, not a
classical race condition), add the simple max() fix, and keep the
row_maxes alternative. Restructure the vmap sections into a single
narrative arc explaining memory savings and kernel fusion. Remove
unsupported speed claims about Numba vs JAX for sequential operations.
Move autodiff advantage of lax.scan into a {note}. Fix label and
capitalization.
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
If you are running this on a GPU, as we are, you should see another nontrivial speed gain.
435
-
436
454
437
455
### Summary
438
456
@@ -552,9 +570,7 @@ with qe.Timer(precision=8):
552
570
553
571
JAX is also quite efficient for this sequential operation.
554
572
555
-
Both JAX and Numba deliver strong performance after compilation, with Numba
556
-
typically (but not always) offering slightly better speeds on purely sequential
557
-
operations.
573
+
Both JAX and Numba deliver strong performance after compilation.
558
574
559
575
560
576
### Summary
@@ -572,7 +588,7 @@ The JAX version, on the other hand, requires using `lax.scan`, which is signific
572
588
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
573
589
574
590
For this type of sequential operation, Numba is the clear winner in terms of
575
-
code clarity and ease of implementation, as well as high performance.
591
+
code clarity and ease of implementation.
576
592
577
593
578
594
## Overall recommendations
@@ -596,14 +612,15 @@ The code is natural and readable --- just a Python loop with a decorator ---
596
612
and performance is excellent.
597
613
598
614
JAX can handle sequential problems via `lax.scan`, but the syntax is less
599
-
intuitive and the performance gain is minimal for purely sequential work.
615
+
intuitive.
600
616
601
-
That said, `lax.scan` has one important advantage: it supports automatic
617
+
```{note}
618
+
One important advantage of `lax.scan` is that it supports automatic
602
619
differentiation through the loop, which Numba cannot do.
603
-
604
620
If you need to differentiate through a sequential computation (e.g., computing
605
621
sensitivities of a trajectory to model parameters), JAX is the better choice
606
622
despite the less natural syntax.
623
+
```
607
624
608
625
In practice, many problems involve a mix of both patterns.
0 commit comments