From 4b45330363c835cdd0df3084885894dc13882547 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 13:48:54 +0100 Subject: [PATCH 1/7] Update translation: lectures/jax_intro.md --- lectures/jax_intro.md | 220 ++++++++++++++++++++++++++++-------------- 1 file changed, 149 insertions(+), 71 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 81ed1eb..f904ed7 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -15,9 +15,13 @@ translation: JAX as a NumPy Replacement: JAX 作为 NumPy 的替代品 JAX as a NumPy Replacement::Similarities: 相似之处 JAX as a NumPy Replacement::Differences: 差异 + JAX as a NumPy Replacement::Differences::Speed!: 速度! + JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy + JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX + JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验 JAX as a NumPy Replacement::Differences::Precision: 精度 JAX as a NumPy Replacement::Differences::Immutability: 不可变性 - JAX as a NumPy Replacement::Differences::A workaround: 变通方法 + JAX as a NumPy Replacement::Differences::A Workaround: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 Functional Programming::Examples: 示例 @@ -26,18 +30,12 @@ translation: Random numbers::Why explicit random state?: 为什么要显式随机状态? Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法 Random numbers::Why explicit random state?::JAX's approach: JAX 的方法 - JIT compilation: JIT 编译 - JIT compilation::A simple example: 一个简单的示例 - JIT compilation::A simple example::With NumPy: 使用 NumPy - JIT compilation::A simple example::With JAX: 使用 JAX - JIT compilation::A simple example::Changing array sizes: 更改数组大小 - JIT compilation::Evaluating a more complicated function: 评估更复杂的函数 - JIT compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy - JIT compilation::Evaluating a more complicated function::With JAX: 使用 JAX - JIT compilation::How JIT compilation works: JIT 编译的工作原理 - JIT compilation::Compiling the whole function: 编译整个函数 - JIT compilation::Compiling non-pure functions: 编译非纯函数 - JIT compilation::Summary: 总结 + JIT Compilation: JIT 编译 + JIT Compilation::With NumPy: 一个简单的示例 + JIT Compilation::With JAX: 评估更复杂的函数 + JIT Compilation::Compiling the Whole Function: JIT 编译的工作原理 + JIT Compilation::How JIT compilation works: 编译整个函数 + JIT Compilation::Compiling non-pure functions: 编译非纯函数 Vectorization with `vmap`: 使用 `vmap` 进行向量化 Vectorization with `vmap`::A simple example: 一个简单的示例 Vectorization with `vmap`::Combining transformations: 组合变换 @@ -49,13 +47,16 @@ translation: 本讲座简要介绍 [Google JAX](https://github.com/jax-ml/jax)。 +```{include} _admonition/gpu.md +``` + JAX 是一个高性能科学计算库,提供以下功能: * 类似 [NumPy](https://en.wikipedia.org/wiki/NumPy) 的接口,可以在 CPU 和 GPU 上自动并行化, * 一个即时编译器,用于加速大量数值运算,以及 * [自动微分](https://en.wikipedia.org/wiki/Automatic_differentiation)。 -JAX 也在日益维护和提供[更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。 +JAX 也在日益维护和提供 [更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。 除了 Anaconda 中已有的内容外,本讲座还需要以下库: @@ -65,36 +66,27 @@ JAX 也在日益维护和提供[更多专业化的科学计算例程](https://do !pip install jax quantecon ``` -```{include} _admonition/gpu.md -``` - -## JAX 作为 NumPy 的替代品 - -JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。 - -这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。 - -让我们来看看 JAX 和 NumPy 之间的异同。 - -### 相似之处 - 我们将使用以下导入: ```{code-cell} ipython3 import jax import jax.numpy as jnp import matplotlib.pyplot as plt -import matplotlib.patches as mpatches import numpy as np import quantecon as qe -import matplotlib as mpl # i18n -import matplotlib.font_manager # i18n -FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" # i18n -mpl.font_manager.fontManager.addfont(FONTPATH) # i18n -mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ``` -注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。 +## JAX 作为 NumPy 的替代品 + +让我们来看看 JAX 和 NumPy 之间的异同。 + +### 相似之处 + +上面我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的数组操作接口。 + +JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 NumPy API。 + +因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。 以下是使用 `jnp` 进行的一些标准数组操作: @@ -110,15 +102,11 @@ print(a) print(jnp.sum(a)) ``` -```{code-cell} ipython3 -print(jnp.mean(a)) -``` - ```{code-cell} ipython3 print(jnp.dot(a, a)) ``` -然而,数组对象 `a` 并不是 NumPy 数组: +但需要注意的是,数组对象 `a` 并不是 NumPy 数组: ```{code-cell} ipython3 a @@ -128,37 +116,131 @@ a type(a) ``` -即使是数组上的标量值映射也会返回 JAX 数组。 +即使是数组上的标量值映射也会返回 JAX 数组而非标量! ```{code-cell} ipython3 jnp.sum(a) ``` -对高维数组的操作也与 NumPy 类似: +### 差异 + +现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 + +(jax_speed)= +#### 速度! -```{code-cell} ipython3 -A = jnp.ones((2, 2)) -B = jnp.identity(2) -A @ B +一个主要差异是 JAX 更快——有时快得多。 + +为了说明这一点,假设我们想在许多点处计算余弦函数。 + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) # NumPy array ``` -JAX 的数组接口也提供了 `linalg` 子包: +##### 使用 NumPy -```{code-cell} ipython3 -jnp.linalg.inv(B) # Inverse of identity is identity +让我们用 NumPy 来试试。 + +```{code-cell} +with qe.Timer(): + # First NumPy timing + y = np.cos(x) ``` -```{code-cell} ipython3 -jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors +再来一次。 + +```{code-cell} +with qe.Timer(): + # Second NumPy timing + y = np.cos(x) ``` -### 差异 +这里 -现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 +* NumPy 使用预编译的二进制文件将余弦函数应用于浮点数数组 +* 该二进制文件在本地机器的 CPU 上运行 + +##### 使用 JAX + +现在让我们用 JAX 来试试。 + +```{code-cell} +x = jnp.linspace(0, 10, n) +``` + +让我们对同样的过程计时。 + +```{code-cell} +with qe.Timer(): + # First run + y = jnp.cos(x) + # Hold the interpreter until the array operation finishes + y.block_until_ready() +``` + +```{note} +上面,`block_until_ready` 方法会阻塞解释器,直到计算结果返回。 +这对于计时执行是必要的,因为 JAX 使用异步调度, +允许 Python 解释器在数值计算之前运行。 +``` + +现在让我们再次计时。 + +```{code-cell} +with qe.Timer(): + # Second run + y = jnp.cos(x) + # Hold interpreter + y.block_until_ready() +``` + +在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。 + +另外,通常情况下,由于 JIT 编译,第二次运行比第一次更快。 + +这是因为即使是像 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。 + +为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本呢? + +原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。 + +大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件匹配。 + +#### 大小实验 + +我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。 + +```{code-cell} +x = jnp.linspace(0, 10, n + 1) +``` + +```{code-cell} +with qe.Timer(): + # First run + y = jnp.cos(x) + # Hold interpreter + y.block_until_ready() +``` + + +```{code-cell} +with qe.Timer(): + # Second run + y = jnp.cos(x) + # Hold interpreter + y.block_until_ready() +``` + +运行时间先增加后减少(这在 GPU 上会更明显)。 + +这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。 + +下面将进一步讨论 JIT 编译。 #### 精度 -NumPy 和 JAX 之间的一个差异是 JAX 默认使用 32 位浮点数。 +NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。 这是因为 JAX 经常用于 GPU 计算,而大多数 GPU 计算使用 32 位浮点数。 @@ -196,7 +278,8 @@ a[0] = 1 a ``` -在 JAX 中,这会失败: +在 JAX 中,这会失败 😱。 + ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -204,32 +287,25 @@ a ``` ```{code-cell} ipython3 -:tags: [raises-exception] +try: + a[0] = 1 +except Exception as e: + print(e) -a[0] = 1 ``` -与不可变性一致,JAX 不支持原地操作: - -```{code-cell} ipython3 -a = np.array((2, 1)) -a.sort() # Unlike NumPy, does not mutate a -a -``` +JAX 的设计者选择将数组设为不可变的,因为 -```{code-cell} ipython3 -a = jnp.array((2, 1)) -a_new = a.sort() # Instead, the sort method returns a new sorted array -a, a_new -``` +1. JAX 使用*函数式编程风格*,并且 +2. 函数式编程通常避免可变数据 -JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [函数式编程](https://en.wikipedia.org/wiki/Functional_programming) 风格。 +我们将在 {ref}`下面 ` 讨论这些思想。 -这个设计选择有重要的含义,我们接下来将对此进行探讨! +(jax_at_workaround)= #### 变通方法 -我们注意到 JAX 确实提供了一种使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 进行原地数组修改的版本。 +JAX 确实通过 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 提供了原地数组修改的直接替代方案。 ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -251,6 +327,8 @@ a (尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。) + +(jax_func)= ## 函数式编程 来自 JAX 的文档: From 0afe7073a7d674585cd7e69a21207a0cf0f1a80f Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 13:48:55 +0100 Subject: [PATCH 2/7] Update translation: .translate/state/jax_intro.md.yml --- .translate/state/jax_intro.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index bbee313..414535f 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d -synced-at: "2026-04-09" +source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 From 5f70fd1a6f6b94eb5feb10f955be7e2c9f93ea5e Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 13:48:56 +0100 Subject: [PATCH 3/7] Update translation: lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 109 ++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 34 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index a7f322d..e3846ec 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -17,9 +17,7 @@ translation: Vectorized operations::Parallelized Numba: 并行化的 Numba Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码 Vectorized operations::JAX plus vmap: JAX 加 vmap - Vectorized operations::JAX plus vmap::Version 1: 版本 1 - Vectorized operations::vmap version 2: vmap 版本 2 - Vectorized operations::Summary: 总结 + Vectorized operations::Summary: vmap 版本 2 Sequential operations: 顺序运算 Sequential operations::Numba Version: Numba 版本 Sequential operations::JAX Version: JAX 版本 @@ -27,7 +25,7 @@ translation: Overall recommendations: 总体建议 --- -(parallel)= +(numpy_numba_jax)= ```{raw} jupyter
@@ -69,7 +67,6 @@ tags: [hide-output] 我们将使用以下导入。 ```{code-cell} ipython3 -import random from functools import partial import numpy as np @@ -472,14 +469,16 @@ def qm(x0, n, α=4.0): ```{code-cell} ipython3 n = 10_000_000 -with qe.Timer(precision=8): +with qe.Timer(): + # First run x = qm(0.1, n) ``` 让我们再次运行以消除编译时间: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run x = qm(0.1, n) ``` @@ -491,15 +490,62 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代 ### JAX 版本 -现在让我们使用 `lax.scan` 创建一个 JAX 版本: +现在让我们使用 `at[t].set` 风格的语法创建一个 JAX 版本,正如 {ref}`JAX 讲座中讨论的 `,这为不可变数组提供了一种变通方法。 -(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。) +我们将使用 `lax.fori_loop`,它是一种可以被 XLA 编译的 for 循环版本。 ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnums=(1,), device=cpu) -def qm_jax(x0, n, α=4.0): +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_fori(x0, n, α=4.0): + + x = jnp.empty(n + 1).at[0].set(x0) + + def update(t, x): + return x.at[t + 1].set(α * x[t] * (1 - x[t])) + + x = lax.fori_loop(0, n, update, x) + return x + +``` + +* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。 +* 我们通过 `device=cpu` 将计算固定在 CPU 上,因为这种顺序工作负载由许多小操作组成,几乎没有机会利用 GPU 并行性。 + +虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新。 + +让我们使用相同的参数计时: + +```{code-cell} ipython3 +with qe.Timer(): + # First run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +让我们再次运行以消除编译开销: + +```{code-cell} ipython3 +with qe.Timer(): + # Second run + x_jax = qm_jax_fori(0.1, n) + # Hold interpreter + x_jax.block_until_ready() +``` + +JAX 对于这种顺序运算也相当高效。 + + +我们还有另一种实现循环的方式,使用 `lax.scan`。 + +这种替代方案可以说更符合 JAX 的函数式方法——尽管语法难以记忆。 + + +```{code-cell} ipython3 +@partial(jax.jit, static_argnames=("n",), device=cpu) +def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) return x_new, x_new @@ -510,33 +556,27 @@ def qm_jax(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 -```{note} -细心的读者会注意到,我们在 `jax.jit` 装饰器中指定了 `device=cpu`。 - -该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。 - -因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 - -好奇的读者可以尝试删除此选项,看看性能如何变化。 -``` - 让我们使用相同的参数计时: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # First run + x_jax = qm_jax_scan(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` 让我们再次运行以消除编译开销: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # Second run + x_jax = qm_jax_scan(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` -JAX 对于这种顺序运算也相当高效。 - -JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。 +JAX 和 Numba 在编译后都能提供出色的性能。 ### 总结 @@ -546,11 +586,11 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 这正是大多数程序员思考该算法的方式。 -另一方面,JAX 版本需要使用 `lax.scan`,这明显不够直观。 +另一方面,JAX 版本需要使用 `lax.fori_loop` 或 `lax.scan`,两者都比标准 Python 循环更不直观。 -此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。 +虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。 -对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。 +对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。 ## 总体建议 @@ -568,11 +608,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。 -JAX 可以通过 `lax.scan` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。 - -话虽如此,`lax.scan` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。 +JAX 可以通过 `lax.fori_loop` 或 `lax.scan` 处理顺序问题,但语法不够直观。 +```{note} +`lax.fori_loop` 和 `lax.scan` 有一个重要优势:它们支持对循环进行自动微分,而 Numba 无法做到这一点。 如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。 +``` 在实践中,许多问题往往同时涉及两种模式。 From 8904ddca5b8a369a1ecb73cd0bccf28f12f6b646 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Mon, 13 Apr 2026 13:48:56 +0100 Subject: [PATCH 4/7] Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --- .translate/state/numpy_vs_numba_vs_jax.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 0798448..3f071e6 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d -synced-at: "2026-04-09" +source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 From 5f846b69a82e6ef0d8c8e9fa261eb5676112925d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 02:03:31 +0000 Subject: [PATCH 5/7] Fix translation inconsistencies with English source files Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/62d01b3e-e9e2-420f-9d68-71e64461260f Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/jax_intro.md | 57 ++++++++++++++++++++++--------- lectures/numpy_vs_numba_vs_jax.md | 6 ---- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index a8c8938..55496be 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -25,7 +25,7 @@ translation: Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 Functional Programming::Examples: 示例 - Functional Programming::Why Functional Programming?: 为什么使用函数式编程? + Functional Programming::Why Functional Programming?: 为什么要函数式编程? Random numbers: 随机数 Random numbers::Random number generation: 随机数生成 Random numbers::Why explicit random state?: 为什么要显式随机状态? @@ -37,6 +37,10 @@ translation: JIT Compilation::Compiling the Whole Function: 编译整个函数 JIT Compilation::How JIT compilation works: JIT 编译的工作原理 JIT Compilation::Compiling non-pure functions: 编译非纯函数 + Vectorization with vmap: 使用 vmap 进行向量化 + Vectorization with vmap::A simple example: 一个简单的示例 + Vectorization with vmap::Combining transformations: 组合变换 + Automatic differentiation: a preview: 自动微分:预览 Exercises: 练习 --- @@ -85,16 +89,6 @@ JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。 -## JAX 作为 NumPy 的替代品 - -JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。 - -这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。 - -让我们来看看 JAX 和 NumPy 之间的异同。 - -### 相似之处 - 以下是使用 `jnp` 进行的一些标准数组操作: ```{code-cell} ipython3 @@ -187,11 +181,8 @@ with qe.Timer(): ``` ```{note} -这里,为了测量实际速度,我们使用 `block_until_ready` 方法来阻塞解释器,直到计算结果返回。 - -这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。 - -对于非计时代码,可以删除包含 `block_until_ready` 的那一行。 +上面的 `block_until_ready` 方法会阻塞解释器,直到计算结果返回。 +这对于计时是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前继续运行。 ``` 再来计时一次。 @@ -869,6 +860,40 @@ fast_batch_mm_diff(X) `jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。 + +## 自动微分:预览 + +JAX 可以使用自动微分来计算梯度。 + +这对于优化和求解非线性系统非常有用。 + +以下是一个简单的示例,涉及函数 $f(x) = x^2 / 2$: + +```{code-cell} ipython3 +def f(x): + return (x**2) / 2 + +f_prime = jax.grad(f) +``` + +```{code-cell} ipython3 +f_prime(10.0) +``` + +让我们绘制函数和导数,注意 $f'(x) = x$。 + +```{code-cell} ipython3 +fig, ax = plt.subplots() +x_grid = jnp.linspace(-4, 4, 200) +ax.plot(x_grid, f(x_grid), label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend(loc='upper center') +plt.show() +``` + +自动微分是一个有许多经济学和金融学应用的深层主题。我们在{doc}`自动微分讲座 `中提供了更深入的讨论。 + + ## 练习 diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 6467d3e..25ee779 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -532,10 +532,6 @@ def qm_jax_scan(x0, n, α=4.0): 这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。 -```{note} -我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。 -``` - 让我们使用相同的参数计时: ```{code-cell} ipython3 @@ -556,8 +552,6 @@ with qe.Timer(): x_jax.block_until_ready() ``` -JAX 对于这种顺序运算也相当高效。 - JAX 和 Numba 在编译后都能提供出色的性能。 ### 总结 From f6569d26567ddb549d66098f3e29d1a1d80bca98 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 02:09:21 +0000 Subject: [PATCH 6/7] Fix numba.md translation inconsistencies with English source Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/f1e1dea8-f49a-43ec-a291-272349dbe061 Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/numba.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lectures/numba.md b/lectures/numba.md index 46ebf29..5c8223d 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -16,11 +16,9 @@ translation: Compiling Functions: 编译函数 Compiling Functions::An Example: 示例 Compiling Functions::How and When it Works: 工作原理与适用时机 - Type Inference: 类型推断 - Dangers and Limitations: 危险与局限 - Dangers and Limitations::Limitations: 局限性 - 'Dangers and Limitations::A Gotcha: Global Variables': 一个陷阱:全局变量 - Dangers and Limitations::Caching Compiled Code: 缓存编译代码 + Sharp Bits: 注意事项 + Sharp Bits::Typing: 类型推断 + Sharp Bits::Global Variables: 全局变量 Multithreaded Loops in Numba: Numba 中的多线程循环 Exercises: 练习 --- @@ -185,7 +183,7 @@ timer1.elapsed / timer3.elapsed 让我们讨论一下这是如何工作的。 -### 工作原理与适用场景 +### 工作原理与适用时机 Numba 尝试使用 [LLVM Project](https://llvm.org/) 提供的基础设施生成快速机器码。 @@ -421,6 +419,8 @@ with qe.Timer(): 注意,我们是跨家庭进行并行化,而非跨时间——单个家庭跨时期的更新本质上是顺序的。 +关于基于 GPU 的并行化,请参阅我们{doc}`关于 JAX 的讲座 `。 + ## 练习 ```{exercise} From a73e80c11c9d6cd019d66902fc72b118265b71d5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 02:16:22 +0000 Subject: [PATCH 7/7] Fix need_for_speed.md: correct Multithreading section heading and garbled body text Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/336ee7bd-5e55-4f44-bccf-ce3f9734101f Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/need_for_speed.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lectures/need_for_speed.md b/lectures/need_for_speed.md index ddfa453..624da84 100644 --- a/lectures/need_for_speed.md +++ b/lectures/need_for_speed.md @@ -393,13 +393,9 @@ with qe.Timer(): 让我们回顾一下科学计算中常用的两种主要 CPU 并行化方式,并讨论它们的优缺点。 -#### 多进程 - -多进程是指使用多个处理器并发执行多条逻辑线程。 - -多进程可以在一台拥有多个 CPU 的机器上进行,也可以在通过网络连接的机器集群上进行。 +#### 多线程 -在多进程中,*每个进程都有自己的内存空间*,尽管物理内存芯片可能是共享的。 +多线程是指在单个进程中运行多个执行线程。 所有线程共享同一内存空间,因此它们可以在不复制数据的情况下对同一数组进行读写。