From e02d1d3830a1c2d4d2016ee7ae205b16414d6d90 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 14 Apr 2026 16:10:01 -0400 Subject: [PATCH 1/3] misc --- lectures/jax_intro.md | 47 +++++++---- lectures/numpy_vs_numba_vs_jax.md | 128 ++++++++++++++++-------------- 2 files changed, 101 insertions(+), 74 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 4a65e7a4..a2c679f9 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -351,19 +351,20 @@ In particular, pure functions will always return the same result if invoked with -### Examples +### Examples -- Pure and Impure -Here's an example of a *non-pure* function +Here's an example of a *impure* function ```{code-cell} ipython3 tax_rate = 0.1 -prices = [10.0, 20.0] def add_tax(prices): for i, price in enumerate(prices): prices[i] = price * (1 + tax_rate) - print('Post-tax prices: ', prices) - return prices + +prices = [10.0, 20.0] +add_tax(prices) +prices ``` This function fails to be pure because @@ -375,15 +376,22 @@ This function fails to be pure because Here's a *pure* version ```{code-cell} ipython3 -tax_rate = 0.1 -prices = (10.0, 20.0) def add_tax_pure(prices, tax_rate): new_prices = [price * (1 + tax_rate) for price in prices] return new_prices + +tax_rate = 0.1 +prices = (10.0, 20.0) +after_tax_prices = add_tax(prices) +after_tax_prices ``` -This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state. +This is pure because + +* all dependencies explicit through function arguments +* and doesn't modify any external state + ### Why Functional Programming? @@ -438,8 +446,8 @@ This function is *not pure* because: * It's non-deterministic: same inputs, different outputs * It has side effects: it modifies the global random number generator state -Dangerous under parallelization --- must carefully control what happens in each -thread! +This is dangerous under parallelization --- must carefully control what happens in each +thread. ### JAX @@ -560,7 +568,11 @@ sense when we get to parallel programming. The function below produces `k` (quasi-) independent random `n x n` matrices using `split`. ```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): +def gen_random_matrices( + key, # JAX key for random numbers + n=2, # Matrices will be n x n + k=3 # Number of matrices to generate + ): matrices = [] for _ in range(k): key, subkey = jax.random.split(key) @@ -583,7 +595,7 @@ This function is *pure* ### Benefits -The explicitness of JAX brings significant benefits: +As mentioned above, this explicitness is valuable: * Reproducibility: Easy to reproduce results by reusing keys * Parallelization: Control what happens on separate threads @@ -672,8 +684,14 @@ with qe.Timer(): The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. -But we are still using eager execution --- lots of memory and read/write +This is because the individual array operations are parallelized on the GPU +But we are still using eager execution + +* lots of memory due to intermediate arrays +* lots of memory read/writes + +Also, many separate kernels launched on the GPU ### Compiling the Whole Function @@ -708,7 +726,8 @@ The runtime has improved again --- now because we fused all the operations * Aggressive optimization based on entire computational sequence * Eliminates multiple calls to the hardware accelerator -* No creation of intermediate arrays + +The memory footprint is also much lower --- no creation of intermediate arrays Incidentally, a more common syntax when targeting a function for the JIT compiler is diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index e38b5ed6..05b2b94e 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -135,18 +135,37 @@ for x in grid: Let's switch to NumPy and use a larger grid -Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such -that `f(x, y)` generates all evaluations on the product grid. +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) # Large grid +``` + +As a first pass of vectorization we might try something like this + +```{code-cell} ipython3 +# Large grid +z = np.max(f(grid, grid)) # This is wrong! +``` + +The problem here is that `f(grid, grid)` doesn't obey the nested loop. + +In terms of the figure above, it only computes the values of `f` along the +diagonal. + +To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`. + +Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` + +such that `f(x, y)` generates all evaluations on the product grid. ```{code-cell} ipython3 # Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid +x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): - z_max_numpy = np.max(f(x, y)) + z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works ``` In the vectorized version, all the looping takes place in compiled code. @@ -159,11 +178,30 @@ The output should be close to one: print(f"NumPy result: {z_max_numpy:.6f}") ``` +### Memory Issues + +So we have the right solution reasonable time --- but memory usage is huge. + +While the flat arrays are low-memory + +```{code-cell} ipython3 +grid.nbytes +``` + +the mesh grids are two-dimensional and hence very memory intensive + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` + +Moreover, NumPy's eager execution creates many intermediate arrays of the same size! + +This kind of memory usage can be a big problem in actual research calculations. ### A Comparison with Numba -Now let's see if we can achieve better performance using Numba with a simple loop. +Let's see if we can achieve better performance using Numba with a simple loop. ```{code-cell} ipython3 @numba.jit @@ -194,15 +232,13 @@ with qe.Timer(): compute_max_numba(grid) ``` -Depending on your machine, the Numba version might be either slower or faster than NumPy. +Notice how we are using almost no memory --- we just need the one-dimensional `grid` -In most cases we find that Numba is slightly better. +Moreover, execution speed is good. -On the one hand, NumPy combines efficient arithmetic with some -multithreading, which provides an advantage. +On most machines, the Numba version will be somewhat faster than NumPy. -On the other hand, the Numba routine uses much less memory, since we are only -working with a single one-dimensional grid. +The reason is efficient machine code plus less memory read-write. ### Parallelized Numba @@ -301,27 +337,11 @@ The compilation overhead is a one-time cost that pays off when the function is c ### JAX plus vmap -There is one problem with both the NumPy code and the JAX code above: - -While the flat arrays are low-memory - -```{code-cell} ipython3 -grid.nbytes -``` - -the mesh grids are memory intensive - -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` +Because we use'd `jax.jit` above, we avoided creating many intermediate arrays. -This extra memory usage can be a big problem in actual research calculations. +But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`. -Fortunately, JAX admits a different approach -using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). - -The idea of `vmap` is to break vectorization into stages, transforming a -function that operates on single values into one that operates on arrays. +Fortunately, we can avoid this by using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). Here's how we can apply it to our problem. @@ -330,13 +350,13 @@ Here's how we can apply it to our problem. @jax.jit def compute_max_vmap(grid): # Construct a function that takes the max over all x for given y - f_vec_x_max = lambda y: jnp.max(f(grid, y)) + compute_column_max = lambda y: jnp.max(f(grid, y)) # Vectorize the function so we can call on all y simultaneously - f_vec_max = jax.vmap(f_vec_x_max) - # Compute the max across x at every y - maxes = f_vec_max(grid) - # Compute the max of the maxes and return - return jnp.max(maxes) + vectorized_compute_column_max = jax.vmap(compute_column_max) + # Compute the column max at every row + column_maxes = vectorized_compute_column_max(grid) + # Compute the max of the column maxes and return + return jnp.max(column_maxes) ``` Note that we never create @@ -345,6 +365,8 @@ Note that we never create * the two-dimensional grid `y_mesh` or * the two-dimensional array `f(x,y)` +Like Numba, we just use the flat array `grid`. + And because everything is under a single `@jax.jit`, the compiler can fuse all operations into one optimized kernel. @@ -378,18 +400,14 @@ In our view, JAX is the winner for vectorized operations. It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap). -Moreover, the `vmap` approach can sometimes lead to significantly clearer code. - -While Numba is impressive, the beauty of JAX is that, with fully vectorized -operations, we can run exactly the same code on machines with hardware -accelerators and reap all the benefits without extra effort. +It also dominates Numba when run on the GPU. -Moreover, JAX already knows how to effectively parallelize many common array -operations, which is key to fast execution. - -For most cases encountered in economics, econometrics, and finance, it is +```{note} +Numba can support GPU programming through `numba.cuda` but then we need to +parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is far better to hand over to the JAX compiler for efficient parallelization than to -try to hand code these routines ourselves. +try to hand-code these routines ourselves. +``` ## Sequential operations @@ -554,8 +572,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or While JAX's `at[t].set` syntax does allow element-wise updates, the overall code remains harder to read than the Numba equivalent. -For this type of sequential operation, Numba is the clear winner in terms of -code clarity and ease of implementation. ## Overall recommendations @@ -573,7 +589,7 @@ than traditional meshgrid-based vectorization. In addition, JAX functions are automatically differentiable, as we explore in {doc}`autodiff`. -For **sequential operations**, Numba has clear advantages. +For **sequential operations**, Numba has nicer syntax. The code is natural and readable --- just a Python loop with a decorator --- and performance is excellent. @@ -581,17 +597,9 @@ and performance is excellent. JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but the syntax is less intuitive. -```{note} -One important advantage of `lax.fori_loop` and `lax.scan` is that they -support automatic differentiation through the loop, which Numba cannot do. -If you need to differentiate through a sequential computation (e.g., computing -sensitivities of a trajectory to model parameters), JAX is the better choice -despite the less natural syntax. -``` +On the other hand, the JAX versions support automatic differentiation. -In practice, many problems involve a mix of both patterns. +That might be of interest if, say, we want to compute sensitivities of a +trajectory to model parameters -A good rule of thumb: default to JAX for new projects, especially when -hardware acceleration or differentiability might be useful, and reach for Numba -when you have a tight sequential loop that needs to be fast and readable. From 3ea9f1aeb4add2477bb81c64409e229b72c2880a Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 14 Apr 2026 16:11:28 -0400 Subject: [PATCH 2/3] misc --- lectures/numpy_vs_numba_vs_jax.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 05b2b94e..f45df634 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -153,8 +153,7 @@ diagonal. To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`. -Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` - +Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such that `f(x, y)` generates all evaluations on the product grid. @@ -180,7 +179,7 @@ print(f"NumPy result: {z_max_numpy:.6f}") ### Memory Issues -So we have the right solution reasonable time --- but memory usage is huge. +So we have the right solution in reasonable time --- but memory usage is huge. While the flat arrays are low-memory @@ -337,7 +336,7 @@ The compilation overhead is a one-time cost that pays off when the function is c ### JAX plus vmap -Because we use'd `jax.jit` above, we avoided creating many intermediate arrays. +Because we used `jax.jit` above, we avoided creating many intermediate arrays. But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`. From 19eba459d9471de2abdf8fcb4c90a980dc949ee9 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 14 Apr 2026 16:25:01 -0400 Subject: [PATCH 3/3] Fix call to add_tax_pure in jax_intro pure function example Co-Authored-By: Claude Opus 4.6 (1M context) --- lectures/jax_intro.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index a2c679f9..748f0eb3 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -383,7 +383,7 @@ def add_tax_pure(prices, tax_rate): tax_rate = 0.1 prices = (10.0, 20.0) -after_tax_prices = add_tax(prices) +after_tax_prices = add_tax_pure(prices, tax_rate) after_tax_prices ```