From 71f76ad98ce0da99e6247f6b674c6c16da14def3 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 13 Apr 2026 08:28:43 -0400 Subject: [PATCH] Add lax.fori_loop example and improve jax_intro clarity - Add lax.fori_loop example to sequential operations section, alongside existing lax.scan version - Update summaries and recommendations to reference both approaches - Improve jax_intro: reorganize intro text, fix block_until_ready style, add Size Experiment subheading, expand immutability explanation, add cross-reference labels - Remove unused import Co-Authored-By: Claude Opus 4.6 (1M context) --- lectures/jax_intro.md | 66 +++++++++++++---------- lectures/numpy_vs_numba_vs_jax.md | 88 ++++++++++++++++++++++--------- 2 files changed, 102 insertions(+), 52 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 1c6fb52a..4914e3de 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -46,20 +46,21 @@ import numpy as np import quantecon as qe ``` -Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface. ## JAX as a NumPy Replacement -One of the attractive features of JAX is that, whenever possible, its array -processing operations conform to the NumPy API. - -This means that, in many cases, we can use JAX as a drop-in NumPy replacement. - Let's look at the similarities and differences between JAX and NumPy. ### Similarities +Above we import `jax.numpy as jnp`, which provides a NumPy-like interface to +array operations. + +One of the attractive features of JAX is that, whenever possible, this interface +conform to the NumPy API. + +As a result, we can often use JAX as a drop-in NumPy replacement. Here are some standard array operations using `jnp`: @@ -79,7 +80,7 @@ print(jnp.sum(a)) print(jnp.dot(a, a)) ``` -However, the array object `a` is not a NumPy array: +It should be remembered, however, that the array object `a` is not a NumPy array: ```{code-cell} ipython3 a @@ -104,11 +105,13 @@ Let's now look at some differences between JAX and NumPy array operations. (jax_speed)= #### Speed! -Let's say we want to evaluate the cosine function at many points. +One major difference is that JAX is faster --- and sometimes much faster. + +To illustrate, suppose that we want to evaluate the cosine function at many points. ```{code-cell} n = 50_000_000 -x = np.linspace(0, 10, n) +x = np.linspace(0, 10, n) # NumPy array ``` ##### With NumPy @@ -150,28 +153,24 @@ with qe.Timer(): # First run y = jnp.cos(x) # Hold the interpreter until the array operation finishes - jax.block_until_ready(y); + y.block_until_ready() ``` ```{note} -Here, in order to measure actual speed, we use the `block_until_ready` method -to hold the interpreter until the results of the computation are returned. - -This is necessary because JAX uses asynchronous dispatch, which +Above, the `block_until_ready` method +holds the interpreter until the results of the computation are returned. +This is necessary for timing execution because JAX uses asynchronous dispatch, which allows the Python interpreter to run ahead of numerical computations. - -For non-timed code, you can drop the line containing `block_until_ready`. ``` -And let's time it again. - +Now let's time it again. ```{code-cell} with qe.Timer(): # Second run y = jnp.cos(x) # Hold interpreter - jax.block_until_ready(y); + y.block_until_ready() ``` On a GPU, this code runs much faster than its NumPy equivalent. @@ -190,7 +189,11 @@ being used (as well as the data type). The size matters for generating optimized code because efficient parallelization requires matching the size of the task to the available hardware. -We can verify the claim that JAX specializes on array size by changing the input size and watching the runtimes. + +#### Size Experiment + +We can verify the claim that JAX specializes on array size by changing the input +size and watching the runtimes. ```{code-cell} x = jnp.linspace(0, 10, n + 1) @@ -201,7 +204,7 @@ with qe.Timer(): # First run y = jnp.cos(x) # Hold interpreter - jax.block_until_ready(y); + y.block_until_ready() ``` @@ -210,7 +213,7 @@ with qe.Timer(): # Second run y = jnp.cos(x) # Hold interpreter - jax.block_until_ready(y); + y.block_until_ready() ``` The run time increases and then falls again (this will be more obvious on the GPU). @@ -263,7 +266,8 @@ a[0] = 1 a ``` -In JAX this fails! +In JAX this fails 😱. + ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -278,14 +282,19 @@ except Exception as e: ``` -The designers of JAX chose to make arrays immutable because JAX uses a -functional programming style, which we discuss below. +The designers of JAX chose to make arrays immutable because + +1. JAX uses a *functional programming style* and +2. functional programming typically avoids mutable data + +We discuss these ideas {ref}`below `. -#### A workaround +(jax_at_workaround)= +#### A Workaround -We note that JAX does provide an alternative to in-place array modification -using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). +JAX does provide a direct alternative to in-place array modification +via the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -308,6 +317,7 @@ Hence, for the most part, we try to avoid this syntax. (Although it can in fact be efficient inside JIT-compiled functions -- but let's put this aside for now.) +(jax_func)= ## Functional Programming From JAX's documentation: diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 6d08a4b2..03b3bef2 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -54,7 +54,6 @@ tags: [hide-output] We will use the following imports. ```{code-cell} ipython3 -import random from functools import partial import numpy as np @@ -483,18 +482,67 @@ Notice that the second run is significantly faster after JIT compilation complet Numba's compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one. + ### JAX Version -Now let's create a JAX version using `lax.scan`: +Now let's create a JAX version using `at[t].set` style syntax, which, as +{ref}`discussed in the JAX lecture `, provides a workaround for immutable arrays. -(We'll hold `n` static because it affects array size and hence JAX wants to -specialize on its value in the compiled code.) +We'll apply a `lax.fori_loop`, which is a version of a for loop that can be compiled by XLA. ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=('n',), device=cpu) -def qm_jax(x0, n, α=4.0): +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_fori(x0, n, α=4.0): + + x = jnp.empty(n + 1).at[0].set(x0) + + def update(t, x): + return x.at[t + 1].set(α * x[t] * (1 - x[t])) + + x = lax.fori_loop(0, n, update, x) + return x + +``` + +* We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code. +* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. + +Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place. + +Let's time it with the same parameters: + +```{code-cell} ipython3 +with qe.Timer(): + # First run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +Let's run it again to eliminate compilation overhead: + +```{code-cell} ipython3 +with qe.Timer(): + # Second run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +JAX is also quite efficient for this sequential operation. + + +There's another way we can implement the loop that uses `lax.scan`. + +This alternative is arguably more in line with JAX's functional approach --- +although the syntax is difficult to remember. + + +```{code-cell} ipython3 +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) return x_new, x_new @@ -505,20 +553,12 @@ def qm_jax(x0, n, α=4.0): This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. -```{note} -We specify `device=cpu` in the `jax.jit` decorator because this computation -consists of many small sequential operations, leaving little opportunity for the -GPU to exploit parallelism. As a result, kernel-launch overhead tends to -dominate on the GPU, making the CPU a better -fit. -``` - Let's time it with the same parameters: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -528,13 +568,11 @@ Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` -JAX is also quite efficient for this sequential operation. - Both JAX and Numba deliver strong performance after compilation. @@ -547,9 +585,11 @@ array and fill it element by element using a standard Python loop. This is exactly how most programmers think about the algorithm. -The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive. +The JAX versions, on the other hand, require either `lax.fori_loop` or +`lax.scan`, both of which are less intuitive than a standard Python loop. -Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba. +While JAX's `at[t].set` syntax does allow element-wise updates, the overall code +remains harder to read than the Numba equivalent. For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation. @@ -575,12 +615,12 @@ For **sequential operations**, Numba has clear advantages. The code is natural and readable --- just a Python loop with a decorator --- and performance is excellent. -JAX can handle sequential problems via `lax.scan`, but the syntax is less -intuitive. +JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but +the syntax is less intuitive. ```{note} -One important advantage of `lax.scan` is that it supports automatic -differentiation through the loop, which Numba cannot do. +One important advantage of `lax.fori_loop` and `lax.scan` is that they +support automatic differentiation through the loop, which Numba cannot do. If you need to differentiate through a sequential computation (e.g., computing sensitivities of a trajectory to model parameters), JAX is the better choice despite the less natural syntax.