diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index fd119c6d..fed7bcb3 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -15,6 +15,9 @@ kernelspec: This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). +```{include} _admonition/gpu.md +``` + JAX is a high-performance scientific computing library that provides * a [NumPy](https://en.wikipedia.org/wiki/NumPy)-like interface that can automatically parallelize across CPUs and GPUs, @@ -33,9 +36,19 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install jax quantecon ``` -```{include} _admonition/gpu.md +We'll use the following imports + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import quantecon as qe ``` +Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface. + + ## JAX as a NumPy Replacement One of the attractive features of JAX is that, whenever possible, its array @@ -47,17 +60,6 @@ Let's look at the similarities and differences between JAX and NumPy. ### Similarities -We'll use the following imports - -```{code-cell} ipython3 -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import quantecon as qe -``` - -Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface. Here are some standard array operations using `jnp`: @@ -73,10 +75,6 @@ print(a) print(jnp.sum(a)) ``` -```{code-cell} ipython3 -print(jnp.mean(a)) -``` - ```{code-cell} ipython3 print(jnp.dot(a, a)) ``` @@ -91,30 +89,12 @@ a type(a) ``` -Even scalar-valued maps on arrays return JAX arrays. +Even scalar-valued maps on arrays return JAX arrays rather than scalars! ```{code-cell} ipython3 jnp.sum(a) ``` -Operations on higher dimensional arrays are also similar to NumPy: - -```{code-cell} ipython3 -A = jnp.ones((2, 2)) -B = jnp.identity(2) -A @ B -``` - -JAX's array interface also provides the `linalg` subpackage: - -```{code-cell} ipython3 -jnp.linalg.inv(B) # Inverse of identity is identity -``` - -```{code-cell} ipython3 -eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors -eigvals -``` ### Differences @@ -137,6 +117,7 @@ Let's try with NumPy ```{code-cell} with qe.Timer(): + # First NumPy timing y = np.cos(x) ``` @@ -144,6 +125,7 @@ And one more time. ```{code-cell} with qe.Timer(): + # Second NumPy timing y = np.cos(x) ``` @@ -165,7 +147,9 @@ Let's time the same procedure. ```{code-cell} with qe.Timer(): + # First run y = jnp.cos(x) + # Hold the interpreter until the array operation finishes jax.block_until_ready(y); ``` @@ -184,7 +168,9 @@ And let's time it again. ```{code-cell} with qe.Timer(): + # Second run y = jnp.cos(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -212,14 +198,18 @@ x = jnp.linspace(0, 10, n + 1) ```{code-cell} with qe.Timer(): + # First run y = jnp.cos(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second run y = jnp.cos(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -294,7 +284,7 @@ functional programming style, which we discuss below. #### A workaround -We note that JAX does provide a version of in-place array modification +We note that JAX does provide an alternative to in-place array modification using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). ```{code-cell} ipython3 @@ -387,11 +377,26 @@ This pure version makes all dependencies explicit through function arguments, an ### Why Functional Programming? -JAX represents functions as computational graphs, which are then compiled or transformed (e.g., differentiated) +At QuantEcon we love pure functions because they + +* Help testing: each function can operate in isolation +* Promote deterministic behavior and hence reproducibility +* Prevent bugs that arise from mutating shared state + +The JAX compiler loves pure functions and functional programming because + +* Data dependencies are explicit, which helps with optimizing complex computations +* Pure functions are easier to differentiate (autodiff) +* Pure functions are easier to parallelize and optimize (don't depend on shared mutable state) + +Another way to think of this is as follows: + +JAX represents functions as computational graphs, which are then compiled or +transformed (e.g., differentiated) These computational graphs describe how a given set of inputs is transformed into an output. -They are pure by construction. +JAX's computational graphs are pure by construction. JAX uses a functional programming style so that user-built functions map directly into the graph-theoretic representations supported by JAX. @@ -520,8 +525,8 @@ plt.tight_layout() plt.show() ``` -This syntax will seem unusual for a NumPy or Matlab user --- but will make a lot -of sense when we progress to parallel programming. +This syntax will seem unusual for a NumPy or Matlab user --- but will make more +sense when we get to parallel programming. The function below produces `k` (quasi-) independent random `n x n` matrices using `split`. @@ -664,6 +669,7 @@ x = np.linspace(0, 10, n) ```{code-cell} with qe.Timer(): + # Time NumPy code y = f(x) ``` @@ -679,27 +685,31 @@ As a first pass, we replace `np` with `jnp` throughout: def f(x): y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2 return y -``` -Now let's time it. -```{code-cell} x = jnp.linspace(0, 10, n) ``` +Now let's time it. + ```{code-cell} with qe.Timer(): + # First call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` -The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. +The outcome is similar to the `cos` example --- JAX is faster, especially on the +second run after JIT compilation. However, with JAX, we have another trick up our sleeve --- we can JIT-compile the *entire* function, not just individual operations. @@ -718,13 +728,17 @@ f_jax = jax.jit(f) ```{code-cell} with qe.Timer(): + # First run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -734,7 +748,6 @@ allowing the compiler to optimize more aggressively. For example, the compiler can eliminate multiple calls to the hardware accelerator and the creation of a number of intermediate arrays. - Incidentally, a more common syntax when targeting a function for the JIT compiler is @@ -811,21 +824,6 @@ f(x) Moral of the story: write pure functions when using JAX! -### Summary - -Now we can see why both developers and compilers benefit from pure functions. - -We love pure functions because they - -* Help testing: each function can operate in isolation -* Promote deterministic behavior and hence reproducibility -* Prevent bugs that arise from mutating shared state - -The compiler loves pure functions and functional programming because - -* Data dependencies are explicit, which helps with optimizing complex computations -* Pure functions are easier to differentiate (autodiff) -* Pure functions are easier to parallelize and optimize (don't depend on shared mutable state) ## Vectorization with `vmap` @@ -838,18 +836,18 @@ This avoids the need to manually write vectorized code or use explicit loops. ### A simple example -Suppose we have a function that computes summary statistics for a single array: +Suppose we have a function that computes the difference between mean and median for an array of numbers. ```{code-cell} ipython3 -def summary(x): - return jnp.mean(x), jnp.median(x) +def mm_diff(x): + return jnp.mean(x) - jnp.median(x) ``` We can apply it to a single vector: ```{code-cell} ipython3 x = jnp.array([1.0, 2.0, 5.0]) -summary(x) +mm_diff(x) ``` Now suppose we have a matrix and want to compute these statistics for each row. @@ -862,7 +860,7 @@ X = jnp.array([[1.0, 2.0, 5.0], [1.0, 8.0, 9.0]]) for row in X: - print(summary(row)) + print(mm_diff(row)) ``` However, Python loops are slow and cannot be efficiently compiled or @@ -872,11 +870,11 @@ Using `vmap` keeps the computation on the accelerator and composes with other JAX transformations like `jit` and `grad`: ```{code-cell} ipython3 -batch_summary = jax.vmap(summary) -batch_summary(X) +batch_mm_diff = jax.vmap(mm_diff) +batch_mm_diff(X) ``` -The function `summary` was written for a single array, and `vmap` automatically +The function `mm_diff` was written for a single array, and `vmap` automatically lifted it to operate row-wise over a matrix --- no loops, no reshaping. ### Combining transformations @@ -886,8 +884,8 @@ One of JAX's strengths is that transformations compose naturally. For example, we can JIT-compile a vectorized function: ```{code-cell} ipython3 -fast_batch_summary = jax.jit(jax.vmap(summary)) -fast_batch_summary(X) +fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff)) +fast_batch_mm_diff(X) ``` This composition of `jit`, `vmap`, and (as we'll see next) `grad` is central to