From 9a9cb8b72aad9f69db71cde6dc104cabc67ef19c Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 14 Apr 2026 12:25:52 -0400 Subject: [PATCH] Improve and simplify jax_intro and numpy_vs_numba_vs_jax lectures Co-Authored-By: Claude Opus 4.6 (1M context) --- lectures/jax_intro.md | 176 +++++++++++------------------- lectures/numpy_vs_numba_vs_jax.md | 129 ++++++++-------------- 2 files changed, 109 insertions(+), 196 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 4914e3de..4a65e7a4 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -416,15 +416,34 @@ directly into the graph-theoretic representations supported by JAX. Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB. -At first you might find the syntax rather verbose. -But the syntax and semantics are necessary to maintain the functional programming style we just discussed. -Moreover, full control of random state is essential for parallel programming, -such as when we want to run independent experiments along multiple threads. +### NumPy / MATLAB Approach +In NumPy / MATLAB, generation works by maintaining hidden global state. + +```{code-cell} ipython3 +np.random.seed(42) +print(np.random.randn(2)) +``` + +Each time we call a random function, the hidden state is updated: + +```{code-cell} ipython3 +print(np.random.randn(2)) +``` + +This function is *not pure* because: + +* It's non-deterministic: same inputs, different outputs +* It has side effects: it modifies the global random number generator state + +Dangerous under parallelization --- must carefully control what happens in each +thread! + + +### JAX -### Random number generation In JAX, the state of the random number generator is controlled explicitly. @@ -547,105 +566,30 @@ def gen_random_matrices(key, n=2, k=3): key, subkey = jax.random.split(key) A = jax.random.uniform(subkey, (n, n)) matrices.append(A) - print(A) return matrices ``` ```{code-cell} ipython3 seed = 42 key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -We can also use `fold_in` when iterating in a loop: - -```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): - matrices = [] - for i in range(k): - step_key = jax.random.fold_in(key, i) - A = jax.random.uniform(step_key, (n, n)) - matrices.append(A) - print(A) - return matrices -``` - -```{code-cell} ipython3 -key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - - -### Why explicit random state? - -Why does JAX require this somewhat verbose approach to random number generation? - -One reason is to maintain pure functions. - -Let's see how random number generation relates to pure functions by comparing NumPy and JAX. - -#### NumPy's approach - -In NumPy's legacy random number generation API (which mimics MATLAB), generation -works by maintaining hidden global state. - -Each time we call a random function, this state is updated: - -```{code-cell} ipython3 -np.random.seed(42) -print(np.random.randn()) # Updates state of random number generator -print(np.random.randn()) # Updates state of random number generator +gen_random_matrices(key) ``` -Each call returns a different value, even though we're calling the same function with the same inputs (no arguments). - -This function is *not pure* because: - -* It's non-deterministic: same inputs (none, in this case) give different outputs -* It has side effects: it modifies the global random number generator state - - -#### JAX's approach - -As we saw above, JAX takes a different approach, making randomness explicit through keys. - -For example, - -```{code-cell} ipython3 -def random_sum_jax(key): - key1, key2 = jax.random.split(key) - x = jax.random.normal(key1) - y = jax.random.normal(key2) - return x + y -``` - -With the same key, we always get the same result: - -```{code-cell} ipython3 -key = jax.random.key(42) -random_sum_jax(key) -``` - -```{code-cell} ipython3 -random_sum_jax(key) -``` +This function is *pure* -To get new draws we need to supply a new key. +* Deterministic: same inputs, same output +* No side effects: no hidden state is modified -The function `random_sum_jax` is pure because: -* It's deterministic: same key always produces same output -* No side effects: no hidden state is modified +### Benefits The explicitness of JAX brings significant benefits: * Reproducibility: Easy to reproduce results by reusing keys -* Parallelization: Each thread can have its own key without conflicts -* Debugging: No hidden state makes code easier to reason about +* Parallelization: Control what happens on separate threads +* Debugging: No hidden state makes code easier to test * JIT compatibility: The compiler can optimize pure functions more aggressively -The last point is expanded on in the next section. - ## JIT Compilation @@ -655,7 +599,12 @@ efficient machine code that varies with both task size and hardware. We saw the power of JAX's JIT compiler combined with parallel hardware when we {ref}`above `, when we applied `cos` to a large array. -Let's try the same thing with a more complex function: +Here we study JIT compilation for more complex functions + + +### With NumPy + +We'll try first with NumPy, using ```{code-cell} def f(x): @@ -663,9 +612,7 @@ def f(x): return y ``` -### With NumPy - -We'll try first with NumPy +Let's run with large `x` ```{code-cell} n = 50_000_000 @@ -678,11 +625,20 @@ with qe.Timer(): y = f(x) ``` +**Eager** execution model +* Each operation is executed immediately as it is encountered, materializing its + result before the next operation begins. -### With JAX +Disadvantages -Now let's try again with JAX. +* Minimal parallelization +* Heavy memory footprint --- produces many intermediate arrays +* Lots of memory read/write + + + +### With JAX As a first pass, we replace `np` with `jnp` throughout: @@ -716,14 +672,15 @@ with qe.Timer(): The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. -However, with JAX, we have another trick up our sleeve --- we can JIT-compile -the entire function, not just individual operations. +But we are still using eager execution --- lots of memory and read/write ### Compiling the Whole Function -The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array -operations into a single optimized kernel. +Fortunately, with JAX, we have another trick up our sleeve --- we can JIT-compile +the entire function, not just individual operations. + +The compiler fuses all array operations into a single optimized kernel Let's try this with the function `f`: @@ -747,11 +704,11 @@ with qe.Timer(): jax.block_until_ready(y); ``` -The runtime has improved again --- now because we fused all the operations, -allowing the compiler to optimize more aggressively. +The runtime has improved again --- now because we fused all the operations -For example, the compiler can eliminate multiple calls to the hardware -accelerator and the creation of a number of intermediate arrays. +* Aggressive optimization based on entire computational sequence +* Eliminates multiple calls to the hardware accelerator +* No creation of intermediate arrays Incidentally, a more common syntax when targeting a function for the JIT compiler is @@ -777,16 +734,12 @@ subsequent calls with the same input shapes and types reuse the cached compiled code and run at full speed. - ### Compiling non-pure functions -Now that we've seen how powerful JIT compilation can be, it's important to -understand its relationship with pure functions. - While JAX will not usually throw errors when compiling impure functions, -execution becomes unpredictable. +execution becomes unpredictable! -Here's an illustration of this fact, using global variables: +Here's an illustration of this fact: ```{code-cell} ipython3 a = 1 # global @@ -871,16 +824,13 @@ for row in X: However, Python loops are slow and cannot be efficiently compiled or parallelized by JAX. -Using `vmap` keeps the computation on the accelerator and composes with other -JAX transformations like `jit` and `grad`: +With `vmap`, we can avoid loops and keep the computation on the accelerator: ```{code-cell} ipython3 -batch_mm_diff = jax.vmap(mm_diff) -batch_mm_diff(X) +batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version +batch_mm_diff(X) # Apply to each row of X ``` -The function `mm_diff` was written for a single array, and `vmap` automatically -lifted it to operate row-wise over a matrix --- no loops, no reshaping. ### Combining transformations diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 03b3bef2..e38b5ed6 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -127,37 +127,38 @@ m = -np.inf for x in grid: for y in grid: z = f(x, y) - if z > m: - m = z + m = max(m, z) ``` ### NumPy vectorization -If we switch to NumPy-style vectorization we can use a much larger grid and the -code executes relatively quickly. +Let's switch to NumPy and use a larger grid Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such that `f(x, y)` generates all evaluations on the product grid. -(This strategy dates back to MATLAB.) ```{code-cell} ipython3 +# Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) + +x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): z_max_numpy = np.max(f(x, y)) - -print(f"NumPy result: {z_max_numpy:.6f}") ``` In the vectorized version, all the looping takes place in compiled code. -Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs. +The use of `meshgrid` allows us to replicate the nested for loop. + +The output should be close to one: + +```{code-cell} ipython3 +print(f"NumPy result: {z_max_numpy:.6f}") +``` -(The parallelization cannot be highly efficient because the binary is compiled -before it sees the size of the arrays `x` and `y`.) ### A Comparison with Numba @@ -183,8 +184,6 @@ grid = np.linspace(-3, 3, 3_000) with qe.Timer(): # First run z_max_numba = compute_max_numba(grid) - -print(f"Numba result: {z_max_numba:.6f}") ``` Let's run again to eliminate compile time. @@ -230,8 +229,6 @@ Here's a warm up run and test. with qe.Timer(): # First run z_max_parallel = compute_max_numba_parallel(grid) - -print(f"Numba result: {z_max_parallel:.6f}") ``` Here's the timing for the pre-compiled version. @@ -242,18 +239,22 @@ with qe.Timer(): compute_max_numba_parallel(grid) ``` -If you have multiple cores, you should see at least some benefits from -parallelization here. +If you have multiple cores, you should see benefits from parallelization here. -For more powerful machines and larger grid sizes, parallelization can generate -major speed gains, even on the CPU. +Let's make sure we're still getting the right result (close to one): +```{code-cell} ipython3 +print(f"Numba result: {z_max_parallel:.6f}") +``` -### Vectorized code with JAX -On the surface, vectorized code in JAX is similar to NumPy code. +For powerful machines and larger grid sizes, parallelization can generate +useful speed gains, even on the CPU. + -But there are also some differences, which we highlight here. +### Vectorized code with JAX + +Let's try replicating the NumPy vectorized approach with JAX. Let's start with the function, which switches `np` to `jnp` and adds `jax.jit` @@ -264,8 +265,7 @@ def f(x, y): ``` -As with NumPy, to get the right shape and the correct nested `for` loop -calculation, we can use a `meshgrid` operation designed for this purpose: +We use the NumPy style meshgrid approach: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) @@ -325,65 +325,25 @@ function that operates on single values into one that operates on arrays. Here's how we can apply it to our problem. -```{code-cell} ipython3 -# Set up f to compute f(x, y) at every x for any given y -f_vec_x = lambda y: f(grid, y) -# Create a second function that vectorizes this operation over all y -f_vec = jax.vmap(f_vec_x) -``` - -Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array `grid`. - -Let's see the timing: - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() - -print(f"JAX vmap v1 result: {z_max:.6f}") -``` - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() -``` - -By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version -uses far less memory with greatly changing run time. - -This is good --- but we are leaving speed gains on the table! - -First note that the code above computes the full two-dimensional array `f(x,y)`, which creates -overheads, before it takes the max. - -Second, the `jnp.max` call sits outside the JIT-compiled function `f`, so the -compiler cannot fuse these operations into a single kernel. - -We can fix both problems by pushing the max inside and wrapping everything in -a single `@jax.jit`: ```{code-cell} ipython3 @jax.jit def compute_max_vmap(grid): - # Construct a function that takes the max along each row + # Construct a function that takes the max over all x for given y f_vec_x_max = lambda y: jnp.max(f(grid, y)) - # Vectorize the function so we can call on all rows simultaneously + # Vectorize the function so we can call on all y simultaneously f_vec_max = jax.vmap(f_vec_x_max) - # Call the vectorized function and take the max - return jnp.max(f_vec_max(grid)) + # Compute the max across x at every y + maxes = f_vec_max(grid) + # Compute the max of the maxes and return + return jnp.max(maxes) ``` -Here - -* `f_vec_x_max` computes the max along any given row -* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel. +Note that we never create -We apply this function to all rows and then take the max of the row maxes. - -Because we push the max inside, we never construct the full two-dimensional -array `f(x,y)`, saving even more memory. +* the two-dimensional grid `x_mesh` +* the two-dimensional grid `y_mesh` or +* the two-dimensional array `f(x,y)` And because everything is under a single `@jax.jit`, the compiler can fuse all operations into one optimized kernel. @@ -478,15 +438,16 @@ with qe.Timer(): Numba handles this sequential operation very efficiently. -Notice that the second run is significantly faster after JIT compilation completes. -Numba's compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one. +### JAX Version +We cannot directly replace `numba.jit` with `jax.jit` because JAX arrays are immutable. -### JAX Version +But we can still implement this operation -Now let's create a JAX version using `at[t].set` style syntax, which, as -{ref}`discussed in the JAX lecture `, provides a workaround for immutable arrays. +#### First Attempt + +Here's a workaround using the `at[t].set` syntax we {ref}`discussed in the JAX lecture `. We'll apply a `lax.fori_loop`, which is a version of a for loop that can be compiled by XLA. @@ -509,7 +470,7 @@ def qm_jax_fori(x0, n, α=4.0): * We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code. * We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. -Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place. +Important: Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place! Let's time it with the same parameters: @@ -531,8 +492,10 @@ with qe.Timer(): x_jax.block_until_ready() ``` -JAX is also quite efficient for this sequential operation. +JAX is also quite efficient for this sequential operation! + +#### Second Attempt There's another way we can implement the loop that uses `lax.scan`. @@ -573,12 +536,12 @@ with qe.Timer(): x_jax.block_until_ready() ``` -Both JAX and Numba deliver strong performance after compilation. +Surprisingly, JAX also delivers strong performance after compilation. ### Summary -While both Numba and JAX deliver strong performance for sequential operations, *there are significant differences in code readability and ease of use*. +While both Numba and JAX deliver strong performance for sequential operations, there are differences in code readability and ease of use. The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop.