From cee064887191ab36b6f375019f247cd3d7c79a3a Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:38 +0100 Subject: [PATCH 1/8] Update translation: lectures/jax_intro.md --- lectures/jax_intro.md | 238 +++++++----------------------------------- 1 file changed, 36 insertions(+), 202 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 4318167..15c873c 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -15,9 +15,9 @@ translation: JAX as a NumPy Replacement: JAX 作为 NumPy 的替代品 JAX as a NumPy Replacement::Similarities: 相似之处 JAX as a NumPy Replacement::Differences: 差异 - JAX as a NumPy Replacement::Differences::Precision: 精度 - JAX as a NumPy Replacement::Differences::Immutability: 不可变性 - JAX as a NumPy Replacement::Differences::A workaround: 变通方法 + JAX as a NumPy Replacement::Differences::Speed!: 精度 + JAX as a NumPy Replacement::Differences::Precision: 不可变性 + JAX as a NumPy Replacement::Differences::Immutability: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 Functional Programming::Examples: 示例 @@ -26,18 +26,12 @@ translation: Random numbers::Why explicit random state?: 为什么要显式随机状态? Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 - JIT compilation: JIT 编译 - JIT compilation::A simple example: 一个简单的示例 - JIT compilation::A simple example::With NumPy: 使用 NumPy - JIT compilation::A simple example::With JAX: 使用 JAX - JIT compilation::A simple example::Changing array sizes: 更改数组大小 - JIT compilation::Evaluating a more complicated function: 评估更复杂的函数 - JIT compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy - JIT compilation::Evaluating a more complicated function::With JAX: 使用 JAX - JIT compilation::Compiling the Whole Function: 编译整个函数 - JIT compilation::Compiling non-pure functions: 编译非纯函数 - JIT compilation::Summary: 总结 - Gradients: 梯度 + JIT Compilation: JIT 编译 + JIT Compilation::With NumPy: 使用 NumPy + JIT Compilation::With JAX: 使用 JAX + JIT Compilation::Compiling the Whole Function: 编译整个函数 + JIT Compilation::How JIT compilation works: JIT 编译的工作原理 + JIT Compilation::Compiling non-pure functions: 编译非纯函数 Exercises: 练习 --- @@ -148,7 +142,6 @@ jnp.linalg.inv(B) # Inverse of identity is identity jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors ``` - ### 差异 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 @@ -249,7 +242,6 @@ a (尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。) - ## 函数式编程 来自 JAX 的文档: @@ -279,8 +271,6 @@ a * 不会改变全局状态 * 不会修改传递给函数的数据(不可变数据) - - ### 示例 以下是一个*非纯*函数的示例: @@ -316,7 +306,6 @@ def add_tax_pure(prices, tax_rate): 现在我们理解了什么是纯函数,让我们探索 JAX 处理随机数的方法如何维护这种纯粹性。 - ## 随机数 与 NumPy 或 Matlab 中的随机数相比,JAX 中的随机数有很大不同。 @@ -327,7 +316,6 @@ def add_tax_pure(prices, tax_rate): 此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。 - ### 随机数生成 在 JAX 中,随机数生成器的状态被显式控制。 @@ -405,7 +393,6 @@ key = jax.random.PRNGKey(seed) matrices = gen_random_matrices(key) ``` - ### 为什么要显式随机状态? 为什么 JAX 需要这种相对冗长的随机数生成方法? @@ -433,7 +420,6 @@ print(np.random.randn()) # Updates state of random number generator * 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出 * 它有副作用:它修改了全局随机数生成器状态 - #### JAX 的方法 如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。 @@ -475,126 +461,21 @@ JAX 的显式性带来了显著的好处: 最后一点将在下一节中进行扩展。 - ## JIT 编译 JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。 -### 一个简单的示例 - -假设我们想在许多点上求余弦函数的值。 - -```{code-cell} -n = 50_000_000 -x = np.linspace(0, 10, n) -``` - -#### 使用 NumPy - -让我们先用 NumPy 试试: - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -再试一次。 - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -这里 NumPy 使用预先构建的二进制文件,该文件由精心编写的低级代码编译而成,用于对浮点数数组应用余弦函数。 - -这个二进制文件随 NumPy 一起发布。 - -#### 使用 JAX - -现在让我们用 JAX 试试。 - -```{code-cell} -x = jnp.linspace(0, 10, n) -``` - -让我们对相同的过程计时。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{note} -这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。 - -这是必要的,因为 JAX 使用异步调度,允许 Python 解释器领先于数值计算。 - -对于非计时代码,您可以删除包含 `block_until_ready` 的那一行。 -``` - - -再次计时。 - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。 - -此外,通常第二次运行比第一次快,这是由于 JIT 编译的原因。 - -这是因为即使是像 `jnp.cos` 这样的内置函数也经过了 JIT 编译——第一次运行包含了编译时间。 - -为什么 JAX 要对像 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样提供预编译版本? - -原因是 JIT 编译器希望针对正在使用的数组的*大小*(以及数据类型)进行专门优化。 - -大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。 - -这就是为什么 JAX 要等到看到数组大小后再进行编译——这需要 JIT 编译方法,而不是提供预编译的二进制文件。 - - -#### 更改数组大小 - -这里我们更改输入大小并观察运行时间。 - -```{code-cell} -x = jnp.linspace(0, 10, n + 1) -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -通常,运行时间会先增加然后再减少(这在 GPU 上会更加明显)。 - -这是因为 JIT 编译器针对数组大小进行专门优化以利用并行化——因此当数组大小改变时会生成新的编译代码。 +当我们在 {ref}`上面 ` 对一个大型数组应用 `cos` 时,我们看到了 JAX 的 JIT 编译器结合并行硬件的强大之处。 - -### 评估更复杂的函数 - -让我们用一个更复杂的函数尝试同样的操作。 +让我们用一个更复杂的函数尝试同样的操作: ```{code-cell} def f(x): - y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2 + y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2 return y ``` -#### 使用 NumPy +### 使用 NumPy 我们先用 NumPy 试试: @@ -605,12 +486,11 @@ x = np.linspace(0, 10, n) ```{code-cell} with qe.Timer(): + # Time NumPy code y = f(x) ``` - - -#### 使用 JAX +### 使用 JAX 现在让我们用 JAX 再试一次。 @@ -620,34 +500,36 @@ with qe.Timer(): def f(x): y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2 return y -``` -现在让我们计时。 -```{code-cell} x = jnp.linspace(0, 10, n) ``` +现在让我们计时。 + ```{code-cell} with qe.Timer(): + # First call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second call y = f(x) + # Hold interpreter jax.block_until_ready(y); ``` 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -此外,使用 JAX,我们还有另一个技巧: - +然而,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。 ### 编译整个函数 -JAX 即时(JIT)编译器可以通过将线性代数运算融合到单个优化内核中来加速函数内部的执行。 +JAX 即时(JIT)编译器可以通过将数组操作融合到单个优化内核中来加速函数内部的执行。 让我们用函数 `f` 来试试这个: @@ -657,13 +539,17 @@ f_jax = jax.jit(f) ```{code-cell} with qe.Timer(): + # First run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` ```{code-cell} with qe.Timer(): + # Second run y = f_jax(x) + # Hold interpreter jax.block_until_ready(y); ``` @@ -671,7 +557,6 @@ with qe.Timer(): 例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。 - 顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是: ```{code-cell} ipython3 @@ -680,6 +565,14 @@ def f(x): pass # put function body here ``` +### JIT 编译的工作原理 + +当我们对一个函数应用 `jax.jit` 时,JAX 会对其进行*追踪*:它不会立即执行操作,而是将操作序列记录为计算图,并将该图交给 [XLA](https://openxla.org/xla) 编译器。 + +XLA 随后将这些操作融合并优化为一个针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。 + +对 JIT 编译函数的第一次调用会产生编译开销,但后续具有相同输入形状和类型的调用将复用缓存的编译代码并以全速运行。 + ### 编译非纯函数 现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。 @@ -728,65 +621,6 @@ f(x) 这个故事的寓意:使用 JAX 时请编写纯函数! - -### 总结 - -现在我们可以理解为什么开发者和编译器都受益于纯函数。 - -我们喜欢纯函数,因为它们: - -* 有助于测试:每个函数可以独立运行 -* 促进确定性行为,从而实现可复现性 -* 防止由于修改共享状态而产生的错误 - -编译器喜欢纯函数和函数式编程,因为: - -* 数据依赖关系是显式的,有助于优化复杂计算 -* 纯函数更容易进行微分(自动微分) -* 纯函数更容易并行化和优化(不依赖于共享可变状态) - - -## 梯度 - -JAX 可以使用自动微分来计算梯度。 - -这对于优化和求解非线性系统非常有用。 - -我们将在本讲座系列后面看到重要的应用。 - -现在,这里有一个非常简单的说明,涉及函数: - -```{code-cell} ipython3 -def f(x): - return (x**2) / 2 -``` - -让我们求导数: - -```{code-cell} ipython3 -f_prime = jax.grad(f) -``` - -```{code-cell} ipython3 -f_prime(10.0) -``` - -让我们绘制函数和导数,注意 $f'(x) = x$。 - -```{code-cell} ipython3 -import matplotlib.pyplot as plt - -fig, ax = plt.subplots() -x_grid = jnp.linspace(-4, 4, 200) -ax.plot(x_grid, f(x_grid), label="$f$") -ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") -ax.legend(loc='upper center') -plt.show() -``` - -我们将进一步探索 JAX 自动微分的内容推迟到 {doc}`jax:autodiff`。 - - ## 练习 @@ -871,4 +705,4 @@ with qe.Timer(): ``` ```{solution-end} -``` \ No newline at end of file +``` From 3cca34918c83299fadca718991b37fc736720d47 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:39 +0100 Subject: [PATCH 2/8] Update translation: .translate/state/jax_intro.md.yml --- .translate/state/jax_intro.md.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index baa2850..e8f1756 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,6 +1,6 @@ -source-sha: c4c03c80c1eb4318f627d869707d242d19c8cf09 -synced-at: "2026-03-20" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 -mode: NEW -section-count: 6 -tool-version: 0.13.0 +mode: UPDATE +section-count: 7 +tool-version: 0.14.1 From 14b109b4c13ab91a7fa9205607c0db5d9e7808a3 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:39 +0100 Subject: [PATCH 3/8] Update translation: lectures/numba.md --- lectures/numba.md | 347 ++++++++-------------------------------------- 1 file changed, 54 insertions(+), 293 deletions(-) diff --git a/lectures/numba.md b/lectures/numba.md index e6739c7..6dfbb8d 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -14,19 +14,15 @@ translation: headings: Overview: 概述 Compiling Functions: 编译函数 - Compiling Functions::An Example: 一个示例 - Compiling Functions::How and When it Works: 工作原理及适用场景 - Decorator Notation: 装饰器语法 - Type Inference: 类型推断 - Compiling Classes: 编译类 - Dangers and Limitations: 危险与局限 - Dangers and Limitations::Limitations: 局限性 - 'Dangers and Limitations::A Gotcha: Global Variables': 一个陷阱:全局变量 + Compiling Functions::An Example: 示例 + Compiling Functions::An Example::Base Version: 基础版本 + Compiling Functions::An Example::Acceleration via Numba: 通过 Numba 加速 + Compiling Functions::How and When it Works: 工作原理与适用场景 Multithreaded Loops in Numba: Numba 中的多线程循环 Exercises: 练习 --- -(speed)= +(numba_lecture)= ```{raw} jupyter
@@ -92,29 +88,26 @@ mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ```{index} single: Python; Numba ``` -如上所述,Numba 的主要用途是在运行时将函数编译为快速的本地机器码。 (quad_map_eg)= -### 一个示例 +### 示例 -让我们考虑一个难以向量化的问题:给定初始条件,生成差分方程的轨迹。 +让我们考虑一个难以向量化(即交由数组处理操作完成)的问题。 -我们将采用的差分方程是二次映射 +该问题涉及通过二次映射生成轨迹 $$ -x_{t+1} = \alpha x_t (1 - x_t) + x_{t+1} = \alpha x_t (1 - x_t) $$ -在下文中,我们设定 +在以下内容中,我们设 $\alpha = 4$。 -```{code-cell} ipython3 -α = 4.0 -``` +#### 基础版本 -以下是一条典型轨迹的图像,从 $x_0 = 0.1$ 开始,以 $t$ 为横轴 +以下是从 $x_0 = 0.1$ 出发的典型轨迹图,x 轴为 $t$ ```{code-cell} ipython3 -def qm(x0, n): +def qm(x0, n, α=4.0): x = np.empty(n+1) x[0] = x0 for t in range(n): @@ -129,327 +122,95 @@ ax.set_ylabel('$x_{t}$', fontsize = 12) plt.show() ``` -要使用 Numba 加速函数 `qm`,我们的第一步是 - -```{code-cell} ipython3 -from numba import jit - -qm_numba = jit(qm) -``` - -函数 `qm_numba` 是 `qm` 的一个版本,它被"定向"为 JIT 编译。 - -我们稍后将解释这意味着什么。 - -让我们对这两个版本进行相同的函数调用计时和比较,首先从原始函数 `qm` 开始: +让我们看看当 $n$ 很大时运行需要多长时间 ```{code-cell} ipython3 n = 10_000_000 with qe.Timer() as timer1: - qm(0.1, int(n)) -time1 = timer1.elapsed -``` - -现在让我们试试 qm_numba + # Time Python base version + x = qm(0.1, n) -```{code-cell} ipython3 -with qe.Timer() as timer2: - qm_numba(0.1, int(n)) -time2 = timer2.elapsed ``` -这已经是一个非常大的速度提升。 - -事实上,下次及之后每次运行时速度都会更快,因为函数已经被编译并存储在内存中: +#### 通过 Numba 加速 -(qm_numba_result)= +要使用 Numba 加速函数 `qm`,我们首先导入 `jit` 函数 -```{code-cell} ipython3 -with qe.Timer() as timer3: - qm_numba(0.1, int(n)) -time3 = timer3.elapsed -``` ```{code-cell} ipython3 -time1 / time3 # 计算速度提升倍数 +from numba import jit ``` -相对于修改的简单和清晰程度,这种速度提升令人印象深刻。 - -### 工作原理及适用场景 - -Numba 尝试使用 [LLVM 项目](https://llvm.org/) 提供的基础设施生成快速的机器码。 - -它通过即时推断类型信息来实现这一点。 - -(有关类型的讨论,请参阅我们{doc}`之前关于科学计算的讲座 `。) - -基本思路如下: - -* Python 非常灵活,因此我们可以用多种类型调用函数 `qm`。 - * 例如,`x0` 可以是 NumPy 数组或列表,`n` 可以是整数或浮点数,等等。 -* 这使得*预*编译函数(即在运行时之前编译)变得困难。 -* 然而,当我们实际调用函数时,比如运行 `qm(0.5, 10)`,`x0` 和 `n` 的类型就变得清晰了。 -* 此外,一旦知道输入类型,`qm` 中其他变量的类型也可以被推断出来。 -* 因此,Numba 和其他 JIT 编译器的策略是等到这一时刻,*然后*再编译函数。 - -这就是为什么它被称为"即时"编译。 - -请注意,如果您调用 `qm(0.5, 10)`,然后紧跟着调用 `qm(0.9, 20)`,编译只发生在第一次调用时。 - -编译后的代码会被缓存并按需复用。 - -这就是为什么在上面的代码中,`time3` 比 `time2` 小。 - -## 装饰器语法 - -在上面的代码中,我们通过调用以下方式创建了 `qm` 的 JIT 编译版本 +现在我们将其应用于 `qm`,生成一个新函数: ```{code-cell} ipython3 qm_numba = jit(qm) ``` -在实践中,这通常使用另一种*装饰器*语法来完成。 - -(我们在{doc}`单独的讲座 `中讨论装饰器,但在此阶段您可以跳过细节。) - -让我们看看这是如何完成的。 - -要将函数定向为 JIT 编译,我们可以在函数定义前放置 `@jit`。 - -以下是 `qm` 的写法 - -```{code-cell} ipython3 -@jit -def qm(x0, n): - x = np.empty(n+1) - x[0] = x0 - for t in range(n): - x[t+1] = α * x[t] * (1 - x[t]) - return x -``` - -这等价于在函数定义后添加 `qm = jit(qm)`。 - -以下代码现在使用 JIT 编译版本: - -```{code-cell} ipython3 -with qe.Timer(precision=4): - qm(0.1, 100_000) -``` - -```{code-cell} ipython3 -with qe.Timer(precision=4): - qm(0.1, 100_000) -``` - -Numba 还为装饰器提供了几个参数以加速计算和缓存函数——请参阅[这里](https://numba.readthedocs.io/en/stable/user/performance-tips.html)。 - -## 类型推断 - -成功的类型推断是 JIT 编译的关键部分。 +函数 `qm_numba` 是 `qm` 的一个版本,专门针对 JIT 编译进行了"目标化"。 -可以想象,对于简单的 Python 对象(例如,浮点数和整数等简单标量数据类型),推断类型更为容易。 - -Numba 也与 NumPy 数组配合良好,因为它们具有明确定义的类型。 - -在理想情况下,Numba 可以推断出所有必要的类型信息。 - -这使它能够生成本地机器码,而无需调用 Python 运行时环境。 - -在这种情况下,Numba 将与低级语言的机器码相媲美。 - -当 Numba 无法推断所有类型信息时,它将引发错误。 - -例如,在下面这个(人为的)示例中,Numba 在编译函数 `bootstrap` 时无法确定函数 `mean` 的类型 - -```{code-cell} ipython3 -@jit -def bootstrap(data, statistics, n): - bootstrap_stat = np.empty(n) - n = len(data) - for i in range(n_resamples): - resample = np.random.choice(data, size=n, replace=True) - bootstrap_stat[i] = statistics(resample) - return bootstrap_stat - -# 这里没有装饰器。 -def mean(data): - return np.mean(data) - -data = np.array((2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2)) -n_resamples = 10 - -# 这段代码会抛出错误 -try: - bootstrap(data, mean, n_resamples) -except Exception as e: - print(e) -``` - -在这种情况下,我们可以通过编译 `mean` 来轻松修复这个错误。 - -```{code-cell} ipython3 -@jit -def mean(data): - return np.mean(data) - -with qe.Timer(): - bootstrap(data, mean, n_resamples) -``` - -## 编译类 - -如上所述,目前 Numba 只能编译 Python 的一个子集。 - -然而,这个子集一直在扩展。 - -值得注意的是,Numba 现在在编译类方面相当有效。 - -如果一个类被成功编译,那么它的方法就像 JIT 编译的函数一样运行。 - -举一个例子,让我们考虑在{doc}`本讲座 `中创建的用于分析索洛-斯旺增长模型的类。 +我们稍后将解释这意味着什么。 -要编译这个类,我们使用 `@jitclass` 装饰器: +让我们对这个新版本计时: ```{code-cell} ipython3 -from numba import float64 -from numba.experimental import jitclass +with qe.Timer() as timer2: + # Time jitted version + x = qm_numba(0.1, n) ``` -注意,我们还导入了一个叫做 `float64` 的东西。 - -这是一种表示标准浮点数的数据类型。 +这是一个很大的速度提升。 -我们在这里导入它是因为当 Numba 尝试处理类时,它需要一些关于类型的额外帮助。 +事实上,第二次及之后的运行速度会更快,因为函数已经被编译并驻留在内存中: -以下是我们的代码: +(qm_numba_result)= ```{code-cell} ipython3 -solow_data = [ - ('n', float64), - ('s', float64), - ('δ', float64), - ('α', float64), - ('z', float64), - ('k', float64) -] - -@jitclass(solow_data) -class Solow: - r""" - Implements the Solow growth model with the update rule - - k_{t+1} = [(s z k^α_t) + (1 - δ)k_t] /(1 + n) - - """ - def __init__(self, n=0.05, # population growth rate - s=0.25, # savings rate - δ=0.1, # depreciation rate - α=0.3, # share of labor - z=2.0, # productivity - k=1.0): # current capital stock - - self.n, self.s, self.δ, self.α, self.z = n, s, δ, α, z - self.k = k - - def h(self): - "Evaluate the h function" - # 解包参数(去掉 self 以简化符号) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # 应用更新规则 - return (s * z * self.k**α + (1 - δ) * self.k) / (1 + n) - - def update(self): - "Update the current state (i.e., the capital stock)." - self.k = self.h() - - def steady_state(self): - "Compute the steady state value of capital." - # 解包参数(去掉 self 以简化符号) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # 计算并返回稳态 - return ((s * z) / (n + δ))**(1 / (1 - α)) - - def generate_sequence(self, t): - "Generate and return a time series of length t" - path = [] - for i in range(t): - path.append(self.k) - self.update() - return path +with qe.Timer() as timer3: + # Second run + x = qm_numba(0.1, n) ``` -首先,我们在 `solow_data` 中指定了类的实例数据类型。 - -之后,将类定向为 JIT 编译只需在类定义前添加 `@jitclass(solow_data)` 即可。 - -当我们调用类中的方法时,这些方法就像函数一样被即时编译。 +速度提升如下 ```{code-cell} ipython3 -s1 = Solow() -s2 = Solow(k=8.0) - -T = 60 -fig, ax = plt.subplots() - -# 绘制共同的稳态资本值 -ax.plot([s1.steady_state()]*T, 'k-', label='稳态') - -# 为每个经济体绘制时间序列 -for s in s1, s2: - lb = f'从初始状态 {s.k} 出发的资本序列' - ax.plot(s.generate_sequence(T), 'o-', lw=2, alpha=0.6, label=lb) -ax.set_ylabel('$k_{t}$', fontsize=12) -ax.set_xlabel('$t$', fontsize=12) -ax.legend() -plt.show() +timer1.elapsed / timer3.elapsed ``` -## 危险与局限 - -让我们回顾上述内容并补充一些注意事项。 - -### 局限性 - -正如我们所见,Numba 需要推断所有变量的类型信息以生成快速的机器级指令。 +这对我们原始代码的一个小改动来说是一个很大的提升。 -对于简单的例程,Numba 推断类型非常出色。 +让我们来讨论这是如何工作的。 -对于较大的例程,或使用外部库的例程,它很容易失败。 +### 工作原理与适用场景 -因此,在使用 Numba 时,明智的做法是专注于加速代码中小而关键的片段。 +Numba 尝试使用 [LLVM Project](https://llvm.org/) 提供的基础设施生成快速机器码。 -这将比在 Python 程序中大量使用 `@njit` 语句带来更好的性能。 +它通过即时推断类型信息来实现这一点。 -### 一个陷阱:全局变量 +(有关类型的讨论,请参阅我们 {doc}`早期的讲座 `,内容涉及科学计算。) -以下是使用 Numba 时需要注意的另一件事。 +基本思想如下: -考虑以下示例 +* Python 非常灵活,因此我们可以用多种类型调用函数 `qm`。 + * 例如,`x0` 可以是 NumPy 数组或列表,`n` 可以是整数或浮点数,等等。 +* 这使得*提前*(即在运行时之前)生成高效机器码非常困难。 +* 然而,当我们实际*调用*函数时,例如运行 `qm(0.5, 10)`,`x0`、`α` 和 `n` 的类型就确定了。 +* 此外,一旦输入类型已知,`qm` 中*其他变量*的类型*可以被推断出来*。 +* 因此,Numba 和其他 JIT 编译器的策略是*等到函数被调用时*,再进行编译。 -```{code-cell} ipython3 -a = 1 +这被称为"即时"编译。 -@jit -def add_a(x): - return a + x +请注意,如果你调用 `qm_numba(0.5, 10)`,然后紧接着调用 `qm_numba(0.9, 20)`,编译只在第一次调用时发生。 -print(add_a(10)) -``` +这是因为编译后的代码会被缓存并按需复用。 -```{code-cell} ipython3 -a = 2 +这就是为什么在上面的代码中,`qm_numba` 的第二次运行更快。 -print(add_a(10)) +```{admonition} 备注 +在实践中,我们通常不写 `qm_numba = jit(qm)`,而是使用*装饰器*语法,在函数定义前加上 `@jit`。这等价于在定义之后添加 `qm = jit(qm)`。 ``` -注意,更改全局变量对函数返回值没有任何影响。 - -当 Numba 为函数编译机器码时,它将全局变量视为常量,以确保类型稳定性。 - -(multithreading)= ## Numba 中的多线程循环 除了 JIT 编译之外,Numba 还为 CPU 上的并行计算提供了强大支持。 @@ -922,4 +683,4 @@ def compute_call_price_parallel(β=β, 如果您使用的是具有多个 CPU 的机器,差异应该很显著。 ```{solution-end} -``` \ No newline at end of file +``` From 6c7593e8c684c98b599f72bc50309140fd0f63e7 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:40 +0100 Subject: [PATCH 4/8] Update translation: .translate/state/numba.md.yml --- .translate/state/numba.md.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.translate/state/numba.md.yml b/.translate/state/numba.md.yml index c5f1f07..c7d5b0a 100644 --- a/.translate/state/numba.md.yml +++ b/.translate/state/numba.md.yml @@ -1,6 +1,6 @@ -source-sha: cc9c3256dc35bd277cb25d0089f0a0452c0fa94e -synced-at: "2026-03-20" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 -mode: NEW -section-count: 8 -tool-version: 0.13.1 +mode: UPDATE +section-count: 5 +tool-version: 0.14.1 From d02990b66a135e37d89891ab51c314f8de701c81 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:40 +0100 Subject: [PATCH 5/8] Update translation: lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 197 +++++++++++++----------------- 1 file changed, 82 insertions(+), 115 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index af8c186..cb9933c 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -17,8 +17,6 @@ translation: Vectorized operations::Parallelized Numba: 并行化的 Numba Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码 Vectorized operations::JAX plus vmap: JAX 加 vmap - Vectorized operations::JAX plus vmap::Version 1: 版本 1 - Vectorized operations::vmap version 2: vmap 版本 2 Vectorized operations::Summary: 总结 Sequential operations: 顺序运算 Sequential operations::Numba Version: Numba 版本 @@ -26,7 +24,7 @@ translation: Sequential operations::Summary: 总结 --- -(parallel)= +(numpy_numba_jax)= ```{raw} jupyter
@@ -144,7 +142,6 @@ for x in grid: m = z ``` - ### NumPy 向量化 如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。 @@ -157,7 +154,7 @@ for x in grid: grid = np.linspace(-3, 3, 3_000) x, y = np.meshgrid(grid, grid) -with qe.Timer(precision=8): +with qe.Timer(): z_max_numpy = np.max(f(x, y)) print(f"NumPy result: {z_max_numpy:.6f}") @@ -169,7 +166,6 @@ print(f"NumPy result: {z_max_numpy:.6f}") (并行化效率不高,因为二进制文件在看到数组 `x` 和 `y` 的大小之前就已经被编译了。) - ### 与 Numba 的比较 现在让我们看看能否使用简单循环的 Numba 获得更好的性能。 @@ -183,38 +179,42 @@ def compute_max_numba(grid): for x in grid: for y in grid: z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z + m = max(m, z) return m +``` + +让我们测试一下: +```{code-cell} ipython3 grid = np.linspace(-3, 3, 3_000) -with qe.Timer(precision=8): - z_max_numpy = compute_max_numba(grid) +with qe.Timer(): + # First run + z_max_numba = compute_max_numba(grid) -print(f"Numba result: {z_max_numpy:.6f}") +print(f"Numba result: {z_max_numba:.6f}") ``` 让我们再次运行以消除编译时间。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba(grid) ``` 根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快。 -一方面,NumPy 将高效的算术运算(类似 Numba)与一定程度的多线程(不同于这段 Numba 代码)结合在一起,这提供了优势。 +在大多数情况下,我们发现 Numba 略胜一筹。 -另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 +一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势。 +另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 ### 并行化的 Numba 现在让我们使用 `prange` 尝试 Numba 的并行化: -这是一个简单但**不正确**的尝试。 - ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): @@ -225,57 +225,25 @@ def compute_max_numba_parallel(grid): x = grid[i] y = grid[j] z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z + m = max(m, z) return m - ``` -这通常会返回不正确的结果: +以下是预热运行和测试。 ```{code-cell} ipython3 -z_max_parallel_incorrect = compute_max_numba_parallel(grid) -print(f"Numba result: {z_max_parallel_incorrect} 😱") -``` - -原因是变量 `m` 被多个线程共享,但没有得到正确控制。 - -当多个线程同时尝试读写 `m` 时,它们会相互干扰。 - -线程读取了 `m` 的过时值,或者相互覆盖了更新——或者 `m` 始终保持其初始值而从未被更新。 - -这里有一个更仔细编写的版本。 - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel(grid): - n = len(grid) - row_maxes = np.empty(n) - for i in numba.prange(n): - row_max = -np.inf - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > row_max: - row_max = z - row_maxes[i] = row_max - return np.max(row_maxes) -``` - -现在 `for i in numba.prange(n)` 所作用的代码块在不同的 `i` 之间是独立的。 - -每个线程写入数组 `row_maxes` 的不同元素,并行化是安全的。 +with qe.Timer(): + # First run + z_max_parallel = compute_max_numba_parallel(grid) -```{code-cell} ipython3 -z_max_parallel = compute_max_numba_parallel(grid) print(f"Numba result: {z_max_parallel:.6f}") ``` -以下是计时结果。 +以下是预编译版本的计时结果。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba_parallel(grid) ``` @@ -283,14 +251,13 @@ with qe.Timer(precision=8): 对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。 - ### 使用 JAX 的向量化代码 表面上,JAX 中的向量化代码与 NumPy 代码类似。 但两者之间也存在一些差异,我们在这里加以强调。 -让我们从函数开始。 +让我们从函数开始,它将 `np` 替换为 `jnp` 并添加了 `jax.jit`。 ```{code-cell} ipython3 @@ -304,10 +271,16 @@ def f(x, y): ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) -x_mesh, y_mesh = np.meshgrid(grid, grid) +x_mesh, y_mesh = jnp.meshgrid(grid, grid) +``` + +现在让我们运行并计时 -with qe.Timer(precision=8): +```{code-cell} ipython3 +with qe.Timer(): + # First run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() print(f"Plain vanilla JAX result: {z_max:.6f}") @@ -316,16 +289,17 @@ print(f"Plain vanilla JAX result: {z_max:.6f}") 让我们再次运行以消除编译时间。 ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() ``` -编译完成后,由于 GPU 加速,JAX 明显快于 NumPy。 +编译完成后,JAX 明显快于 NumPy,尤其是在 GPU 上。 编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。 - ### JAX 加 vmap NumPy 代码和 JAX 代码都存在一个问题: @@ -346,9 +320,9 @@ x_mesh.nbytes + y_mesh.nbytes 幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。 -#### 版本 1 +`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。 -以下是我们应用 `vmap` 的一种方式。 +以下是我们将其应用于问题的方式。 ```{code-cell} ipython3 # 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y) @@ -362,7 +336,7 @@ f_vec = jax.vmap(f_vec_x) 让我们看看计时结果: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() @@ -370,37 +344,24 @@ print(f"JAX vmap v1 result: {z_max:.6f}") ``` ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() ``` -通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多。 - -在 CPU 上运行时,其运行时间与网格版本相似。 - -在 GPU 上运行时,通常速度要快得多。 - -实际上,使用 `vmap` 还有另一个优势:它允许我们将向量化分阶段进行。 - -这往往会产生比传统向量化代码更易于理解的代码。 - -当我们处理更大的问题时,将进一步探讨这些想法。 - - -### vmap 版本 2 +通过避免使用大型输入数组 `x_mesh` 和 `y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间变化不大。 -我们可以使用 vmap 进一步提高内存效率。 +这很好——但我们还是遗漏了一些速度提升的空间! -在前一个版本中,虽然我们避免了大型输入数组,但在计算最大值之前仍然会创建大型输出数组 `f(x,y)`。 +首先,上面的代码在取最大值之前会计算完整的二维数组 `f(x,y)`,这会产生额外开销。 -让我们尝试一种略有不同的方法,将求最大值操作移到内部。 +其次,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。 -由于这一改变,我们永远不会计算二维数组 `f(x,y)`。 +我们可以通过将求最大值操作移到内部并将所有内容包装在单个 `@jax.jit` 中来解决这两个问题: ```{code-cell} ipython3 @jax.jit -def compute_max_vmap_v2(grid): +def compute_max_vmap(grid): # 构建一个沿每行取最大值的函数 f_vec_x_max = lambda y: jnp.max(f(grid, y)) # 向量化该函数,以便我们可以同时对所有行调用 @@ -416,25 +377,32 @@ def compute_max_vmap_v2(grid): 我们将此函数应用于所有行,然后取各行最大值中的最大值。 +由于我们将求最大值操作移到内部,我们永远不会构建完整的二维数组 `f(x,y)`,从而节省了更多内存。 + +并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。 + 让我们试试。 ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap_v2(grid).block_until_ready() +with qe.Timer(): + # First run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() -print(f"JAX vmap v1 result: {z_max:.6f}") +print(f"JAX vmap result: {z_max:.6f}") ``` 让我们再次运行以消除编译时间: ```{code-cell} ipython3 -with qe.Timer(precision=8): - z_max = compute_max_vmap_v2(grid).block_until_ready() +with qe.Timer(): + # Second run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() ``` -如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。 - - ### 总结 在我们看来,JAX 是向量化运算的赢家。 @@ -449,15 +417,13 @@ with qe.Timer(precision=8): 对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 - ## 顺序运算 某些运算本质上是顺序的——因此难以或不可能向量化。 在这种情况下,NumPy 是一个较差的选择,我们只剩下 Numba 或 JAX 可以选择。 -为了比较这两种选择,我们将重新回顾在{doc}`Numba 讲座 `中看到的迭代二次映射问题。 - +为了比较这两种选择,我们将重新回顾在 {doc}`Numba 讲座 ` 中看到的迭代二次映射问题。 ### Numba 版本 @@ -478,14 +444,16 @@ def qm(x0, n, α=4.0): ```{code-cell} ipython3 n = 10_000_000 -with qe.Timer(precision=8): +with qe.Timer(): + # First run x = qm(0.1, n) ``` 让我们再次运行以消除编译时间: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run x = qm(0.1, n) ``` @@ -507,7 +475,7 @@ from functools import partial cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnums=(1,), device=cpu) +@partial(jax.jit, static_argnames=('n',), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -520,37 +488,36 @@ def qm_jax(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 ```{note} -细心的读者会注意到,我们在 `jax.jit` 装饰器中指定了 `device=cpu`。 - -该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。 - -因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 - -好奇的读者可以尝试删除此选项,看看性能如何变化。 +我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 ``` 让我们使用相同的参数计时: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # First run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` 让我们再次运行以消除编译开销: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # Second run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` JAX 对于这种顺序运算也相当高效。 -JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。 - +JAX 和 Numba 在编译后都能提供出色的性能。 ### 总结 -虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但**在代码可读性和易用性方面存在显著差异**。 +虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但 *在代码可读性和易用性方面存在显著差异*。 Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。 @@ -560,4 +527,4 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。 -对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。 \ No newline at end of file +对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。 From 5e4d1ee4105a3efc551bf81f5d7f3651b97a05d3 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 01:09:41 +0100 Subject: [PATCH 6/8] Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --- .translate/state/numpy_vs_numba_vs_jax.md.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 61c6ef5..34ec88f 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,6 +1,6 @@ -source-sha: c4c03c80c1eb4318f627d869707d242d19c8cf09 -synced-at: "2026-03-20" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 -mode: NEW -section-count: 2 -tool-version: 0.13.0 +mode: UPDATE +section-count: 3 +tool-version: 0.14.1 From 41cc68f7d084d60ab1138322c77c15321818432f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 00:14:09 +0000 Subject: [PATCH 7/8] Fix missing newline between text and code block fence in numpy_vs_numba_vs_jax.md Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/92cf1688-2d57-44e9-8f1e-f49228181b4f Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/numpy_vs_numba_vs_jax.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index abd784e..914830a 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -260,7 +260,9 @@ with qe.Timer(): 但两者之间也存在一些差异,我们在这里加以强调。 -让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit````{code-cell} ipython3 +让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit` + +```{code-cell} ipython3 @jax.jit def f(x, y): return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2) From b299b21eb634570c9a18add1a890db7418cd24ce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 01:37:08 +0000 Subject: [PATCH 8/8] Fix duplicate Speed! section in jax_intro.md, add Size Experiment heading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add #### 大小实验 heading to match English #### Size Experiment structure - Remove duplicate #### 速度! section (was duplicated from overlapping translation PRs) - Add Size Experiment entry to heading map in frontmatter Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/9216d7c8-4733-4e5e-a1c5-74a4d6018ce1 Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/jax_intro.md | 101 ++---------------------------------------- 1 file changed, 3 insertions(+), 98 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 1f6b0ee..2f21d1f 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -18,6 +18,7 @@ translation: JAX as a NumPy Replacement::Differences::Speed!: 速度! JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX + JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验 JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 JAX as a NumPy Replacement::Differences::A workaround: 变通方法 @@ -203,6 +204,8 @@ with qe.Timer(): 大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。 +#### 大小实验 + 我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。 ```{code-cell} @@ -231,104 +234,6 @@ with qe.Timer(): 关于 JIT 编译的进一步讨论见下文。 -#### 速度! - -假设我们想在许多点上求余弦函数的值。 - -```{code-cell} -n = 50_000_000 -x = np.linspace(0, 10, n) -``` - -##### 使用 NumPy - -让我们先用 NumPy 试试: - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -再来一次。 - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -这里: - -* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数 -* 该二进制文件在本地机器的 CPU 上运行 - -##### 使用 JAX - -现在让我们用 JAX 试试。 - -```{code-cell} -x = jnp.linspace(0, 10, n) -``` - -对相同的过程计时。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{note} -这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。 - -这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。 - -对于非计时代码,可以省略包含 `block_until_ready` 的那行。 -``` - -再计时一次。 - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。 - -此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。 - -这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。 - -为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本? - -原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。 - -大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。 - -我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。 - -```{code-cell} -x = jnp.linspace(0, 10, n + 1) -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。 - -这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。 - -关于 JIT 编译的进一步讨论见下文。 - #### 精度 NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。