Skip to content

Commit 5e34ef9

Browse files
authored
Merge pull request #49 from QuantEcon/translation-sync-2026-04-12T13-13-55-pr-525
🌐 [translation-sync] Improve NumPy vs Numba vs JAX lecture
2 parents 7a64951 + 429b7b4 commit 5e34ef9

File tree

2 files changed

+61
-54
lines changed

2 files changed

+61
-54
lines changed

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d
2-
synced-at: "2026-04-09"
1+
source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594
2+
synced-at: "2026-04-12"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ translation:
1717
Vectorized operations::Parallelized Numba: 并行化的 Numba
1818
Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码
1919
Vectorized operations::JAX plus vmap: JAX 加 vmap
20-
Vectorized operations::JAX plus vmap::Version 1: 版本 1
21-
Vectorized operations::vmap version 2: vmap 版本 2
2220
Vectorized operations::Summary: 总结
2321
Sequential operations: 顺序运算
2422
Sequential operations::Numba Version: Numba 版本
@@ -27,7 +25,7 @@ translation:
2725
Overall recommendations: 总体建议
2826
---
2927

30-
(parallel)=
28+
(numpy_numba_jax)=
3129
```{raw} jupyter
3230
<div id="qe-notebook-header" align="right" style="text-align:right;">
3331
<a href="https://quantecon.org/" title="quantecon.org">
@@ -155,7 +153,7 @@ for x in grid:
155153

156154
这里我们使用 `np.meshgrid` 来创建二维输入网格 `x``y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。
157155

158-
(这一策略可以追溯到 Matlab。)
156+
(这一策略可以追溯到 MATLAB。)
159157

160158
```{code-cell} ipython3
161159
grid = np.linspace(-3, 3, 3_000)
@@ -231,24 +229,44 @@ def compute_max_numba_parallel(grid):
231229
232230
```
233231

234-
这通常会返回不正确的结果
232+
这将返回 `-inf`——即 `m` 的初始值,仿佛它从未被更新过
235233

236234
```{code-cell} ipython3
237235
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
238236
print(f"Numba result: {z_max_parallel_incorrect} 😱")
239237
```
240238

241-
原因是变量 `m` 被多个线程共享,但没有得到正确控制
239+
要理解原因,请回忆 `prange` 会将外层循环拆分到各个线程中
242240

243-
当多个线程同时尝试读写 `m` 时,它们会相互干扰
241+
每个线程都会得到自己的 `m` 私有副本,初始化为 `-np.inf`,并在其负责的迭代块中正确地更新它
244242

245-
线程读取了 `m` 的过时值,或者相互覆盖了更新——或者 `m` 始终保持其初始值而从未被更新
243+
但在循环结束时,Numba 需要将各线程的 `m` 副本合并为一个单一的值——即**归约**操作
246244

247-
这里有一个更仔细编写的版本。
245+
对于它能识别的模式,例如 `m += z`(求和)或 `m = max(m, z)`(求最大值),Numba 知道合并算子。
246+
247+
但它无法将 `if z > m: m = z` 识别为最大值归约,因此各线程的结果永远不会被合并,`m` 始终保持其初始值。
248+
249+
最简单的修复方法是将条件判断替换为 Numba 能识别的 `max`
248250

249251
```{code-cell} ipython3
250252
@numba.jit(parallel=True)
251253
def compute_max_numba_parallel(grid):
254+
n = len(grid)
255+
m = -np.inf
256+
for i in numba.prange(n):
257+
for j in range(n):
258+
x = grid[i]
259+
y = grid[j]
260+
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
261+
m = max(m, z)
262+
return m
263+
```
264+
265+
另一种方法是使循环体在不同 `i` 之间完全独立,并自行处理归约:
266+
267+
```{code-cell} ipython3
268+
@numba.jit(parallel=True)
269+
def compute_max_numba_parallel_v2(grid):
252270
n = len(grid)
253271
row_maxes = np.empty(n)
254272
for i in numba.prange(n):
@@ -263,9 +281,7 @@ def compute_max_numba_parallel(grid):
263281
return np.max(row_maxes)
264282
```
265283

266-
现在 `for i in numba.prange(n)` 所作用的代码块在不同的 `i` 之间是独立的。
267-
268-
每个线程写入数组 `row_maxes` 的不同元素,并行化是安全的。
284+
在这里,每个线程写入 `row_maxes` 的不同元素,因此我们通过 `np.max` 自行处理归约。
269285

270286
```{code-cell} ipython3
271287
z_max_parallel = compute_max_numba_parallel(grid)
@@ -320,13 +336,13 @@ with qe.Timer(precision=8):
320336
z_max.block_until_ready()
321337
```
322338

323-
编译完成后,JAX 明显快于 NumPy,尤其是在 GPU
339+
编译完成后,JAX 明显快于 NumPy, GPU 上尤为如此
324340

325341
编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。
326342

327343
### JAX 加 vmap
328344

329-
NumPy 代码和 JAX 代码都存在一个问题:
345+
NumPy 代码和上述 JAX 代码都存在一个问题:
330346

331347
虽然扁平数组占用内存较少
332348

@@ -344,9 +360,9 @@ x_mesh.nbytes + y_mesh.nbytes
344360

345361
幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。
346362

