Skip to content

Commit 94dd7d2

Browse files
jstacclaude
andauthored
Improve NumPy vs Numba vs JAX lecture (#525)
Rewrite the prange race condition section with an accurate explanation of why the result is always -inf (unrecognized reduction pattern, not a classical race condition), add the simple max() fix, and keep the row_maxes alternative. Restructure the vmap sections into a single narrative arc explaining memory savings and kernel fusion. Remove unsupported speed claims about Numba vs JAX for sequential operations. Move autodiff advantage of lax.scan into a {note}. Fix label and capitalization. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4c025df commit 94dd7d2

File tree

1 file changed

+66
-49
lines changed

1 file changed

+66
-49
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ kernelspec:
99
name: python3
1010
---
1111

12-
(parallel)=
12+
(numpy_numba_jax)=
1313
```{raw} jupyter
1414
<div id="qe-notebook-header" align="right" style="text-align:right;">
1515
<a href="https://quantecon.org/" title="quantecon.org">
@@ -141,7 +141,7 @@ code executes relatively quickly.
141141
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such
142142
that `f(x, y)` generates all evaluations on the product grid.
143143

144-
(This strategy dates back to Matlab.)
144+
(This strategy dates back to MATLAB.)
145145

146146
```{code-cell} ipython3
147147
grid = np.linspace(-3, 3, 3_000)
@@ -223,24 +223,50 @@ def compute_max_numba_parallel(grid):
223223
224224
```
225225

226-
Usually this returns an incorrect result:
226+
This returns `-inf` --- the initial value of `m`, as if it were never updated:
227227

228228
```{code-cell} ipython3
229229
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
230230
print(f"Numba result: {z_max_parallel_incorrect} 😱")
231231
```
232232

233-
The reason is that the variable `m` is shared across threads and not properly controlled.
233+
To understand why, recall that `prange` splits the outer loop across threads.
234234

235-
When multiple threads try to read and write `m` simultaneously, they interfere with each other.
235+
Each thread gets its own private copy of `m`, initialized to `-np.inf`, and
236+
correctly updates it within its chunk of iterations.
236237

237-
Threads read stale values of `m` or overwrite each other's updates --- or `m` never gets updated from its initial value.
238+
But at the end of the loop, Numba needs to combine the per-thread copies of `m`
239+
back into a single value --- a **reduction**.
238240

239-
Here's a more carefully written version.
241+
For patterns it recognizes, such as `m += z` (sum) or `m = max(m, z)` (max),
242+
Numba knows the combining operator.
243+
244+
But it does not recognize the `if z > m: m = z` pattern as a max reduction, so
245+
the per-thread results are never combined and `m` retains its initial value.
246+
247+
The simplest fix is to replace the conditional with `max`, which Numba
248+
recognizes:
240249

241250
```{code-cell} ipython3
242251
@numba.jit(parallel=True)
243252
def compute_max_numba_parallel(grid):
253+
n = len(grid)
254+
m = -np.inf
255+
for i in numba.prange(n):
256+
for j in range(n):
257+
x = grid[i]
258+
y = grid[j]
259+
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
260+
m = max(m, z)
261+
return m
262+
```
263+
264+
An alternative is to make the loop body fully independent across `i` and
265+
handle the reduction ourselves:
266+
267+
```{code-cell} ipython3
268+
@numba.jit(parallel=True)
269+
def compute_max_numba_parallel_v2(grid):
244270
n = len(grid)
245271
row_maxes = np.empty(n)
246272
for i in numba.prange(n):
@@ -255,11 +281,8 @@ def compute_max_numba_parallel(grid):
255281
return np.max(row_maxes)
256282
```
257283

258-
Now the code block that `for i in numba.prange(n)` acts over is independent
259-
across `i`.
260-
261-
Each thread writes to a separate element of the array `row_maxes` and
262-
the parallelization is safe.
284+
Here each thread writes to a separate element of `row_maxes`, so we handle the
285+
reduction ourselves via `np.max`.
263286

264287
```{code-cell} ipython3
265288
z_max_parallel = compute_max_numba_parallel(grid)
@@ -325,7 +348,7 @@ The compilation overhead is a one-time cost that pays off when the function is c
325348

326349
### JAX plus vmap
327350

328-
There is one problem with both the NumPy code and the JAX code:
351+
There is one problem with both the NumPy code and the JAX code above:
329352

330353
While the flat arrays are low-memory
331354

@@ -341,12 +364,13 @@ x_mesh.nbytes + y_mesh.nbytes
341364

342365
This extra memory usage can be a big problem in actual research calculations.
343366

344-
Fortunately, JAX admits a different approach
367+
Fortunately, JAX admits a different approach
345368
using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
346369

347-
#### Version 1
370+
The idea of `vmap` is to break vectorization into stages, transforming a
371+
function that operates on single values into one that operates on arrays.
348372

349-
Here's one way we can apply `vmap`.
373+
Here's how we can apply it to our problem.
350374

351375
```{code-cell} ipython3
352376
# Set up f to compute f(x, y) at every x for any given y
@@ -373,33 +397,23 @@ with qe.Timer(precision=8):
373397
z_max.block_until_ready()
374398
```
375399

376-
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.
377-
378-
When run on a CPU, its runtime is similar to that of the meshgrid version.
379-
380-
When run on a GPU, it is usually significantly faster.
381-
382-
In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.
383-
384-
This leads to code that is often easier to comprehend than traditional vectorized code.
385-
386-
We will investigate these ideas more when we tackle larger problems.
387-
400+
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version
401+
uses far less memory, with similar runtime.
388402

389-
### vmap version 2
403+
But we are still leaving speed gains on the table.
390404

391-
We can be still more memory efficient using vmap.
405+
The code above computes the full two-dimensional array `f(x,y)` and then takes
406+
the max.
392407

393-
While we avoid large input arrays in the preceding version,
394-
we still create the large output array `f(x,y)` before we compute the max.
408+
Moreover, the `jnp.max` call sits outside the JIT-compiled function `f`, so the
409+
compiler cannot fuse these operations into a single kernel.
395410

396-
Let's try a slightly different approach that takes the max to the inside.
397-
398-
Because of this change, we never compute the two-dimensional array `f(x,y)`.
411+
We can fix both problems by pushing the max inside and wrapping everything in
412+
a single `@jax.jit`:
399413

400414
```{code-cell} ipython3
401415
@jax.jit
402-
def compute_max_vmap_v2(grid):
416+
def compute_max_vmap(grid):
403417
# Construct a function that takes the max along each row
404418
f_vec_x_max = lambda y: jnp.max(f(grid, y))
405419
# Vectorize the function so we can call on all rows simultaneously
@@ -408,31 +422,35 @@ def compute_max_vmap_v2(grid):
408422
return jnp.max(f_vec_max(grid))
409423
```
410424

411-
Here
425+
Here
412426

413427
* `f_vec_x_max` computes the max along any given row
414428
* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.
415429

416430
We apply this function to all rows and then take the max of the row maxes.
417431

432+
Because we push the max inside, we never construct the full two-dimensional
433+
array `f(x,y)`, saving even more memory.
434+
435+
And because everything is under a single `@jax.jit`, the compiler can fuse
436+
all operations into one optimized kernel.
437+
418438
Let's try it.
419439

420440
```{code-cell} ipython3
421441
with qe.Timer(precision=8):
422-
z_max = compute_max_vmap_v2(grid).block_until_ready()
442+
z_max = compute_max_vmap(grid).block_until_ready()
423443
424-
print(f"JAX vmap v2 result: {z_max:.6f}")
444+
print(f"JAX vmap result: {z_max:.6f}")
425445
```
426446

427447
Let's run it again to eliminate compilation time:
428448

429449
```{code-cell} ipython3
430450
with qe.Timer(precision=8):
431-
z_max = compute_max_vmap_v2(grid).block_until_ready()
451+
z_max = compute_max_vmap(grid).block_until_ready()
432452
```
433453

434-
If you are running this on a GPU, as we are, you should see another nontrivial speed gain.
435-
436454

437455
### Summary
438456

@@ -552,9 +570,7 @@ with qe.Timer(precision=8):
552570

553571
JAX is also quite efficient for this sequential operation.
554572

555-
Both JAX and Numba deliver strong performance after compilation, with Numba
556-
typically (but not always) offering slightly better speeds on purely sequential
557-
operations.
573+
Both JAX and Numba deliver strong performance after compilation.
558574

559575

560576
### Summary
@@ -572,7 +588,7 @@ The JAX version, on the other hand, requires using `lax.scan`, which is signific
572588
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
573589

574590
For this type of sequential operation, Numba is the clear winner in terms of
575-
code clarity and ease of implementation, as well as high performance.
591+
code clarity and ease of implementation.
576592

577593

578594
## Overall recommendations
@@ -596,14 +612,15 @@ The code is natural and readable --- just a Python loop with a decorator ---
596612
and performance is excellent.
597613

598614
JAX can handle sequential problems via `lax.scan`, but the syntax is less
599-
intuitive and the performance gain is minimal for purely sequential work.
615+
intuitive.
600616

601-
That said, `lax.scan` has one important advantage: it supports automatic
617+
```{note}
618+
One important advantage of `lax.scan` is that it supports automatic
602619
differentiation through the loop, which Numba cannot do.
603-
604620
If you need to differentiate through a sequential computation (e.g., computing
605621
sensitivities of a trajectory to model parameters), JAX is the better choice
606622
despite the less natural syntax.
623+
```
607624

608625
In practice, many problems involve a mix of both patterns.
609626

0 commit comments

Comments
 (0)