From 59017ca7a060e2952d11c126db6083d3c7548e43 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 12 Apr 2026 19:49:51 -0400 Subject: [PATCH] misc --- lectures/jax_intro.md | 15 +-- lectures/numba.md | 6 +- lectures/numpy_vs_numba_vs_jax.md | 182 ++++++++++++------------------ 3 files changed, 81 insertions(+), 122 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index fed7bcb3..1c6fb52a 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -645,12 +645,7 @@ efficient machine code that varies with both task size and hardware. We saw the power of JAX's JIT compiler combined with parallel hardware when we {ref}`above `, when we applied `cos` to a large array. -Let's try the same thing with a more complex function. - - -### Evaluating a more complicated function - -Consider the function +Let's try the same thing with a more complex function: ```{code-cell} def f(x): @@ -658,7 +653,7 @@ def f(x): return y ``` -#### With NumPy +### With NumPy We'll try first with NumPy @@ -675,7 +670,7 @@ with qe.Timer(): -#### With JAX +### With JAX Now let's try again with JAX. @@ -712,10 +707,10 @@ The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. However, with JAX, we have another trick up our sleeve --- we can JIT-compile -the *entire* function, not just individual operations. +the entire function, not just individual operations. -### Compiling the whole function +### Compiling the Whole Function The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array operations into a single optimized kernel. diff --git a/lectures/numba.md b/lectures/numba.md index 8c927191..85ba3ddd 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -124,7 +124,7 @@ n = 10_000_000 with qe.Timer() as timer1: # Time Python base version - x = qm(0.1, int(n)) + x = qm(0.1, n) ``` @@ -154,7 +154,7 @@ Let's time this new version: ```{code-cell} ipython3 with qe.Timer() as timer2: # Time jitted version - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` This is a large speed gain. @@ -167,7 +167,7 @@ function has been compiled and is in memory: ```{code-cell} ipython3 with qe.Timer() as timer3: # Second run - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` Here's the speed gain diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index f1aa62a9..6d08a4b2 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -147,7 +147,7 @@ that `f(x, y)` generates all evaluations on the product grid. grid = np.linspace(-3, 3, 3_000) x, y = np.meshgrid(grid, grid) -with qe.Timer(precision=8): +with qe.Timer(): z_max_numpy = np.max(f(x, y)) print(f"NumPy result: {z_max_numpy:.6f}") @@ -172,13 +172,17 @@ def compute_max_numba(grid): for x in grid: for y in grid: z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z + m = max(m, z) return m +``` +Let's test it: + +```{code-cell} ipython3 grid = np.linspace(-3, 3, 3_000) -with qe.Timer(precision=8): +with qe.Timer(): + # First run z_max_numba = compute_max_numba(grid) print(f"Numba result: {z_max_numba:.6f}") @@ -187,15 +191,17 @@ print(f"Numba result: {z_max_numba:.6f}") Let's run again to eliminate compile time. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba(grid) ``` -Depending on your machine, the Numba version can be a bit slower or a bit faster -than NumPy. +Depending on your machine, the Numba version might be either slower or faster than NumPy. -On one hand, NumPy combines efficient arithmetic (like Numba) with some -multithreading (unlike this Numba code), which provides an advantage. +In most cases we find that Numba is slightly better. + +On the one hand, NumPy combines efficient arithmetic with some +multithreading, which provides an advantage. On the other hand, the Numba routine uses much less memory, since we are only working with a single one-dimensional grid. @@ -205,48 +211,6 @@ working with a single one-dimensional grid. Now let's try parallelization with Numba using `prange`: -Here's a naive and *incorrect* attempt. - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel(grid): - n = len(grid) - m = -np.inf - for i in numba.prange(n): - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m - -``` - -This returns `-inf` --- the initial value of `m`, as if it were never updated: - -```{code-cell} ipython3 -z_max_parallel_incorrect = compute_max_numba_parallel(grid) -print(f"Numba result: {z_max_parallel_incorrect} 😱") -``` - -To understand why, recall that `prange` splits the outer loop across threads. - -Each thread gets its own private copy of `m`, initialized to `-np.inf`, and -correctly updates it within its chunk of iterations. - -But at the end of the loop, Numba needs to combine the per-thread copies of `m` -back into a single value --- a **reduction**. - -For patterns it recognizes, such as `m += z` (sum) or `m = max(m, z)` (max), -Numba knows the combining operator. - -But it does not recognize the `if z > m: m = z` pattern as a max reduction, so -the per-thread results are never combined and `m` retains its initial value. - -The simplest fix is to replace the conditional with `max`, which Numba -recognizes: - ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): @@ -261,38 +225,21 @@ def compute_max_numba_parallel(grid): return m ``` -An alternative is to make the loop body fully independent across `i` and -handle the reduction ourselves: +Here's a warm up run and test. ```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel_v2(grid): - n = len(grid) - row_maxes = np.empty(n) - for i in numba.prange(n): - row_max = -np.inf - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > row_max: - row_max = z - row_maxes[i] = row_max - return np.max(row_maxes) -``` - -Here each thread writes to a separate element of `row_maxes`, so we handle the -reduction ourselves via `np.max`. +with qe.Timer(): + # First run + z_max_parallel = compute_max_numba_parallel(grid) -```{code-cell} ipython3 -z_max_parallel = compute_max_numba_parallel(grid) print(f"Numba result: {z_max_parallel:.6f}") ``` -Here's the timing. +Here's the timing for the pre-compiled version. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba_parallel(grid) ``` @@ -309,8 +256,7 @@ On the surface, vectorized code in JAX is similar to NumPy code. But there are also some differences, which we highlight here. -Let's start with the function. - +Let's start with the function, which switches `np` to `jnp` and adds `jax.jit` ```{code-cell} ipython3 @jax.jit @@ -325,9 +271,15 @@ calculation, we can use a `meshgrid` operation designed for this purpose: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) x_mesh, y_mesh = jnp.meshgrid(grid, grid) +``` + +Now let's run and time -with qe.Timer(precision=8): +```{code-cell} ipython3 +with qe.Timer(): + # First run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() print(f"Plain vanilla JAX result: {z_max:.6f}") @@ -336,8 +288,10 @@ print(f"Plain vanilla JAX result: {z_max:.6f}") Let's run again to eliminate compile time. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() ``` @@ -384,7 +338,7 @@ Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array Let's see the timing: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() @@ -392,20 +346,20 @@ print(f"JAX vmap v1 result: {z_max:.6f}") ``` ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() ``` By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version -uses far less memory, with similar runtime. +uses far less memory with greatly changing run time. -But we are still leaving speed gains on the table. +This is good --- but we are leaving speed gains on the table! -The code above computes the full two-dimensional array `f(x,y)` and then takes -the max. +First note that the code above computes the full two-dimensional array `f(x,y)`, which creates +overheads, before it takes the max. -Moreover, the `jnp.max` call sits outside the JIT-compiled function `f`, so the +Second, the `jnp.max` call sits outside the JIT-compiled function `f`, so the compiler cannot fuse these operations into a single kernel. We can fix both problems by pushing the max inside and wrapping everything in @@ -438,8 +392,11 @@ all operations into one optimized kernel. Let's try it. ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap(grid).block_until_ready() +with qe.Timer(): + # First run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() print(f"JAX vmap result: {z_max:.6f}") ``` @@ -447,8 +404,11 @@ print(f"JAX vmap result: {z_max:.6f}") Let's run it again to eliminate compilation time: ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap(grid).block_until_ready() +with qe.Timer(): + # Second run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() ``` @@ -504,14 +464,16 @@ Let's generate a time series of length 10,000,000 and time the execution: ```{code-cell} ipython3 n = 10_000_000 -with qe.Timer(precision=8): +with qe.Timer(): + # First run x = qm(0.1, n) ``` Let's run it again to eliminate compilation time: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run x = qm(0.1, n) ``` @@ -525,12 +487,13 @@ Numba's compilation is typically quite fast, and the resulting code performance Now let's create a JAX version using `lax.scan`: -(We'll hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.) +(We'll hold `n` static because it affects array size and hence JAX wants to +specialize on its value in the compiled code.) ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnums=(1,), device=cpu) +@partial(jax.jit, static_argnames=('n',), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -543,29 +506,31 @@ def qm_jax(x0, n, α=4.0): This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. ```{note} -Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator. - -The computation consists of many small sequential operations, leaving little -opportunity for the GPU to exploit parallelism. - -As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU -a better fit for this workload. - -Curious readers can try removing this option to see how performance changes. +We specify `device=cpu` in the `jax.jit` decorator because this computation +consists of many small sequential operations, leaving little opportunity for the +GPU to exploit parallelism. As a result, kernel-launch overhead tends to +dominate on the GPU, making the CPU a better +fit. ``` Let's time it with the same parameters: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # First run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # Second run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` JAX is also quite efficient for this sequential operation. @@ -575,8 +540,7 @@ Both JAX and Numba deliver strong performance after compilation. ### Summary -While both Numba and JAX deliver strong performance for sequential operations, -*there are significant differences in code readability and ease of use*. +While both Numba and JAX deliver strong performance for sequential operations, *there are significant differences in code readability and ease of use*. The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop.