diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 663bfe6..e8f1756 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 8d73de367a7f160dac777aa557f1c26069f84ea5 -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/.translate/state/numba.md.yml b/.translate/state/numba.md.yml index fbbdfa3..c7d5b0a 100644 --- a/.translate/state/numba.md.yml +++ b/.translate/state/numba.md.yml @@ -1,5 +1,5 @@ -source-sha: be6eeaee8db0c8bfea65b89d57ca8aecf7f96dff -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 5 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 93adba6..34ec88f 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594 -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index e50faa0..2f21d1f 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -18,6 +18,7 @@ translation: JAX as a NumPy Replacement::Differences::Speed!: 速度! JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX + JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验 JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 JAX as a NumPy Replacement::Differences::A workaround: 变通方法 @@ -31,13 +32,11 @@ translation: Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 JIT Compilation: JIT 编译 - JIT Compilation::Evaluating a more complicated function: 评估更复杂的函数 - JIT Compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy - JIT Compilation::Evaluating a more complicated function::With JAX: 使用 JAX - JIT Compilation::Compiling the whole function: 编译整个函数 + JIT Compilation::With NumPy: 使用 NumPy + JIT Compilation::With JAX: 使用 JAX + JIT Compilation::Compiling the Whole Function: 编译整个函数 JIT Compilation::How JIT compilation works: JIT 编译的工作原理 JIT Compilation::Compiling non-pure functions: 编译非纯函数 - JIT Compilation::Summary: 总结 Exercises: 练习 --- @@ -205,6 +204,8 @@ with qe.Timer(): 大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。 +#### 大小实验 + 我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。 ```{code-cell} @@ -233,105 +234,6 @@ with qe.Timer(): 关于 JIT 编译的进一步讨论见下文。 -(jax_speed)= -#### 速度! - -假设我们想在许多点上求余弦函数的值。 - -```{code-cell} -n = 50_000_000 -x = np.linspace(0, 10, n) -``` - -##### 使用 NumPy - -让我们先用 NumPy 试试: - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -再来一次。 - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -这里: - -* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数 -* 该二进制文件在本地机器的 CPU 上运行 - -##### 使用 JAX - -现在让我们用 JAX 试试。 - -```{code-cell} -x = jnp.linspace(0, 10, n) -``` - -对相同的过程计时。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{note} -这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。 - -这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。 - -对于非计时代码,可以省略包含 `block_until_ready` 的那行。 -``` - -再计时一次。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。 - -此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。 - -这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。 - -为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本? - -原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。 - -大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。 - -我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。 - -```{code-cell} -x = jnp.linspace(0, 10, n + 1) -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。 - -这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。 - -关于 JIT 编译的进一步讨论见下文。 - #### 精度 NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。 @@ -731,11 +633,7 @@ JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高 我们在 {ref}`上文 ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。 -让我们用一个更复杂的函数尝试同样的操作。 - -### 评估更复杂的函数 - -考虑以下函数: +让我们用一个更复杂的函数尝试同样的操作: ```{code-cell} def f(x): @@ -743,7 +641,7 @@ def f(x): return y ``` -#### 使用 NumPy +### 使用 NumPy 我们先用 NumPy 试试: @@ -758,7 +656,7 @@ with qe.Timer(): y = f(x) ``` -#### 使用 JAX +### 使用 JAX 现在让我们用 JAX 再试一次。 @@ -793,7 +691,7 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -然而,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。 +然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。 ### 编译整个函数 diff --git a/lectures/numba.md b/lectures/numba.md index dcc7de0..46ebf29 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -132,7 +132,7 @@ n = 10_000_000 with qe.Timer() as timer1: # Time Python base version - x = qm(0.1, int(n)) + x = qm(0.1, n) ``` @@ -160,7 +160,7 @@ qm_numba = jit(qm) ```{code-cell} ipython3 with qe.Timer() as timer2: # Time jitted version - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` 这已经是非常大的速度提升。 @@ -172,7 +172,7 @@ with qe.Timer() as timer2: ```{code-cell} ipython3 with qe.Timer() as timer3: # Second run - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` 以下是速度提升 diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 4df7c22..bb1a6e0 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -159,7 +159,7 @@ for x in grid: grid = np.linspace(-3, 3, 3_000) x, y = np.meshgrid(grid, grid) -with qe.Timer(precision=8): +with qe.Timer(): z_max_numpy = np.max(f(x, y)) print(f"NumPy result: {z_max_numpy:.6f}") @@ -182,13 +182,17 @@ def compute_max_numba(grid): for x in grid: for y in grid: z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z + m = max(m, z) return m +``` +让我们测试一下: + +```{code-cell} ipython3 grid = np.linspace(-3, 3, 3_000) -with qe.Timer(precision=8): +with qe.Timer(): + # First run z_max_numba = compute_max_numba(grid) print(f"Numba result: {z_max_numba:.6f}") @@ -197,13 +201,16 @@ print(f"Numba result: {z_max_numba:.6f}") 让我们再次运行以消除编译时间。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba(grid) ``` 根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快。 -一方面,NumPy 将高效的算术运算(类似 Numba)与一定程度的多线程(不同于这段 Numba 代码)结合在一起,这提供了优势。 +在大多数情况下,我们发现 Numba 略胜一筹。 + +一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势。 另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 @@ -211,43 +218,6 @@ with qe.Timer(precision=8): 现在让我们使用 `prange` 尝试 Numba 的并行化: -这是一个简单但**不正确**的尝试。 - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel(grid): - n = len(grid) - m = -np.inf - for i in numba.prange(n): - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m - -``` - -这将返回 `-inf`——即 `m` 的初始值,仿佛它从未被更新过: - -```{code-cell} ipython3 -z_max_parallel_incorrect = compute_max_numba_parallel(grid) -print(f"Numba result: {z_max_parallel_incorrect} 😱") -``` - -要理解原因,请回忆 `prange` 会将外层循环拆分到各个线程中。 - -每个线程都会得到自己的 `m` 私有副本,初始化为 `-np.inf`,并在其负责的迭代块中正确地更新它。 - -但在循环结束时,Numba 需要将各线程的 `m` 副本合并为一个单一的值——即**归约**操作。 - -对于它能识别的模式,例如 `m += z`(求和)或 `m = max(m, z)`(求最大值),Numba 知道合并算子。 - -但它无法将 `if z > m: m = z` 识别为最大值归约,因此各线程的结果永远不会被合并,`m` 始终保持其初始值。 - -最简单的修复方法是将条件判断替换为 Numba 能识别的 `max`: - ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): @@ -262,36 +232,21 @@ def compute_max_numba_parallel(grid): return m ``` -另一种方法是使循环体在不同 `i` 之间完全独立,并自行处理归约: +以下是预热运行和测试。 ```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel_v2(grid): - n = len(grid) - row_maxes = np.empty(n) - for i in numba.prange(n): - row_max = -np.inf - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > row_max: - row_max = z - row_maxes[i] = row_max - return np.max(row_maxes) -``` - -在这里,每个线程写入 `row_maxes` 的不同元素,因此我们通过 `np.max` 自行处理归约。 +with qe.Timer(): + # First run + z_max_parallel = compute_max_numba_parallel(grid) -```{code-cell} ipython3 -z_max_parallel = compute_max_numba_parallel(grid) print(f"Numba result: {z_max_parallel:.6f}") ``` -以下是计时结果。 +以下是预编译版本的计时结果。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba_parallel(grid) ``` @@ -305,8 +260,7 @@ with qe.Timer(precision=8): 但两者之间也存在一些差异,我们在这里加以强调。 -让我们从函数开始。 - +让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit` ```{code-cell} ipython3 @jax.jit @@ -320,9 +274,15 @@ def f(x, y): ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) x_mesh, y_mesh = jnp.meshgrid(grid, grid) +``` + +现在让我们运行并计时 -with qe.Timer(precision=8): +```{code-cell} ipython3 +with qe.Timer(): + # First run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() print(f"Plain vanilla JAX result: {z_max:.6f}") @@ -331,8 +291,10 @@ print(f"Plain vanilla JAX result: {z_max:.6f}") 让我们再次运行以消除编译时间。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() ``` @@ -376,7 +338,7 @@ f_vec = jax.vmap(f_vec_x) 让我们看看计时结果: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() @@ -384,18 +346,18 @@ print(f"JAX vmap v1 result: {z_max:.6f}") ``` ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() ``` -通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间也相近。 +通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间变化不大。 -但我们仍然留有一些速度提升的空间未被利用。 +这很好——但我们还有进一步提升速度的空间! -上面的代码计算了完整的二维数组 `f(x,y)`,然后再取最大值。 +首先请注意,上面的代码计算了完整的二维数组 `f(x,y)`,这会产生开销,然后再取最大值。 -此外,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。 +其次,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。 我们可以通过将最大值操作移到内部并将所有内容包装在一个 `@jax.jit` 中来解决这两个问题: @@ -424,8 +386,11 @@ def compute_max_vmap(grid): 让我们试试。 ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap(grid).block_until_ready() +with qe.Timer(): + # First run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() print(f"JAX vmap result: {z_max:.6f}") ``` @@ -433,8 +398,11 @@ print(f"JAX vmap result: {z_max:.6f}") 让我们再次运行以消除编译时间: ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap(grid).block_until_ready() +with qe.Timer(): + # Second run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() ``` ### 总结 @@ -478,14 +446,16 @@ def qm(x0, n, α=4.0): ```{code-cell} ipython3 n = 10_000_000 -with qe.Timer(precision=8): +with qe.Timer(): + # First run x = qm(0.1, n) ``` 让我们再次运行以消除编译时间: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run x = qm(0.1, n) ``` @@ -504,7 +474,7 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代 ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnums=(1,), device=cpu) +@partial(jax.jit, static_argnames=('n',), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -517,27 +487,27 @@ def qm_jax(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 ```{note} -细心的读者会注意到,我们在 `jax.jit` 装饰器中指定了 `device=cpu`。 - -该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。 - -因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 - -好奇的读者可以尝试删除此选项,看看性能如何变化。 +我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 ``` 让我们使用相同的参数计时: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # First run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` 让我们再次运行以消除编译开销: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # Second run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` JAX 对于这种顺序运算也相当高效。 @@ -546,7 +516,7 @@ JAX 和 Numba 在编译后都能提供出色的性能。 ### 总结 -虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但**在代码可读性和易用性方面存在显著差异**。 +虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,*在代码可读性和易用性方面存在显著差异*。 Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。