diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index bbee313..34786c5 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d -synced-at: "2026-04-09" +source-sha: 4c025df0d52a5b6546938952294776d4bd4908ce +synced-at: "2026-04-10" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 81ed1eb..9731ffe 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -15,33 +15,29 @@ translation: JAX as a NumPy Replacement: JAX 作为 NumPy 的替代品 JAX as a NumPy Replacement::Similarities: 相似之处 JAX as a NumPy Replacement::Differences: 差异 + JAX as a NumPy Replacement::Differences::Speed!: 速度! + JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy + JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 JAX as a NumPy Replacement::Differences::A workaround: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 Functional Programming::Examples: 示例 + Functional Programming::Why Functional Programming?: 为什么使用函数式编程? Random numbers: 随机数 Random numbers::Random number generation: 随机数生成 Random numbers::Why explicit random state?: 为什么要显式随机状态? Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 - JIT compilation: JIT 编译 - JIT compilation::A simple example: 一个简单的示例 - JIT compilation::A simple example::With NumPy: 使用 NumPy - JIT compilation::A simple example::With JAX: 使用 JAX - JIT compilation::A simple example::Changing array sizes: 更改数组大小 - JIT compilation::Evaluating a more complicated function: 评估更复杂的函数 - JIT compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy - JIT compilation::Evaluating a more complicated function::With JAX: 使用 JAX - JIT compilation::How JIT compilation works: JIT 编译的工作原理 - JIT compilation::Compiling the whole function: 编译整个函数 - JIT compilation::Compiling non-pure functions: 编译非纯函数 - JIT compilation::Summary: 总结 - Vectorization with `vmap`: 使用 `vmap` 进行向量化 - Vectorization with `vmap`::A simple example: 一个简单的示例 - Vectorization with `vmap`::Combining transformations: 组合变换 - 'Automatic differentiation: a preview': 自动微分:预览 + JIT Compilation: JIT 编译 + JIT Compilation::Evaluating a more complicated function: 评估更复杂的函数 + JIT Compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy + JIT Compilation::Evaluating a more complicated function::With JAX: 使用 JAX + JIT Compilation::Compiling the whole function: 编译整个函数 + JIT Compilation::How JIT compilation works: JIT 编译的工作原理 + JIT Compilation::Compiling non-pure functions: 编译非纯函数 + JIT Compilation::Summary: 总结 Exercises: 练习 --- @@ -84,14 +80,13 @@ JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情 import jax import jax.numpy as jnp import matplotlib.pyplot as plt -import matplotlib.patches as mpatches +import matplotlib as mpl +import matplotlib.font_manager +FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" +mpl.font_manager.fontManager.addfont(FONTPATH) +mpl.rcParams['font.family'] = ['Source Han Serif SC'] import numpy as np import quantecon as qe -import matplotlib as mpl # i18n -import matplotlib.font_manager # i18n -FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" # i18n -mpl.font_manager.fontManager.addfont(FONTPATH) # i18n -mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ``` 注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。 @@ -149,16 +144,116 @@ jnp.linalg.inv(B) # Inverse of identity is identity ``` ```{code-cell} ipython3 -jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors +eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors +eigvals ``` ### 差异 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 +(jax_speed)= +#### 速度! + +假设我们想在许多点上求余弦函数的值。 + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) +``` + +##### 使用 NumPy + +让我们先用 NumPy 试试: + +```{code-cell} +with qe.Timer(): + y = np.cos(x) +``` + +再来一次。 + +```{code-cell} +with qe.Timer(): + y = np.cos(x) +``` + +这里: + +* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数 +* 该二进制文件在本地机器的 CPU 上运行 + +##### 使用 JAX + +现在让我们用 JAX 试试。 + +```{code-cell} +x = jnp.linspace(0, 10, n) +``` + +对相同的过程计时。 + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +```{note} +这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。 + +这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。 + +对于非计时代码,可以省略包含 `block_until_ready` 的那行。 +``` + +再计时一次。 + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。 + +此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。 + +这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。 + +为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本? + +原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。 + +大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。 + +我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。 + +```{code-cell} +x = jnp.linspace(0, 10, n + 1) +``` + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。 + +这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。 + +关于 JIT 编译的进一步讨论见下文。 + #### 精度 -NumPy 和 JAX 之间的一个差异是 JAX 默认使用 32 位浮点数。 +NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。 这是因为 JAX 经常用于 GPU 计算,而大多数 GPU 计算使用 32 位浮点数。 @@ -196,7 +291,7 @@ a[0] = 1 a ``` -在 JAX 中,这会失败: +在 JAX 中,这会失败! ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -204,9 +299,11 @@ a ``` ```{code-cell} ipython3 -:tags: [raises-exception] +try: + a[0] = 1 +except Exception as e: + print(e) -a[0] = 1 ``` 与不可变性一致,JAX 不支持原地操作: @@ -257,7 +354,7 @@ a *当在意大利乡间漫步时,当地人会毫不犹豫地告诉你 JAX 有"una anima di pura programmazione funzionale"(纯函数式编程的灵魂)。* -换句话说,JAX 假设采用函数式编程风格。 +换句话说,JAX 假设采用 [函数式编程](https://en.wikipedia.org/wiki/Functional_programming) 风格。 ### 纯函数 @@ -313,15 +410,23 @@ def add_tax_pure(prices, tax_rate): 这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。 -现在我们理解了什么是纯函数,让我们探索 JAX 处理随机数的方法如何维护这种纯粹性。 +### 为什么使用函数式编程? + +JAX 将函数表示为计算图,然后对其进行编译或变换(例如,求导)。 + +这些计算图描述了给定的一组输入如何被转换为输出。 + +它们在构造上就是纯粹的。 + +JAX 使用函数式编程风格,使得用户构建的函数能够直接映射到 JAX 所支持的图论表示中。 ## 随机数 -与 NumPy 或 Matlab 中的随机数相比,JAX 中的随机数有很大不同。 +JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。 起初,您可能会觉得语法相当冗长。 -但您很快就会意识到,为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的。 +但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的。 此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。 @@ -430,7 +535,7 @@ ax.text(0, 0.65, "split", ha='center', va='center', fontsize=9, ax.text(3, -0.5, "⋮", ha='center', va='center', fontsize=14) -ax.set_title("PRNG Key Splitting Tree", fontsize=13, pad=10) +ax.set_title("PRNG 密钥拆分树", fontsize=13, pad=10) plt.tight_layout() plt.show() ``` @@ -484,7 +589,7 @@ matrices = gen_random_matrices(key) #### NumPy 的方法 -在 NumPy 中,随机数生成通过维护隐藏的全局状态来工作。 +在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。 每次我们调用随机函数时,这个状态都会被更新: @@ -650,7 +755,7 @@ with qe.Timer(): ### 评估更复杂的函数 -让我们用一个更复杂的函数尝试同样的操作。 +考虑如下函数: ```{code-cell} def f(x): @@ -704,64 +809,11 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -此外,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。 - -### JIT 编译的工作原理 - -当我们对一个函数应用 `jax.jit` 时,JAX 会对其进行*追踪*:它不会立即执行操作,而是将操作序列记录为计算图,并将该计算图交给 [XLA](https://openxla.org/xla) 编译器。 - -XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。 - -下图展示了一个简单函数的编译流程: - -```{code-cell} ipython3 -:tags: [hide-input] - -fig, ax = plt.subplots(figsize=(7, 2)) -ax.set_xlim(-0.2, 7.2) -ax.set_ylim(0.2, 2.2) -ax.axis('off') - -# Boxes for pipeline stages -stages = [ - (0.7, 1.2, "Python\nfunction"), - (2.6, 1.2, "computational\ngraph"), - (4.5, 1.2, "optimized\nkernel"), - (6.4, 1.2, "fast\nexecution"), -] - -colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"] - -for (x, y, label), color in zip(stages, colors): - box = mpatches.FancyBboxPatch( - (x - 0.7, y - 0.5), 1.4, 1.0, - boxstyle="round,pad=0.15", - facecolor=color, edgecolor="black", linewidth=1.5) - ax.add_patch(box) - ax.text(x, y, label, ha='center', va='center', fontsize=9) - -# Arrows with labels -arrows = [ - (1.4, 1.9, "trace"), - (3.3, 3.8, "XLA"), - (5.2, 5.7, "run"), -] - -for x_start, x_end, label in arrows: - ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2), - arrowprops=dict(arrowstyle="->", lw=1.5, color="gray")) - ax.text((x_start + x_end) / 2, 1.55, label, - ha='center', fontsize=8, color='gray') - -plt.tight_layout() -plt.show() -``` - -对 JIT 编译函数的第一次调用会产生编译开销,但后续使用相同输入形状和类型的调用将复用缓存的编译代码,以全速运行。 +然而,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。 ### 编译整个函数 -JAX 即时(JIT)编译器可以通过将线性代数运算融合到单个优化内核中来加速函数内部的执行。 +JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。 让我们用函数 `f` 来试试这个: @@ -794,6 +846,14 @@ def f(x): pass # put function body here ``` +### JIT 编译的工作原理 + +当我们对一个函数应用 `jax.jit` 时,JAX 会对其进行*追踪*:它不会立即执行操作,而是将操作序列记录为计算图,并将该图交给 [XLA](https://openxla.org/xla) 编译器。 + +XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。 + +对 JIT 编译函数的第一次调用会产生编译开销,但对于具有相同输入形状和类型的后续调用,将重用缓存的编译代码并以全速运行。 + ### 编译非纯函数 现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。 @@ -858,97 +918,6 @@ f(x) * 纯函数更容易进行微分(自动微分) * 纯函数更容易并行化和优化(不依赖于共享可变状态) -## 使用 `vmap` 进行向量化 - -另一个强大的 JAX 变换是 `jax.vmap`,它能自动将针对单个输入编写的函数 向量化,使其可以在批量数据上运行。 - -这避免了手动编写向量化代码或使用显式循环的需要。 - -### 一个简单的示例 - -假设我们有一个函数,用于计算单个数组的汇总统计量: - -```{code-cell} ipython3 -def summary(x): - return jnp.mean(x), jnp.median(x) -``` - -我们可以将其应用于单个向量: - -```{code-cell} ipython3 -x = jnp.array([1.0, 2.0, 5.0]) -summary(x) -``` - -现在假设我们有一个矩阵,并希望对每一行计算这些统计量。 - -不使用 `vmap` 时,我们需要一个显式循环: - -```{code-cell} ipython3 -X = jnp.array([[1.0, 2.0, 5.0], - [4.0, 5.0, 6.0], - [1.0, 8.0, 9.0]]) - -for row in X: - print(summary(row)) -``` - -然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。 - -使用 `vmap` 可以让计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用: - -```{code-cell} ipython3 -batch_summary = jax.vmap(summary) -batch_summary(X) -``` - -函数 `summary` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵——无需循环,无需重塑。 - -### 组合变换 - -JAX 的一大优势在于变换可以自然地组合。 - -例如,我们可以对向量化函数进行 JIT 编译: - -```{code-cell} ipython3 -fast_batch_summary = jax.jit(jax.vmap(summary)) -fast_batch_summary(X) -``` - -`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习中尤为强大。 - -## 自动微分:预览 - -JAX 可以使用自动微分来计算梯度。 - -这对于优化和求解非线性系统非常有用。 - -以下是一个涉及函数 $f(x) = x^2 / 2$ 的简单示例: - -```{code-cell} ipython3 -def f(x): - return (x**2) / 2 - -f_prime = jax.grad(f) -``` - -```{code-cell} ipython3 -f_prime(10.0) -``` - -让我们绘制函数及其导数,注意 $f'(x) = x$。 - -```{code-cell} ipython3 -fig, ax = plt.subplots() -x_grid = jnp.linspace(-4, 4, 200) -ax.plot(x_grid, f(x_grid), label="$f$") -ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") -ax.legend(loc='upper center') -plt.show() -``` - -自动微分是一个深刻的话题,在经济学和金融领域有许多应用。我们在 {doc}`关于自动微分的讲座 ` 中提供了更为深入的介绍。 - ## 练习