347-
#### 版本 1
363+
`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。
348364

349-
以下是我们应用 `vmap` 的一种方式
365+
以下是我们将其应用于当前问题的方式
350366

351367
```{code-cell} ipython3
352368
# 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y)
@@ -373,31 +389,19 @@ with qe.Timer(precision=8):
373389
z_max.block_until_ready()
374390
```
375391

376-
通过避免使用大型输入数组 `x_mesh``y_mesh`,这个 `vmap` 版本使用的内存少得多。
377-
378-
在 CPU 上运行时,其运行时间与网格版本相似。
379-
380-
在 GPU 上运行时,通常速度要快得多。
381-
382-
实际上,使用 `vmap` 还有另一个优势:它允许我们将向量化分阶段进行。
383-
384-
这往往会产生比传统向量化代码更易于理解的代码。
385-
386-
当我们处理更大的问题时,将进一步探讨这些想法。
392+
通过避免使用大型输入数组 `x_mesh``y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间也相近。
387393

388-
### vmap 版本 2
394+
但我们仍然留有一些速度提升的空间未被利用。
389395

390-
我们可以使用 vmap 进一步提高内存效率
396+
上面的代码计算了完整的二维数组 `f(x,y)`,然后再取最大值
391397

392-
在前一个版本中,虽然我们避免了大型输入数组,但在计算最大值之前仍然会创建大型输出数组 `f(x,y)`
398+
此外,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核
393399

394-
让我们尝试一种略有不同的方法,将求最大值操作移到内部。
395-
396-
由于这一改变,我们永远不会计算二维数组 `f(x,y)`
400+
我们可以通过将最大值操作移到内部并将所有内容包装在一个 `@jax.jit` 中来解决这两个问题:
397401

398402
```{code-cell} ipython3
399403
@jax.jit
400-
def compute_max_vmap_v2(grid):
404+
def compute_max_vmap(grid):
401405
# 构建一个沿每行取最大值的函数
402406
f_vec_x_max = lambda y: jnp.max(f(grid, y))
403407
# 向量化该函数,以便我们可以同时对所有行调用
@@ -413,24 +417,26 @@ def compute_max_vmap_v2(grid):
413417

414418
我们将此函数应用于所有行,然后取各行最大值中的最大值。
415419

420+
由于将最大值操作移到内部,我们永远不会构建完整的二维数组 `f(x,y)`,从而节省了更多内存。
421+
422+
并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。
423+
416424
让我们试试。
417425

418426
```{code-cell} ipython3
419427
with qe.Timer(precision=8):
420-
z_max = compute_max_vmap_v2(grid).block_until_ready()
428+
z_max = compute_max_vmap(grid).block_until_ready()
421429
422-
print(f"JAX vmap v2 result: {z_max:.6f}")
430+
print(f"JAX vmap result: {z_max:.6f}")
423431
```
424432

425433
让我们再次运行以消除编译时间:
426434

427435
```{code-cell} ipython3
428436
with qe.Timer(precision=8):
429-
z_max = compute_max_vmap_v2(grid).block_until_ready()
437+
z_max = compute_max_vmap(grid).block_until_ready()
430438
```
431439

432-
如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。
433-
434440
### 总结
435441

436442
在我们看来,JAX 是向量化运算的赢家。
@@ -536,7 +542,7 @@ with qe.Timer(precision=8):
536542

537543
JAX 对于这种顺序运算也相当高效。
538544

539-
JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度
545+
JAX 和 Numba 在编译后都能提供出色的性能。
540546

541547
### 总结
542548

@@ -550,30 +556,31 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
550556

551557
此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。
552558

553-
对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
559+
对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
554560

555561
## 总体建议
556562

557-
让我们退一步,总结一下各方案的权衡取舍
563+
让我们退一步,总结一下各方的权衡取舍
558564

559565
对于**向量化操作**,JAX 是最强的选择。
560566

561-
得益于 JIT 编译和跨 CPU 与 GPU 的高效并行化,它在速度上与 NumPy 持平甚至超越 NumPy。
567+
得益于 JIT 编译和在 CPU 与 GPU 上的高效并行化,它在速度上与 NumPy 持平或超越 NumPy。
562568

563-
`vmap` 变换可以减少内存使用,并且通常比基于传统网格(meshgrid)的向量化方式产生更清晰的代码
569+
`vmap` 变换降低了内存使用量,并且通常比传统的基于网格的向量化产生更清晰的代码
564570

565-
此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进行探讨
571+
此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进一步探讨
566572

567573
对于**顺序操作**,Numba 具有明显优势。
568574

569-
代码自然易读——只需一个带装饰器的 Python 循环——且性能出色
575+
代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色
570576

571-
JAX 可以通过 `lax.scan` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限
577+
JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观
572578

573-
话虽如此,`lax.scan` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。
574-
575-
如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
579+
```{note}
580+
`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
581+
如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
582+
```
576583

577-
在实践中,许多问题往往同时涉及两种模式
584+
在实践中,许多问题涉及两种模式的混合
578585

579-
一个实用的经验法则是:新项目默认使用 JAX,尤其是在硬件加速或可微分性可能有用的情况下;当需要一个快速且可读的紧凑顺序循环时,则选用 Numba。
586+
一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba。

0 commit comments

Comments
 (0)