From 2989842e6f2539eb2057ff3d5ce1b8b3b4528272 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 21:36:27 +0100 Subject: [PATCH 1/4] Update translation: lectures/jax_intro.md --- lectures/jax_intro.md | 260 ++++++++++++++---------------------------- 1 file changed, 87 insertions(+), 173 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 2f21d1f..0b0285d 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -21,16 +21,15 @@ translation: JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验 JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 - JAX as a NumPy Replacement::Differences::A workaround: 变通方法 + JAX as a NumPy Replacement::Differences::A Workaround: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 - Functional Programming::Examples: 示例 - Functional Programming::Why Functional Programming?: 为什么使用函数式编程? + Functional Programming::Examples -- Pure and Impure: 示例——纯函数与非纯函数 + Functional Programming::Why Functional Programming?: 为什么要函数式编程? Random numbers: 随机数 - Random numbers::Random number generation: 随机数生成 - Random numbers::Why explicit random state?: 为什么要显式随机状态? - Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 - Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 + Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法 + Random numbers::JAX: JAX + Random numbers::Benefits: 优势 JIT Compilation: JIT 编译 JIT Compilation::With NumPy: 使用 NumPy JIT Compilation::With JAX: 使用 JAX @@ -343,19 +342,20 @@ a * 不会改变全局状态 * 不会修改传递给函数的数据(不可变数据) -### 示例 +### 示例——纯函数与非纯函数 以下是一个*非纯*函数的示例: ```{code-cell} ipython3 tax_rate = 0.1 -prices = [10.0, 20.0] def add_tax(prices): for i, price in enumerate(prices): prices[i] = price * (1 + tax_rate) - print('Post-tax prices: ', prices) - return prices + +prices = [10.0, 20.0] +add_tax(prices) +prices ``` 这个函数不是纯函数,因为: @@ -366,15 +366,21 @@ def add_tax(prices): 以下是一个*纯*版本: ```{code-cell} ipython3 -tax_rate = 0.1 -prices = (10.0, 20.0) def add_tax_pure(prices, tax_rate): new_prices = [price * (1 + tax_rate) for price in prices] return new_prices + +tax_rate = 0.1 +prices = (10.0, 20.0) +after_tax_prices = add_tax_pure(prices, tax_rate) +after_tax_prices ``` -这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。 +这是纯函数,因为: + +* 所有依赖关系通过函数参数显式表达 +* 不修改任何外部状态 ### 为什么要函数式编程? @@ -404,13 +410,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射 JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。 -起初,您可能会觉得语法相当冗长。 +### NumPy / MATLAB 方法 -但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的。 +在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作。 -此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。 +```{code-cell} ipython3 +np.random.seed(42) +print(np.random.randn(2)) +``` -### 随机数生成 +每次我们调用随机函数时,隐藏状态都会被更新: + +```{code-cell} ipython3 +print(np.random.randn(2)) +``` + +这个函数*不是纯函数*,因为: + +* 它是非确定性的:相同的输入,不同的输出 +* 它有副作用:它修改了全局随机数生成器状态 + +这在并行化下是危险的——必须仔细控制每个线程中发生的事情。 + +### JAX 在 JAX 中,随机数生成器的状态被显式控制。 @@ -525,115 +547,50 @@ plt.show() 下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。 ```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): +def gen_random_matrices( + key, # JAX key for random numbers + n=2, # Matrices will be n x n + k=3 # Number of matrices to generate + ): matrices = [] for _ in range(k): key, subkey = jax.random.split(key) A = jax.random.uniform(subkey, (n, n)) matrices.append(A) - print(A) return matrices ``` ```{code-cell} ipython3 seed = 42 key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -我们也可以在循环迭代时使用 `fold_in`: - -```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): - matrices = [] - for i in range(k): - step_key = jax.random.fold_in(key, i) - A = jax.random.uniform(step_key, (n, n)) - matrices.append(A) - print(A) - return matrices +gen_random_matrices(key) ``` -```{code-cell} ipython3 -key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -### 为什么要显式随机状态? - -为什么 JAX 需要这种相对冗长的随机数生成方法? - -一个原因是为了维护纯函数。 - -让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。 - -#### NumPy 的方法 - -在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。 - -每次我们调用随机函数时,这个状态都会被更新: - -```{code-cell} ipython3 -np.random.seed(42) -print(np.random.randn()) # Updates state of random number generator -print(np.random.randn()) # Updates state of random number generator -``` - -每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。 - -这个函数*不是纯函数*,因为: - -* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出 -* 它有副作用:它修改了全局随机数生成器状态 - -#### JAX 的方法 - -如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。 - -例如: - -```{code-cell} ipython3 -def random_sum_jax(key): - key1, key2 = jax.random.split(key) - x = jax.random.normal(key1) - y = jax.random.normal(key2) - return x + y -``` - -使用相同的密钥,我们总是得到相同的结果: - -```{code-cell} ipython3 -key = jax.random.key(42) -random_sum_jax(key) -``` +这个函数是*纯函数* -```{code-cell} ipython3 -random_sum_jax(key) -``` - -要获得新的抽取,我们需要提供一个新密钥。 - -函数 `random_sum_jax` 是纯函数,因为: - -* 它是确定性的:相同的密钥总是产生相同的输出 +* 确定性:相同的输入,相同的输出 * 无副作用:没有隐藏状态被修改 -JAX 的显式性带来了显著的好处: +### 优势 + +如上所述,这种显式性是有价值的: * 可复现性:通过重用密钥轻松重现结果 -* 并行化:每个线程可以拥有自己的密钥而不会产生冲突 -* 调试:没有隐藏状态使代码更容易推理 +* 并行化:控制每个独立线程中发生的事情 +* 调试:没有隐藏状态使代码更容易测试 * JIT 兼容性:编译器可以更积极地优化纯函数 -最后一点将在下一节中进行扩展。 - ## JIT 编译 JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。 我们在 {ref}`上文 ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。 -让我们用一个更复杂的函数尝试同样的操作: +这里我们研究针对更复杂函数的 JIT 编译。 + +### 使用 NumPy + +我们先用 NumPy 试试: ```{code-cell} def f(x): @@ -641,9 +598,7 @@ def f(x): return y ``` -### 使用 NumPy - -我们先用 NumPy 试试: +让我们用较大的 `x` 运行: ```{code-cell} n = 50_000_000 @@ -656,9 +611,17 @@ with qe.Timer(): y = f(x) ``` -### 使用 JAX +**即时执行**模型 + +* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。 -现在让我们用 JAX 再试一次。 +缺点 + +* 并行化程度极低 +* 内存占用大——产生许多中间数组 +* 大量内存读写 + +### 使用 JAX 作为第一步,我们将整个代码中的 `np` 替换为 `jnp`: @@ -691,11 +654,20 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。 +这是因为单个数组操作在 GPU 上并行化了。 + +但我们仍然在使用即时执行模式 + +* 由于中间数组导致大量内存占用 +* 大量内存读写 + +此外,在 GPU 上还启动了许多独立的内核。 ### 编译整个函数 -JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。 +幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。 + +编译器将所有数组操作融合到单个优化内核中。 让我们用函数 `f` 来试试这个: @@ -719,9 +691,12 @@ with qe.Timer(): jax.block_until_ready(y); ``` -运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。 +运行时间再次改善——现在是因为我们融合了所有操作 -例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。 +* 基于整个计算序列的激进优化 +* 消除对硬件加速器的多次调用 + +内存占用也大幅降低——不再创建中间数组。 顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是: @@ -741,11 +716,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP ### 编译非纯函数 -现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。 - -虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。 +虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测! -以下是一个使用全局变量的例子: +以下是一个示例: ```{code-cell} ipython3 a = 1 # global @@ -787,65 +760,6 @@ f(x) 这个故事的寓意:使用 JAX 时请编写纯函数! -## 使用 `vmap` 进行向量化 - -JAX 的另一个强大变换是 `jax.vmap`,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行。 - -这避免了手动编写向量化代码或使用显式循环的需要。 - -### 一个简单的示例 - -假设我们有一个函数,用于计算一组数字的均值与中位数之差。 - -```{code-cell} ipython3 -def mm_diff(x): - return jnp.mean(x) - jnp.median(x) -``` - -我们可以将其应用于单个向量: - -```{code-cell} ipython3 -x = jnp.array([1.0, 2.0, 5.0]) -mm_diff(x) -``` - -现在假设我们有一个矩阵,想要对每一行计算这些统计量。 - -不使用 `vmap` 时,我们需要显式循环: - -```{code-cell} ipython3 -X = jnp.array([[1.0, 2.0, 5.0], - [4.0, 5.0, 6.0], - [1.0, 8.0, 9.0]]) - -for row in X: - print(mm_diff(row)) -``` - -然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。 - -使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用: - -```{code-cell} ipython3 -batch_mm_diff = jax.vmap(mm_diff) -batch_mm_diff(X) -``` - -函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。 - -### 组合变换 - -JAX 的优势之一在于各变换可以自然地组合使用。 - -例如,我们可以对向量化函数进行 JIT 编译: - -```{code-cell} ipython3 -fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff)) -fast_batch_mm_diff(X) -``` - -`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。 - ## 练习 From 7c6e71675f62e9926b135cc030f004d8d451c50b Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 21:36:27 +0100 Subject: [PATCH 2/4] Update translation: .translate/state/jax_intro.md.yml --- .translate/state/jax_intro.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index e8f1756..09842f7 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f -synced-at: "2026-04-13" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 From c9ca8ac3b03d8e0020b01d3a1db6c73f2872aed1 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 21:36:28 +0100 Subject: [PATCH 3/4] Update translation: lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 275 +++++++++++++++++------------- 1 file changed, 152 insertions(+), 123 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index bb1a6e0..86eb02d 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -13,6 +13,7 @@ translation: Vectorized operations: 向量化运算 Vectorized operations::Problem Statement: 问题陈述 Vectorized operations::NumPy vectorization: NumPy 向量化 + Vectorized operations::Memory Issues: 内存问题 Vectorized operations::A Comparison with Numba: 与 Numba 的比较 Vectorized operations::Parallelized Numba: 并行化的 Numba Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码 @@ -21,6 +22,8 @@ translation: Sequential operations: 顺序运算 Sequential operations::Numba Version: Numba 版本 Sequential operations::JAX Version: JAX 版本 + Sequential operations::JAX Version::First Attempt: 第一种尝试 + Sequential operations::JAX Version::Second Attempt: 第二种尝试 Sequential operations::Summary: 总结 Overall recommendations: 总体建议 --- @@ -143,37 +146,75 @@ m = -np.inf for x in grid: for y in grid: z = f(x, y) - if z > m: - m = z + m = max(m, z) ``` ### NumPy 向量化 -如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。 +让我们切换到 NumPy 并使用更大的网格 -这里我们使用 `np.meshgrid` 来创建二维输入网格 `x` 和 `y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。 +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) # Large grid +``` -(这一策略可以追溯到 MATLAB。) +作为向量化的第一次尝试,我们可能会这样做 ```{code-cell} ipython3 +# Large grid +z = np.max(f(grid, grid)) # This is wrong! +``` + +这里的问题是 `f(grid, grid)` 不符合嵌套循环的逻辑。 + +就上图而言,它只计算了 `f` 在对角线上的值。 + +要让 NumPy 在每个 `x,y` 对上计算 `f(x,y)`,我们需要使用 `np.meshgrid`。 + +这里我们使用 `np.meshgrid` 来创建二维输入网格 `x` 和 `y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。 + +```{code-cell} ipython3 +# Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) + +x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): - z_max_numpy = np.max(f(x, y)) + z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works +``` + +在向量化版本中,所有循环都在编译后的代码中执行。 +`meshgrid` 的使用使我们能够复现嵌套的 for 循环。 + +输出结果应接近于一: + +```{code-cell} ipython3 print(f"NumPy result: {z_max_numpy:.6f}") ``` -在向量化版本中,所有循环都在编译后的代码中执行。 +### 内存问题 + +我们在合理的时间内得到了正确的解——但内存使用量巨大。 + +虽然扁平数组占用内存较少 + +```{code-cell} ipython3 +grid.nbytes +``` + +但网格矩阵是二维的,因此内存占用非常大 + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` -此外,NumPy 使用隐式多线程,因此至少会发生一定程度的并行化。 +此外,NumPy 的即时执行会创建许多相同大小的中间数组! -(并行化效率不高,因为二进制文件在看到数组 `x` 和 `y` 的大小之前就已经被编译了。) +在实际研究计算中,这种内存使用可能是一个大问题。 ### 与 Numba 的比较 -现在让我们看看能否使用简单循环的 Numba 获得更好的性能。 +让我们看看能否使用简单循环的 Numba 获得更好的性能。 ```{code-cell} ipython3 @numba.jit @@ -194,8 +235,6 @@ grid = np.linspace(-3, 3, 3_000) with qe.Timer(): # First run z_max_numba = compute_max_numba(grid) - -print(f"Numba result: {z_max_numba:.6f}") ``` 让我们再次运行以消除编译时间。 @@ -206,13 +245,13 @@ with qe.Timer(): compute_max_numba(grid) ``` -根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快。 +注意我们几乎不使用任何内存——我们只需要一维的 `grid`。 -在大多数情况下,我们发现 Numba 略胜一筹。 +此外,执行速度也很好。 -一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势。 +在大多数机器上,Numba 版本会比 NumPy 略快。 -另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 +原因在于高效的机器码加上更少的内存读写。 ### 并行化的 Numba @@ -238,8 +277,6 @@ def compute_max_numba_parallel(grid): with qe.Timer(): # First run z_max_parallel = compute_max_numba_parallel(grid) - -print(f"Numba result: {z_max_parallel:.6f}") ``` 以下是预编译版本的计时结果。 @@ -250,15 +287,19 @@ with qe.Timer(): compute_max_numba_parallel(grid) ``` -如果您有多个核心,您应该能在此处看到并行化带来的一定收益。 +如果您有多个核心,您应该能在此处看到并行化带来的收益。 + +让我们确认结果仍然正确(接近于一): + +```{code-cell} ipython3 +print(f"Numba result: {z_max_parallel:.6f}") +``` 对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。 ### 使用 JAX 的向量化代码 -表面上,JAX 中的向量化代码与 NumPy 代码类似。 - -但两者之间也存在一些差异,我们在这里加以强调。 +让我们尝试用 JAX 复现 NumPy 向量化方法。 让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit` @@ -269,7 +310,7 @@ def f(x, y): ``` -与 NumPy 一样,为了获得正确的形状和正确的嵌套 `for` 循环计算,我们可以使用专为此目的设计的 `meshgrid` 操作: +我们使用 NumPy 风格的 meshgrid 方法: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) @@ -304,84 +345,36 @@ with qe.Timer(): ### JAX 加 vmap -NumPy 代码和上述 JAX 代码都存在一个问题: - -虽然扁平数组占用内存较少 - -```{code-cell} ipython3 -grid.nbytes -``` - -但网格矩阵的内存占用很大 - -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` - -在实际研究计算中,这种额外的内存使用可能是一个大问题。 +由于我们在上面使用了 `jax.jit`,我们避免了创建许多中间数组。 -幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。 +但我们仍然创建了大数组 `z_max`、`x_mesh` 和 `y_mesh`。 -`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。 +幸运的是,我们可以通过使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 来避免这种情况。 以下是我们将其应用于当前问题的方式。 -```{code-cell} ipython3 -# 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y) -f_vec_x = lambda y: f(grid, y) -# 创建第二个函数,将此操作在所有 y 上向量化 -f_vec = jax.vmap(f_vec_x) -``` - -现在,当以扁平数组 `grid` 调用时,`f_vec` 将在每个 `x,y` 处计算 `f(x,y)`。 - -让我们看看计时结果: - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() - -print(f"JAX vmap v1 result: {z_max:.6f}") -``` - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() -``` - -通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间变化不大。 - -这很好——但我们还有进一步提升速度的空间! - -首先请注意,上面的代码计算了完整的二维数组 `f(x,y)`,这会产生开销,然后再取最大值。 - -其次,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。 - -我们可以通过将最大值操作移到内部并将所有内容包装在一个 `@jax.jit` 中来解决这两个问题: - ```{code-cell} ipython3 @jax.jit def compute_max_vmap(grid): - # 构建一个沿每行取最大值的函数 - f_vec_x_max = lambda y: jnp.max(f(grid, y)) - # 向量化该函数,以便我们可以同时对所有行调用 - f_vec_max = jax.vmap(f_vec_x_max) - # 调用向量化函数并取最大值 - return jnp.max(f_vec_max(grid)) + # Construct a function that takes the max over all x for given y + compute_column_max = lambda y: jnp.max(f(grid, y)) + # Vectorize the function so we can call on all y simultaneously + vectorized_compute_column_max = jax.vmap(compute_column_max) + # Compute the column max at every row + column_maxes = vectorized_compute_column_max(grid) + # Compute the max of the column maxes and return + return jnp.max(column_maxes) ``` -其中 - -* `f_vec_x_max` 计算任意给定行的最大值 -* `f_vec_max` 是一个向量化版本,可以并行计算所有行的最大值。 +注意我们从未创建 -我们将此函数应用于所有行,然后取各行最大值中的最大值。 +* 二维网格 `x_mesh` +* 二维网格 `y_mesh` 或 +* 二维数组 `f(x,y)` -由于将最大值操作移到内部,我们永远不会构建完整的二维数组 `f(x,y)`,从而节省了更多内存。 +与 Numba 类似,我们只使用扁平数组 `grid`。 -并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。 +由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。 让我们试试。 @@ -411,13 +404,11 @@ with qe.Timer(): 它在速度(通过 JIT 编译和并行化)和内存效率(通过 vmap)两方面都优于 NumPy。 -此外,`vmap` 方法有时可以带来更清晰的代码。 - -虽然 Numba 令人印象深刻,但 JAX 的优势在于,对于完全向量化的运算,我们可以在配备硬件加速器的机器上运行完全相同的代码,并在无需额外努力的情况下获得所有收益。 +在 GPU 上运行时,它也优于 Numba。 -此外,JAX 已经知道如何有效地并行化许多常见的数组运算,这是快速执行的关键。 - -对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 +```{note} +Numba 可以通过 `numba.cuda` 支持 GPU 编程,但届时我们需要手动进行并行化。对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 +``` ## 顺序运算 @@ -461,21 +452,70 @@ with qe.Timer(): Numba 非常高效地处理了这个顺序运算。 -注意,JIT 编译完成后,第二次运行明显更快。 +### JAX 版本 + +我们不能直接用 `jax.jit` 替换 `numba.jit`,因为 JAX 数组是不可变的。 -Numba 的编译通常相当快,对于像这样的顺序运算,生成的代码性能非常出色。 +但我们仍然可以实现这个运算。 -### JAX 版本 +#### 第一种尝试 -现在让我们使用 `lax.scan` 创建一个 JAX 版本: +以下是一种使用 `at[t].set` 语法的变通方法,我们在 {ref}`JAX 讲座中讨论过 `。 -(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。) +我们将使用 `lax.fori_loop`,它是一种可以被 XLA 编译的 for 循环版本。 ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=('n',), device=cpu) -def qm_jax(x0, n, α=4.0): +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_fori(x0, n, α=4.0): + + x = jnp.empty(n + 1).at[0].set(x0) + + def update(t, x): + return x.at[t + 1].set(α * x[t] * (1 - x[t])) + + x = lax.fori_loop(0, n, update, x) + return x + +``` + +* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。 +* 我们通过 `device=cpu` 将计算固定在 CPU 上,因为这种顺序工作负载由许多小操作组成,几乎没有机会利用 GPU 并行性。 + +重要提示:尽管 `at[t].set` 看起来在每一步都会创建一个新数组,但在 JIT 编译函数内部,编译器会检测到旧数组不再需要并就地执行更新! + +让我们使用相同的参数计时: + +```{code-cell} ipython3 +with qe.Timer(): + # First run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +让我们再次运行以消除编译开销: + +```{code-cell} ipython3 +with qe.Timer(): + # Second run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +JAX 对于这种顺序运算也相当高效! + +#### 第二种尝试 + +我们还可以用另一种方式实现循环,使用 `lax.scan`。 + +这种替代方案可以说更符合 JAX 的函数式风格——尽管语法不容易记住。 + +```{code-cell} ipython3 +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) return x_new, x_new @@ -486,16 +526,12 @@ def qm_jax(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 -```{note} -我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 -``` - 让我们使用相同的参数计时: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -505,26 +541,24 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` -JAX 对于这种顺序运算也相当高效。 - -JAX 和 Numba 在编译后都能提供出色的性能。 +令人惊讶的是,JAX 在编译后也能提供出色的性能。 ### 总结 -虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,*在代码可读性和易用性方面存在显著差异*。 +虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但在代码可读性和易用性方面存在差异。 Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。 这正是大多数程序员思考该算法的方式。 -另一方面,JAX 版本需要使用 `lax.scan`,这明显不够直观。 +另一方面,JAX 版本需要使用 `lax.fori_loop` 或 `lax.scan`,两者都不如标准 Python 循环直观。 -此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。 +虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等效版本更难阅读。 对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。 @@ -540,17 +574,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进一步探讨。 -对于**顺序操作**,Numba 具有明显优势。 +对于**顺序操作**,Numba 具有更好的语法。 代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。 -JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观。 - -```{note} -`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。 -如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。 -``` +JAX 可以通过 `lax.fori_loop` 或 `lax.scan` 处理顺序问题,但语法不够直观。 -在实践中,许多问题涉及两种模式的混合。 +另一方面,JAX 版本支持自动微分。 -一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba。 +如果我们想要计算轨迹对模型参数的敏感性,这可能会很有用。 From 8533fe98727c967c061b655c62d94d2de03b27b8 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 21:36:29 +0100 Subject: [PATCH 4/4] Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --- .translate/state/numpy_vs_numba_vs_jax.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 34ec88f..9091a5d 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f -synced-at: "2026-04-13" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3