From 6e17ee1139bf1c23c9f562cea6ec288dbe1b9779 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 19:01:06 +0100 Subject: [PATCH 1/4] Update translation: lectures/jax_intro.md --- lectures/jax_intro.md | 183 ++++++++++++++++-------------------------- 1 file changed, 69 insertions(+), 114 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 2f21d1f..e09f761 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -21,22 +21,24 @@ translation: JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验 JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 - JAX as a NumPy Replacement::Differences::A workaround: 变通方法 + JAX as a NumPy Replacement::Differences::A Workaround: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 Functional Programming::Examples: 示例 - Functional Programming::Why Functional Programming?: 为什么使用函数式编程? + Functional Programming::Why Functional Programming?: 为什么要函数式编程? Random numbers: 随机数 - Random numbers::Random number generation: 随机数生成 - Random numbers::Why explicit random state?: 为什么要显式随机状态? - Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 - Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 + Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法 + Random numbers::JAX: JAX + Random numbers::Benefits: 好处 JIT Compilation: JIT 编译 JIT Compilation::With NumPy: 使用 NumPy JIT Compilation::With JAX: 使用 JAX JIT Compilation::Compiling the Whole Function: 编译整个函数 JIT Compilation::How JIT compilation works: JIT 编译的工作原理 JIT Compilation::Compiling non-pure functions: 编译非纯函数 + Vectorization with `vmap`: 使用 `vmap` 进行向量化 + Vectorization with `vmap`::A simple example: 一个简单的示例 + Vectorization with `vmap`::Combining transformations: 组合变换 Exercises: 练习 --- @@ -404,13 +406,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射 JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。 -起初,您可能会觉得语法相当冗长。 +### NumPy / MATLAB 方法 -但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的。 +在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作。 -此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。 +```{code-cell} ipython3 +np.random.seed(42) +print(np.random.randn(2)) +``` + +每次我们调用随机函数时,隐藏状态都会被更新: + +```{code-cell} ipython3 +print(np.random.randn(2)) +``` + +这个函数*不是纯函数*,因为: -### 随机数生成 +* 它是非确定性的:相同的输入,不同的输出 +* 它有副作用:它修改了全局随机数生成器状态 + +在并行化下很危险——必须仔细控制每个线程中发生的事情! + +### JAX 在 JAX 中,随机数生成器的状态被显式控制。 @@ -531,109 +549,40 @@ def gen_random_matrices(key, n=2, k=3): key, subkey = jax.random.split(key) A = jax.random.uniform(subkey, (n, n)) matrices.append(A) - print(A) return matrices ``` ```{code-cell} ipython3 seed = 42 key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -我们也可以在循环迭代时使用 `fold_in`: - -```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): - matrices = [] - for i in range(k): - step_key = jax.random.fold_in(key, i) - A = jax.random.uniform(step_key, (n, n)) - matrices.append(A) - print(A) - return matrices -``` - -```{code-cell} ipython3 -key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -### 为什么要显式随机状态? - -为什么 JAX 需要这种相对冗长的随机数生成方法? - -一个原因是为了维护纯函数。 - -让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。 - -#### NumPy 的方法 - -在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。 - -每次我们调用随机函数时,这个状态都会被更新: - -```{code-cell} ipython3 -np.random.seed(42) -print(np.random.randn()) # Updates state of random number generator -print(np.random.randn()) # Updates state of random number generator +gen_random_matrices(key) ``` -每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。 +这个函数是*纯函数*: -这个函数*不是纯函数*,因为: - -* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出 -* 它有副作用:它修改了全局随机数生成器状态 - -#### JAX 的方法 - -如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。 - -例如: - -```{code-cell} ipython3 -def random_sum_jax(key): - key1, key2 = jax.random.split(key) - x = jax.random.normal(key1) - y = jax.random.normal(key2) - return x + y -``` - -使用相同的密钥,我们总是得到相同的结果: - -```{code-cell} ipython3 -key = jax.random.key(42) -random_sum_jax(key) -``` - -```{code-cell} ipython3 -random_sum_jax(key) -``` - -要获得新的抽取,我们需要提供一个新密钥。 - -函数 `random_sum_jax` 是纯函数,因为: - -* 它是确定性的:相同的密钥总是产生相同的输出 +* 确定性的:相同的输入,相同的输出 * 无副作用:没有隐藏状态被修改 +### 好处 + JAX 的显式性带来了显著的好处: * 可复现性:通过重用密钥轻松重现结果 -* 并行化:每个线程可以拥有自己的密钥而不会产生冲突 -* 调试:没有隐藏状态使代码更容易推理 +* 并行化:控制各个线程上发生的事情 +* 调试:没有隐藏状态使代码更容易测试 * JIT 兼容性:编译器可以更积极地优化纯函数 -最后一点将在下一节中进行扩展。 - ## JIT 编译 JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。 我们在 {ref}`上文 ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。 -让我们用一个更复杂的函数尝试同样的操作: +这里我们研究更复杂函数的 JIT 编译。 + +### 使用 NumPy + +我们先用 NumPy 试试,使用: ```{code-cell} def f(x): @@ -641,9 +590,7 @@ def f(x): return y ``` -### 使用 NumPy - -我们先用 NumPy 试试: +用较大的 `x` 运行: ```{code-cell} n = 50_000_000 @@ -656,9 +603,17 @@ with qe.Timer(): y = f(x) ``` -### 使用 JAX +**即时**执行模型 + +* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。 + +缺点 + +* 并行化程度最低 +* 较大的内存占用——产生许多中间数组 +* 大量内存读写 -现在让我们用 JAX 再试一次。 +### 使用 JAX 作为第一步,我们将整个代码中的 `np` 替换为 `jnp`: @@ -691,11 +646,13 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。 +但我们仍在使用即时执行——大量内存和读写开销。 ### 编译整个函数 -JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。 +幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。 + +编译器将所有数组运算融合到单个优化内核中。 让我们用函数 `f` 来试试这个: @@ -719,9 +676,11 @@ with qe.Timer(): jax.block_until_ready(y); ``` -运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。 +运行时间再次改善——现在是因为我们融合了所有操作: -例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。 +* 基于整个计算序列的积极优化 +* 消除对硬件加速器的多次调用 +* 不创建中间数组 顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是: @@ -741,11 +700,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP ### 编译非纯函数 -现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。 +虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测! -虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。 - -以下是一个使用全局变量的例子: +以下是一个例子: ```{code-cell} ipython3 a = 1 # global @@ -789,9 +746,9 @@ f(x) ## 使用 `vmap` 进行向量化 -JAX 的另一个强大变换是 `jax.vmap`,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行。 +JAX 的另一个强大变换是 `jax.vmap`,它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作。 -这避免了手动编写向量化代码或使用显式循环的需要。 +这样就无需手动编写向量化代码或使用显式循环。 ### 一个简单的示例 @@ -809,7 +766,7 @@ x = jnp.array([1.0, 2.0, 5.0]) mm_diff(x) ``` -现在假设我们有一个矩阵,想要对每一行计算这些统计量。 +现在假设我们有一个矩阵,希望对每一行计算这些统计量。 不使用 `vmap` 时,我们需要显式循环: @@ -824,18 +781,16 @@ for row in X: 然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。 -使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用: +使用 `vmap`,我们可以避免循环,并将计算保留在加速器上: ```{code-cell} ipython3 -batch_mm_diff = jax.vmap(mm_diff) -batch_mm_diff(X) +batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version +batch_mm_diff(X) # Apply to each row of X ``` -函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。 - ### 组合变换 -JAX 的优势之一在于各变换可以自然地组合使用。 +JAX 的优势之一在于各种变换可以自然地组合使用。 例如,我们可以对向量化函数进行 JIT 编译: @@ -844,7 +799,7 @@ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff)) fast_batch_mm_diff(X) ``` -`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。 +`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。 ## 练习 From 70d05ea225b111e0d6c39f8ebc25d1de5515ce8e Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 19:01:08 +0100 Subject: [PATCH 2/4] Update translation: .translate/state/jax_intro.md.yml --- .translate/state/jax_intro.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index e8f1756..4f0ca12 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f -synced-at: "2026-04-13" +source-sha: 450bafecd23db638602150b47f4272b98aad3146 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 From 2d0d7546aae77b11f3a6591f9a9ce15c41e9c325 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 19:01:09 +0100 Subject: [PATCH 3/4] Update translation: lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 180 ++++++++++++++++-------------- 1 file changed, 95 insertions(+), 85 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index bb1a6e0..5621d66 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -21,6 +21,8 @@ translation: Sequential operations: 顺序运算 Sequential operations::Numba Version: Numba 版本 Sequential operations::JAX Version: JAX 版本 + Sequential operations::JAX Version::First Attempt: 第一种尝试 + Sequential operations::JAX Version::Second Attempt: 第二种尝试 Sequential operations::Summary: 总结 Overall recommendations: 总体建议 --- @@ -143,33 +145,34 @@ m = -np.inf for x in grid: for y in grid: z = f(x, y) - if z > m: - m = z + m = max(m, z) ``` ### NumPy 向量化 -如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。 +让我们切换到 NumPy 并使用更大的网格。 这里我们使用 `np.meshgrid` 来创建二维输入网格 `x` 和 `y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。 -(这一策略可以追溯到 MATLAB。) - ```{code-cell} ipython3 +# Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) + +x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): z_max_numpy = np.max(f(x, y)) - -print(f"NumPy result: {z_max_numpy:.6f}") ``` 在向量化版本中,所有循环都在编译后的代码中执行。 -此外,NumPy 使用隐式多线程,因此至少会发生一定程度的并行化。 +使用 `meshgrid` 可以复现嵌套的 for 循环。 + +输出结果应接近于 1: -(并行化效率不高,因为二进制文件在看到数组 `x` 和 `y` 的大小之前就已经被编译了。) +```{code-cell} ipython3 +print(f"NumPy result: {z_max_numpy:.6f}") +``` ### 与 Numba 的比较 @@ -194,8 +197,6 @@ grid = np.linspace(-3, 3, 3_000) with qe.Timer(): # First run z_max_numba = compute_max_numba(grid) - -print(f"Numba result: {z_max_numba:.6f}") ``` 让我们再次运行以消除编译时间。 @@ -238,8 +239,6 @@ def compute_max_numba_parallel(grid): with qe.Timer(): # First run z_max_parallel = compute_max_numba_parallel(grid) - -print(f"Numba result: {z_max_parallel:.6f}") ``` 以下是预编译版本的计时结果。 @@ -250,15 +249,19 @@ with qe.Timer(): compute_max_numba_parallel(grid) ``` -如果您有多个核心,您应该能在此处看到并行化带来的一定收益。 +如果您有多个核心,您应该能在此处看到并行化带来的收益。 -对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。 +让我们确认结果仍然正确(接近于 1): -### 使用 JAX 的向量化代码 +```{code-cell} ipython3 +print(f"Numba result: {z_max_parallel:.6f}") +``` -表面上,JAX 中的向量化代码与 NumPy 代码类似。 +对于强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来有用的速度提升。 -但两者之间也存在一些差异,我们在这里加以强调。 +### 使用 JAX 的向量化代码 + +让我们尝试用 JAX 复现 NumPy 的向量化方法。 让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit` @@ -269,7 +272,7 @@ def f(x, y): ``` -与 NumPy 一样,为了获得正确的形状和正确的嵌套 `for` 循环计算,我们可以使用专为此目的设计的 `meshgrid` 操作: +我们使用 NumPy 风格的 meshgrid 方法: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) @@ -326,60 +329,24 @@ x_mesh.nbytes + y_mesh.nbytes 以下是我们将其应用于当前问题的方式。 -```{code-cell} ipython3 -# 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y) -f_vec_x = lambda y: f(grid, y) -# 创建第二个函数,将此操作在所有 y 上向量化 -f_vec = jax.vmap(f_vec_x) -``` - -现在,当以扁平数组 `grid` 调用时,`f_vec` 将在每个 `x,y` 处计算 `f(x,y)`。 - -让我们看看计时结果: - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() - -print(f"JAX vmap v1 result: {z_max:.6f}") -``` - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() -``` - -通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间变化不大。 - -这很好——但我们还有进一步提升速度的空间! - -首先请注意,上面的代码计算了完整的二维数组 `f(x,y)`,这会产生开销,然后再取最大值。 - -其次,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。 - -我们可以通过将最大值操作移到内部并将所有内容包装在一个 `@jax.jit` 中来解决这两个问题: - ```{code-cell} ipython3 @jax.jit def compute_max_vmap(grid): - # 构建一个沿每行取最大值的函数 + # 构建一个对给定 y,在所有 x 上取最大值的函数 f_vec_x_max = lambda y: jnp.max(f(grid, y)) - # 向量化该函数,以便我们可以同时对所有行调用 + # 向量化该函数,以便我们可以同时对所有 y 调用 f_vec_max = jax.vmap(f_vec_x_max) - # 调用向量化函数并取最大值 - return jnp.max(f_vec_max(grid)) + # 在每个 y 处计算所有 x 上的最大值 + maxes = f_vec_max(grid) + # 计算最大值的最大值并返回 + return jnp.max(maxes) ``` -其中 - -* `f_vec_x_max` 计算任意给定行的最大值 -* `f_vec_max` 是一个向量化版本,可以并行计算所有行的最大值。 - -我们将此函数应用于所有行,然后取各行最大值中的最大值。 +注意我们从不创建 -由于将最大值操作移到内部,我们永远不会构建完整的二维数组 `f(x,y)`,从而节省了更多内存。 +* 二维网格 `x_mesh` +* 二维网格 `y_mesh` 或 +* 二维数组 `f(x,y)` 并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。 @@ -461,21 +428,70 @@ with qe.Timer(): Numba 非常高效地处理了这个顺序运算。 -注意,JIT 编译完成后,第二次运行明显更快。 +### JAX 版本 -Numba 的编译通常相当快,对于像这样的顺序运算,生成的代码性能非常出色。 +我们不能直接用 `jax.jit` 替换 `numba.jit`,因为 JAX 数组是不可变的。 -### JAX 版本 +但我们仍然可以实现这一运算。 + +#### 第一种尝试 -现在让我们使用 `lax.scan` 创建一个 JAX 版本: +以下是使用 `at[t].set` 语法的变通方案,我们在 {ref}`JAX 讲座中讨论过 `。 -(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。) +我们将应用 `lax.fori_loop`,这是一种可以被 XLA 编译的 for 循环版本。 ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=('n',), device=cpu) -def qm_jax(x0, n, α=4.0): +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_fori(x0, n, α=4.0): + + x = jnp.empty(n + 1).at[0].set(x0) + + def update(t, x): + return x.at[t + 1].set(α * x[t] * (1 - x[t])) + + x = lax.fori_loop(0, n, update, x) + return x + +``` + +* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。 +* 我们通过 `device=cpu` 将计算固定到 CPU,因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。 + +重要提示:虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新! + +让我们使用相同的参数计时: + +```{code-cell} ipython3 +with qe.Timer(): + # First run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +让我们再次运行以消除编译开销: + +```{code-cell} ipython3 +with qe.Timer(): + # Second run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +JAX 对于这种顺序运算也相当高效! + +#### 第二种尝试 + +还有另一种使用 `lax.scan` 实现该循环的方式。 + +这种替代方案可以说更符合 JAX 的函数式风格——尽管语法难以记忆。 + +```{code-cell} ipython3 +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) return x_new, x_new @@ -486,16 +502,12 @@ def qm_jax(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 -```{note} -我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 -``` - 让我们使用相同的参数计时: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -505,26 +517,24 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax(0.1, n) + x_jax = qm_jax_scan(0.1, n) # Hold interpreter x_jax.block_until_ready() ``` -JAX 对于这种顺序运算也相当高效。 - -JAX 和 Numba 在编译后都能提供出色的性能。 +令人惊讶的是,JAX 在编译后也能提供出色的性能。 ### 总结 -虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,*在代码可读性和易用性方面存在显著差异*。 +虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但在代码可读性和易用性方面存在差异。 Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。 这正是大多数程序员思考该算法的方式。 -另一方面,JAX 版本需要使用 `lax.scan`,这明显不够直观。 +另一方面,JAX 版本需要使用 `lax.fori_loop` 或 `lax.scan`,这两者都不如标准 Python 循环直观。 -此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。 +虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。 对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。 From 88b6332b3be6d19f9b8f443896ac8a2cccb7cbd6 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Tue, 14 Apr 2026 19:01:09 +0100 Subject: [PATCH 4/4] Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --- .translate/state/numpy_vs_numba_vs_jax.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 34ec88f..e904fbe 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f -synced-at: "2026-04-13" +source-sha: 450bafecd23db638602150b47f4272b98aad3146 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3