Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,19 +351,20 @@ In particular, pure functions will always return the same result if invoked with



### Examples
### Examples -- Pure and Impure

Here's an example of a *non-pure* function
Here's an example of a *impure* function

```{code-cell} ipython3
tax_rate = 0.1
prices = [10.0, 20.0]

def add_tax(prices):
for i, price in enumerate(prices):
prices[i] = price * (1 + tax_rate)
print('Post-tax prices: ', prices)
return prices

prices = [10.0, 20.0]
add_tax(prices)
prices
```

This function fails to be pure because
Expand All @@ -375,15 +376,22 @@ This function fails to be pure because
Here's a *pure* version

```{code-cell} ipython3
tax_rate = 0.1
prices = (10.0, 20.0)

def add_tax_pure(prices, tax_rate):
new_prices = [price * (1 + tax_rate) for price in prices]
return new_prices

tax_rate = 0.1
prices = (10.0, 20.0)
after_tax_prices = add_tax_pure(prices, tax_rate)
after_tax_prices
```

This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state.
This is pure because

* all dependencies explicit through function arguments
* and doesn't modify any external state


### Why Functional Programming?

Expand Down Expand Up @@ -438,8 +446,8 @@ This function is *not pure* because:
* It's non-deterministic: same inputs, different outputs
* It has side effects: it modifies the global random number generator state

Dangerous under parallelization --- must carefully control what happens in each
thread!
This is dangerous under parallelization --- must carefully control what happens in each
thread.


### JAX
Expand Down Expand Up @@ -560,7 +568,11 @@ sense when we get to parallel programming.
The function below produces `k` (quasi-) independent random `n x n` matrices using `split`.

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
def gen_random_matrices(
key, # JAX key for random numbers
n=2, # Matrices will be n x n
k=3 # Number of matrices to generate
):
matrices = []
for _ in range(k):
key, subkey = jax.random.split(key)
Expand All @@ -583,7 +595,7 @@ This function is *pure*

### Benefits

The explicitness of JAX brings significant benefits:
As mentioned above, this explicitness is valuable:

* Reproducibility: Easy to reproduce results by reusing keys
* Parallelization: Control what happens on separate threads
Expand Down Expand Up @@ -672,8 +684,14 @@ with qe.Timer():
The outcome is similar to the `cos` example --- JAX is faster, especially on the
second run after JIT compilation.

But we are still using eager execution --- lots of memory and read/write
This is because the individual array operations are parallelized on the GPU

But we are still using eager execution

* lots of memory due to intermediate arrays
* lots of memory read/writes

Also, many separate kernels launched on the GPU

### Compiling the Whole Function

Expand Down Expand Up @@ -708,7 +726,8 @@ The runtime has improved again --- now because we fused all the operations

* Aggressive optimization based on entire computational sequence
* Eliminates multiple calls to the hardware accelerator
* No creation of intermediate arrays

The memory footprint is also much lower --- no creation of intermediate arrays

Incidentally, a more common syntax when targeting a function for the JIT
compiler is
Expand Down
127 changes: 67 additions & 60 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,36 @@ for x in grid:

Let's switch to NumPy and use a larger grid

Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such
that `f(x, y)` generates all evaluations on the product grid.
```{code-cell} ipython3
grid = np.linspace(-3, 3, 3_000) # Large grid
```

As a first pass of vectorization we might try something like this

```{code-cell} ipython3
# Large grid
z = np.max(f(grid, grid)) # This is wrong!
```

The problem here is that `f(grid, grid)` doesn't obey the nested loop.

In terms of the figure above, it only computes the values of `f` along the
diagonal.

To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`.

Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y`
such that `f(x, y)` generates all evaluations on the product grid.


```{code-cell} ipython3
# Large grid
grid = np.linspace(-3, 3, 3_000)

x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid

with qe.Timer():
z_max_numpy = np.max(f(x, y))
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
```

In the vectorized version, all the looping takes place in compiled code.
Expand All @@ -159,11 +177,30 @@ The output should be close to one:
print(f"NumPy result: {z_max_numpy:.6f}")
```

### Memory Issues

So we have the right solution in reasonable time --- but memory usage is huge.

While the flat arrays are low-memory

```{code-cell} ipython3
grid.nbytes
```

the mesh grids are two-dimensional and hence very memory intensive

```{code-cell} ipython3
x_mesh.nbytes + y_mesh.nbytes
```

Moreover, NumPy's eager execution creates many intermediate arrays of the same size!

This kind of memory usage can be a big problem in actual research calculations.


### A Comparison with Numba

Now let's see if we can achieve better performance using Numba with a simple loop.
Let's see if we can achieve better performance using Numba with a simple loop.

```{code-cell} ipython3
@numba.jit
Expand Down Expand Up @@ -194,15 +231,13 @@ with qe.Timer():
compute_max_numba(grid)
```

Depending on your machine, the Numba version might be either slower or faster than NumPy.
Notice how we are using almost no memory --- we just need the one-dimensional `grid`

In most cases we find that Numba is slightly better.
Moreover, execution speed is good.

On the one hand, NumPy combines efficient arithmetic with some
multithreading, which provides an advantage.
On most machines, the Numba version will be somewhat faster than NumPy.

On the other hand, the Numba routine uses much less memory, since we are only
working with a single one-dimensional grid.
The reason is efficient machine code plus less memory read-write.


### Parallelized Numba
Expand Down Expand Up @@ -301,27 +336,11 @@ The compilation overhead is a one-time cost that pays off when the function is c

### JAX plus vmap

There is one problem with both the NumPy code and the JAX code above:

While the flat arrays are low-memory

```{code-cell} ipython3
grid.nbytes
```

the mesh grids are memory intensive

```{code-cell} ipython3
x_mesh.nbytes + y_mesh.nbytes
```
Because we used `jax.jit` above, we avoided creating many intermediate arrays.

This extra memory usage can be a big problem in actual research calculations.
But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`.

Fortunately, JAX admits a different approach
using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).

The idea of `vmap` is to break vectorization into stages, transforming a
function that operates on single values into one that operates on arrays.
Fortunately, we can avoid this by using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).

Here's how we can apply it to our problem.

Expand All @@ -330,13 +349,13 @@ Here's how we can apply it to our problem.
@jax.jit
def compute_max_vmap(grid):
# Construct a function that takes the max over all x for given y
f_vec_x_max = lambda y: jnp.max(f(grid, y))
compute_column_max = lambda y: jnp.max(f(grid, y))
# Vectorize the function so we can call on all y simultaneously
f_vec_max = jax.vmap(f_vec_x_max)
# Compute the max across x at every y
maxes = f_vec_max(grid)
# Compute the max of the maxes and return
return jnp.max(maxes)
vectorized_compute_column_max = jax.vmap(compute_column_max)
# Compute the column max at every row
column_maxes = vectorized_compute_column_max(grid)
# Compute the max of the column maxes and return
return jnp.max(column_maxes)
```

Note that we never create
Expand All @@ -345,6 +364,8 @@ Note that we never create
* the two-dimensional grid `y_mesh` or
* the two-dimensional array `f(x,y)`

Like Numba, we just use the flat array `grid`.

And because everything is under a single `@jax.jit`, the compiler can fuse
all operations into one optimized kernel.

Expand Down Expand Up @@ -378,18 +399,14 @@ In our view, JAX is the winner for vectorized operations.
It dominates NumPy both in terms of speed (via JIT-compilation and
parallelization) and memory efficiency (via vmap).

Moreover, the `vmap` approach can sometimes lead to significantly clearer code.

While Numba is impressive, the beauty of JAX is that, with fully vectorized
operations, we can run exactly the same code on machines with hardware
accelerators and reap all the benefits without extra effort.
It also dominates Numba when run on the GPU.

Moreover, JAX already knows how to effectively parallelize many common array
operations, which is key to fast execution.

For most cases encountered in economics, econometrics, and finance, it is
```{note}
Numba can support GPU programming through `numba.cuda` but then we need to
parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is
far better to hand over to the JAX compiler for efficient parallelization than to
try to hand code these routines ourselves.
try to hand-code these routines ourselves.
```


## Sequential operations
Expand Down Expand Up @@ -554,8 +571,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or
While JAX's `at[t].set` syntax does allow element-wise updates, the overall code
remains harder to read than the Numba equivalent.

For this type of sequential operation, Numba is the clear winner in terms of
code clarity and ease of implementation.


## Overall recommendations
Expand All @@ -573,25 +588,17 @@ than traditional meshgrid-based vectorization.
In addition, JAX functions are automatically differentiable, as we explore in
{doc}`autodiff`.

For **sequential operations**, Numba has clear advantages.
For **sequential operations**, Numba has nicer syntax.

The code is natural and readable --- just a Python loop with a decorator ---
and performance is excellent.

JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but
the syntax is less intuitive.

```{note}
One important advantage of `lax.fori_loop` and `lax.scan` is that they
support automatic differentiation through the loop, which Numba cannot do.
If you need to differentiate through a sequential computation (e.g., computing
sensitivities of a trajectory to model parameters), JAX is the better choice
despite the less natural syntax.
```
On the other hand, the JAX versions support automatic differentiation.

In practice, many problems involve a mix of both patterns.
That might be of interest if, say, we want to compute sensitivities of a
trajectory to model parameters

A good rule of thumb: default to JAX for new projects, especially when
hardware acceleration or differentiability might be useful, and reach for Numba
when you have a tight sequential loop that needs to be fast and readable.

Loading