diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 34786c5..663bfe6 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 4c025df0d52a5b6546938952294776d4bd4908ce -synced-at: "2026-04-10" +source-sha: 8d73de367a7f160dac777aa557f1c26069f84ea5 +synced-at: "2026-04-12" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 9731ffe..e50faa0 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -45,13 +45,16 @@ translation: 本讲座简要介绍 [Google JAX](https://github.com/jax-ml/jax)。 +```{include} _admonition/gpu.md +``` + JAX 是一个高性能科学计算库,提供以下功能: * 类似 [NumPy](https://en.wikipedia.org/wiki/NumPy) 的接口,可以在 CPU 和 GPU 上自动并行化, * 一个即时编译器,用于加速大量数值运算,以及 * [自动微分](https://en.wikipedia.org/wiki/Automatic_differentiation)。 -JAX 也在日益维护和提供[更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。 +JAX 也在日益维护和提供 [更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。 除了 Anaconda 中已有的内容外,本讲座还需要以下库: @@ -61,36 +64,28 @@ JAX 也在日益维护和提供[更多专业化的科学计算例程](https://do !pip install jax quantecon ``` -```{include} _admonition/gpu.md -``` - -## JAX 作为 NumPy 的替代品 - -JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。 - -这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。 - -让我们来看看 JAX 和 NumPy 之间的异同。 - -### 相似之处 - 我们将使用以下导入: ```{code-cell} ipython3 import jax import jax.numpy as jnp import matplotlib.pyplot as plt -import matplotlib as mpl -import matplotlib.font_manager -FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" -mpl.font_manager.fontManager.addfont(FONTPATH) -mpl.rcParams['font.family'] = ['Source Han Serif SC'] import numpy as np import quantecon as qe ``` 注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。 +## JAX 作为 NumPy 的替代品 + +JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。 + +这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。 + +让我们来看看 JAX 和 NumPy 之间的异同。 + +### 相似之处 + 以下是使用 `jnp` 进行的一些标准数组操作: ```{code-cell} ipython3 @@ -105,10 +100,6 @@ print(a) print(jnp.sum(a)) ``` -```{code-cell} ipython3 -print(jnp.mean(a)) -``` - ```{code-cell} ipython3 print(jnp.dot(a, a)) ``` @@ -123,34 +114,124 @@ a type(a) ``` -即使是数组上的标量值映射也会返回 JAX 数组。 +即使是数组上的标量值映射也会返回 JAX 数组,而不是标量! ```{code-cell} ipython3 jnp.sum(a) ``` -对高维数组的操作也与 NumPy 类似: +### 差异 -```{code-cell} ipython3 -A = jnp.ones((2, 2)) -B = jnp.identity(2) -A @ B +现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 + +(jax_speed)= +#### 速度! + +假设我们想在许多点上计算余弦函数。 + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) ``` -JAX 的数组接口也提供了 `linalg` 子包: +##### 使用 NumPy -```{code-cell} ipython3 -jnp.linalg.inv(B) # Inverse of identity is identity +让我们先用 NumPy 试试: + +```{code-cell} +with qe.Timer(): + # First NumPy timing + y = np.cos(x) ``` -```{code-cell} ipython3 -eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors -eigvals +再来一次。 + +```{code-cell} +with qe.Timer(): + # Second NumPy timing + y = np.cos(x) ``` -### 差异 +这里 -现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 +* NumPy 使用预编译的二进制文件对浮点数数组应用余弦函数 +* 该二进制文件在本地机器的 CPU 上运行 + +##### 使用 JAX + +现在让我们用 JAX 试试。 + +```{code-cell} +x = jnp.linspace(0, 10, n) +``` + +让我们对相同的过程计时。 + +```{code-cell} +with qe.Timer(): + # First run + y = jnp.cos(x) + # Hold the interpreter until the array operation finishes + jax.block_until_ready(y); +``` + +```{note} +这里,为了测量实际速度,我们使用 `block_until_ready` 方法来阻塞解释器,直到计算结果返回。 + +这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。 + +对于非计时代码,可以删除包含 `block_until_ready` 的那一行。 +``` + +再来计时一次。 + +```{code-cell} +with qe.Timer(): + # Second run + y = jnp.cos(x) + # Hold interpreter + jax.block_until_ready(y); +``` + +在 GPU 上,此代码的运行速度远快于其 NumPy 等效代码。 + +此外,通常第二次运行比第一次更快,这是由于 JIT 编译的缘故。 + +这是因为即使是像 `jnp.cos` 这样的内置函数也是经过 JIT 编译的——第一次运行包含了编译时间。 + +为什么 JAX 要对像 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本? + +原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。 + +大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。 + +我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。 + +```{code-cell} +x = jnp.linspace(0, 10, n + 1) +``` + +```{code-cell} +with qe.Timer(): + # First run + y = jnp.cos(x) + # Hold interpreter + jax.block_until_ready(y); +``` + +```{code-cell} +with qe.Timer(): + # Second run + y = jnp.cos(x) + # Hold interpreter + jax.block_until_ready(y); +``` + +运行时间先增加后减少(这在 GPU 上会更明显)。 + +这与上面的讨论一致——更改数组大小后的第一次运行显示了编译开销。 + +关于 JIT 编译的进一步讨论见下文。 (jax_speed)= #### 速度! @@ -303,30 +384,13 @@ try: a[0] = 1 except Exception as e: print(e) - ``` -与不可变性一致,JAX 不支持原地操作: - -```{code-cell} ipython3 -a = np.array((2, 1)) -a.sort() # Unlike NumPy, does not mutate a -a -``` - -```{code-cell} ipython3 -a = jnp.array((2, 1)) -a_new = a.sort() # Instead, the sort method returns a new sorted array -a, a_new -``` - -JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [函数式编程](https://en.wikipedia.org/wiki/Functional_programming) 风格。 - -这个设计选择有重要的含义,我们接下来将对此进行探讨! +JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。 #### 变通方法 -我们注意到 JAX 确实提供了一种使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 进行原地数组修改的版本。 +我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html)。 ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -410,15 +474,29 @@ def add_tax_pure(prices, tax_rate): 这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。 -### 为什么使用函数式编程? +### 为什么要函数式编程? -JAX 将函数表示为计算图,然后对其进行编译或变换(例如,求导)。 +在 QuantEcon,我们热爱纯函数,因为它们: + +* 有助于测试:每个函数可以独立运行 +* 促进确定性行为,从而提高可重复性 +* 防止由于修改共享状态而产生的错误 + +JAX 编译器热爱纯函数和函数式编程,因为: + +* 数据依赖关系是显式的,有助于优化复杂计算 +* 纯函数更易于微分(自动微分) +* 纯函数更易于并行化和优化(不依赖于共享的可变状态) + +另一种理解方式如下: + +JAX 将函数表示为计算图,然后对其进行编译或变换(例如,微分)。 这些计算图描述了给定的一组输入如何被转换为输出。 -它们在构造上就是纯粹的。 +JAX 的计算图在构造上是纯粹的。 -JAX 使用函数式编程风格,使得用户构建的函数能够直接映射到 JAX 所支持的图论表示中。 +JAX 使用函数式编程风格,以便用户构建的函数能够直接映射到 JAX 所支持的图论表示中。 ## 随机数 @@ -651,111 +729,13 @@ JAX 的显式性带来了显著的好处: JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。 -### 一个简单的示例 - -假设我们想在许多点上求余弦函数的值。 - -```{code-cell} -n = 50_000_000 -x = np.linspace(0, 10, n) -``` - -#### 使用 NumPy - -让我们先用 NumPy 试试: - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -再试一次。 - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -这里 NumPy 使用预先构建的二进制文件,该文件由精心编写的低级代码编译而成,用于对浮点数数组应用余弦函数。 - -这个二进制文件随 NumPy 一起发布。 - -#### 使用 JAX - -现在让我们用 JAX 试试。 - -```{code-cell} -x = jnp.linspace(0, 10, n) -``` - -让我们对相同的过程计时。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{note} -这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。 - -这是必要的,因为 JAX 使用异步调度,允许 Python 解释器领先于数值计算。 - -对于非计时代码,您可以删除包含 `block_until_ready` 的那一行。 -``` - - -再次计时。 - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。 +我们在 {ref}`上文 ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。 -此外,通常第二次运行比第一次快,这是由于 JIT 编译的原因。 - -这是因为即使是像 `jnp.cos` 这样的内置函数也经过了 JIT 编译——第一次运行包含了编译时间。 - -为什么 JAX 要对像 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样提供预编译版本? - -原因是 JIT 编译器希望针对正在使用的数组的*大小*(以及数据类型)进行专门优化。 - -大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。 - -这就是为什么 JAX 要等到看到数组大小后再进行编译——这需要 JIT 编译方法,而不是提供预编译的二进制文件。 - -#### 更改数组大小 - -这里我们更改输入大小并观察运行时间。 - -```{code-cell} -x = jnp.linspace(0, 10, n + 1) -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -通常,运行时间会先增加然后再减少(这在 GPU 上会更加明显)。 - -这是因为 JIT 编译器针对数组大小进行专门优化以利用并行化——因此当数组大小改变时会生成新的编译代码。 +让我们用一个更复杂的函数尝试同样的操作。 ### 评估更复杂的函数 -考虑如下函数: +考虑以下函数: ```{code-cell} def f(x): @@ -774,6 +754,7 @@ x = np.linspace(0, 10, n) ```{code-cell} with qe.Timer(): + # Time NumPy code y = f(x) ``` @@ -787,23 +768,26 @@ with qe.Timer(): def f(x): y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2 return y -``` -现在让我们计时。 -```{code-cell} x = jnp.linspace(0, 10, n) ``` +现在让我们计时。 + ```{code-cell} with qe.Timer(): + # First call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -823,13 +807,17 @@ f_jax = jax.jit(f) ```{code-cell} with qe.Timer(): + # First run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -837,7 +825,6 @@ with qe.Timer(): 例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。 - 顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是: ```{code-cell} ipython3 @@ -902,21 +889,64 @@ f(x) 这个故事的寓意:使用 JAX 时请编写纯函数! -### 总结 +## 使用 `vmap` 进行向量化 -现在我们可以理解为什么开发者和编译器都受益于纯函数。 +JAX 的另一个强大变换是 `jax.vmap`,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行。 -我们喜欢纯函数,因为它们: +这避免了手动编写向量化代码或使用显式循环的需要。 -* 有助于测试:每个函数可以独立运行 -* 促进确定性行为,从而实现可复现性 -* 防止由于修改共享状态而产生的错误 +### 一个简单的示例 -编译器喜欢纯函数和函数式编程,因为: +假设我们有一个函数,用于计算一组数字的均值与中位数之差。 -* 数据依赖关系是显式的,有助于优化复杂计算 -* 纯函数更容易进行微分(自动微分) -* 纯函数更容易并行化和优化(不依赖于共享可变状态) +```{code-cell} ipython3 +def mm_diff(x): + return jnp.mean(x) - jnp.median(x) +``` + +我们可以将其应用于单个向量: + +```{code-cell} ipython3 +x = jnp.array([1.0, 2.0, 5.0]) +mm_diff(x) +``` + +现在假设我们有一个矩阵,想要对每一行计算这些统计量。 + +不使用 `vmap` 时,我们需要显式循环: + +```{code-cell} ipython3 +X = jnp.array([[1.0, 2.0, 5.0], + [4.0, 5.0, 6.0], + [1.0, 8.0, 9.0]]) + +for row in X: + print(mm_diff(row)) +``` + +然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。 + +使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用: + +```{code-cell} ipython3 +batch_mm_diff = jax.vmap(mm_diff) +batch_mm_diff(X) +``` + +函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。 + +### 组合变换 + +JAX 的优势之一在于各变换可以自然地组合使用。 + +例如,我们可以对向量化函数进行 JIT 编译: + +```{code-cell} ipython3 +fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff)) +fast_batch_mm_diff(X) +``` + +`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。 ## 练习