Skip to content
Merged
4 changes: 2 additions & 2 deletions .translate/state/jax_intro.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 8d73de367a7f160dac777aa557f1c26069f84ea5
synced-at: "2026-04-12"
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
synced-at: "2026-04-13"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 7
Expand Down
4 changes: 2 additions & 2 deletions .translate/state/numba.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: be6eeaee8db0c8bfea65b89d57ca8aecf7f96dff
synced-at: "2026-04-12"
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
synced-at: "2026-04-13"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 5
Expand Down
4 changes: 2 additions & 2 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594
synced-at: "2026-04-12"
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
synced-at: "2026-04-13"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
Expand Down
122 changes: 10 additions & 112 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ translation:
JAX as a NumPy Replacement::Differences::Speed!: 速度!
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy
JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
JAX as a NumPy Replacement::Differences::Precision: 精度
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
Expand All @@ -31,13 +32,11 @@ translation:
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
JIT Compilation: JIT 编译
JIT Compilation::Evaluating a more complicated function: 评估更复杂的函数
JIT Compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy
JIT Compilation::Evaluating a more complicated function::With JAX: 使用 JAX
JIT Compilation::Compiling the whole function: 编译整个函数
JIT Compilation::With NumPy: 使用 NumPy
JIT Compilation::With JAX: 使用 JAX
JIT Compilation::Compiling the Whole Function: 编译整个函数
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
JIT Compilation::Compiling non-pure functions: 编译非纯函数
JIT Compilation::Summary: 总结
Exercises: 练习
---

Expand Down Expand Up @@ -205,6 +204,8 @@ with qe.Timer():

大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。

#### 大小实验

我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。

```{code-cell}
Expand Down Expand Up @@ -233,105 +234,6 @@ with qe.Timer():

关于 JIT 编译的进一步讨论见下文。

(jax_speed)=
#### 速度!

假设我们想在许多点上求余弦函数的值。

```{code-cell}
n = 50_000_000
x = np.linspace(0, 10, n)
```

##### 使用 NumPy

让我们先用 NumPy 试试:

```{code-cell}
with qe.Timer():
y = np.cos(x)
```

再来一次。

```{code-cell}
with qe.Timer():
y = np.cos(x)
```

这里:

* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数
* 该二进制文件在本地机器的 CPU 上运行

##### 使用 JAX

现在让我们用 JAX 试试。

```{code-cell}
x = jnp.linspace(0, 10, n)
```

对相同的过程计时。

```{code-cell}
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
```

```{note}
这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。

这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。

对于非计时代码,可以省略包含 `block_until_ready` 的那行。
```

再计时一次。

```{code-cell}
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
```

在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。

此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。

这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。

为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?

原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。

大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。

我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。

```{code-cell}
x = jnp.linspace(0, 10, n + 1)
```

```{code-cell}
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
```

```{code-cell}
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
```

运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。

这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。

关于 JIT 编译的进一步讨论见下文。

#### 精度

NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。
Expand Down Expand Up @@ -731,19 +633,15 @@ JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高

我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。

让我们用一个更复杂的函数尝试同样的操作。

### 评估更复杂的函数

考虑以下函数:
让我们用一个更复杂的函数尝试同样的操作:

```{code-cell}
def f(x):
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
return y
```

#### 使用 NumPy
### 使用 NumPy

我们先用 NumPy 试试:

Expand All @@ -758,7 +656,7 @@ with qe.Timer():
y = f(x)
```

#### 使用 JAX
### 使用 JAX

现在让我们用 JAX 再试一次。

Expand Down Expand Up @@ -793,7 +691,7 @@ with qe.Timer():

结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。

然而,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。
然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。

### 编译整个函数

Expand Down
6 changes: 3 additions & 3 deletions lectures/numba.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ n = 10_000_000

with qe.Timer() as timer1:
# Time Python base version
x = qm(0.1, int(n))
x = qm(0.1, n)

```

Expand Down Expand Up @@ -160,7 +160,7 @@ qm_numba = jit(qm)
```{code-cell} ipython3
with qe.Timer() as timer2:
# Time jitted version
x = qm_numba(0.1, int(n))
x = qm_numba(0.1, n)
```

这已经是非常大的速度提升。
Expand All @@ -172,7 +172,7 @@ with qe.Timer() as timer2:
```{code-cell} ipython3
with qe.Timer() as timer3:
# Second run
x = qm_numba(0.1, int(n))
x = qm_numba(0.1, n)
```

以下是速度提升
Expand Down
Loading
Loading