Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .translate/state/jax_intro.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
synced-at: "2026-04-13"
source-sha: 450bafecd23db638602150b47f4272b98aad3146
synced-at: "2026-04-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 7
Expand Down
4 changes: 2 additions & 2 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
synced-at: "2026-04-13"
source-sha: 450bafecd23db638602150b47f4272b98aad3146
synced-at: "2026-04-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
Expand Down
183 changes: 69 additions & 114 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,24 @@ translation:
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
JAX as a NumPy Replacement::Differences::Precision: 精度
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
Functional Programming: 函数式编程
Functional Programming::Pure functions: 纯函数
Functional Programming::Examples: 示例
Functional Programming::Why Functional Programming?: 为什么使用函数式编程
Functional Programming::Why Functional Programming?: 为什么要函数式编程
Random numbers: 随机数
Random numbers::Random number generation: 随机数生成
Random numbers::Why explicit random state?: 为什么要显式随机状态?
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
Random numbers::JAX: JAX
Random numbers::Benefits: 好处
JIT Compilation: JIT 编译
JIT Compilation::With NumPy: 使用 NumPy
JIT Compilation::With JAX: 使用 JAX
JIT Compilation::Compiling the Whole Function: 编译整个函数
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
JIT Compilation::Compiling non-pure functions: 编译非纯函数
Vectorization with `vmap`: 使用 `vmap` 进行向量化
Vectorization with `vmap`::A simple example: 一个简单的示例
Vectorization with `vmap`::Combining transformations: 组合变换
Exercises: 练习
---

Expand Down Expand Up @@ -404,13 +406,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射

JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。

起初,您可能会觉得语法相当冗长。
### NumPy / MATLAB 方法

但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的
在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作

此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn(2))
```

每次我们调用随机函数时,隐藏状态都会被更新:

```{code-cell} ipython3
print(np.random.randn(2))
```

这个函数*不是纯函数*,因为:

### 随机数生成
* 它是非确定性的:相同的输入,不同的输出
* 它有副作用:它修改了全局随机数生成器状态

在并行化下很危险——必须仔细控制每个线程中发生的事情!

### JAX

在 JAX 中,随机数生成器的状态被显式控制。

Expand Down Expand Up @@ -531,119 +549,48 @@ def gen_random_matrices(key, n=2, k=3):
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
seed = 42
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```

我们也可以在循环迭代时使用 `fold_in`:

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
matrices = []
for i in range(k):
step_key = jax.random.fold_in(key, i)
A = jax.random.uniform(step_key, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```

### 为什么要显式随机状态?

为什么 JAX 需要这种相对冗长的随机数生成方法?

一个原因是为了维护纯函数。

让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。

#### NumPy 的方法

在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。

每次我们调用随机函数时,这个状态都会被更新:

```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn()) # Updates state of random number generator
print(np.random.randn()) # Updates state of random number generator
gen_random_matrices(key)
```

每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。
这个函数是*纯函数*:

这个函数*不是纯函数*,因为:

* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
* 它有副作用:它修改了全局随机数生成器状态

#### JAX 的方法

如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。

例如:

```{code-cell} ipython3
def random_sum_jax(key):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1)
y = jax.random.normal(key2)
return x + y
```

使用相同的密钥,我们总是得到相同的结果:

```{code-cell} ipython3
key = jax.random.key(42)
random_sum_jax(key)
```

```{code-cell} ipython3
random_sum_jax(key)
```

要获得新的抽取,我们需要提供一个新密钥。

函数 `random_sum_jax` 是纯函数,因为:

* 它是确定性的:相同的密钥总是产生相同的输出
* 确定性的:相同的输入,相同的输出
* 无副作用:没有隐藏状态被修改

### 好处

JAX 的显式性带来了显著的好处:

* 可复现性:通过重用密钥轻松重现结果
* 并行化:每个线程可以拥有自己的密钥而不会产生冲突
* 调试:没有隐藏状态使代码更容易推理
* 并行化:控制各个线程上发生的事情
* 调试:没有隐藏状态使代码更容易测试
* JIT 兼容性:编译器可以更积极地优化纯函数

最后一点将在下一节中进行扩展。

## JIT 编译

JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。

我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。

让我们用一个更复杂的函数尝试同样的操作:
这里我们研究更复杂函数的 JIT 编译。

### 使用 NumPy

我们先用 NumPy 试试,使用:

```{code-cell}
def f(x):
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
return y
```

### 使用 NumPy

我们先用 NumPy 试试:
用较大的 `x` 运行:

```{code-cell}
n = 50_000_000
Expand All @@ -656,9 +603,17 @@ with qe.Timer():
y = f(x)
```

### 使用 JAX
**即时**执行模型

* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。

缺点

* 并行化程度最低
* 较大的内存占用——产生许多中间数组
* 大量内存读写

现在让我们用 JAX 再试一次。
### 使用 JAX

作为第一步,我们将整个代码中的 `np` 替换为 `jnp`:

Expand Down Expand Up @@ -691,11 +646,13 @@ with qe.Timer():

结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。

然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作
但我们仍在使用即时执行——大量内存和读写开销

### 编译整个函数

JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。
幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。

编译器将所有数组运算融合到单个优化内核中。

让我们用函数 `f` 来试试这个:

Expand All @@ -719,9 +676,11 @@ with qe.Timer():
jax.block_until_ready(y);
```

运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
运行时间再次改善——现在是因为我们融合了所有操作

例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
* 基于整个计算序列的积极优化
* 消除对硬件加速器的多次调用
* 不创建中间数组

顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:

Expand All @@ -741,11 +700,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP

### 编译非纯函数

现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!

虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。

以下是一个使用全局变量的例子:
以下是一个例子:

```{code-cell} ipython3
a = 1 # global
Expand Down Expand Up @@ -789,9 +746,9 @@ f(x)

## 使用 `vmap` 进行向量化

JAX 的另一个强大变换是 `jax.vmap`,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行
JAX 的另一个强大变换是 `jax.vmap`,它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作

这避免了手动编写向量化代码或使用显式循环的需要
这样就无需手动编写向量化代码或使用显式循环

### 一个简单的示例

Expand All @@ -809,7 +766,7 @@ x = jnp.array([1.0, 2.0, 5.0])
mm_diff(x)
```

现在假设我们有一个矩阵,想要对每一行计算这些统计量
现在假设我们有一个矩阵,希望对每一行计算这些统计量

不使用 `vmap` 时,我们需要显式循环:

Expand All @@ -824,18 +781,16 @@ for row in X:

然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。

使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用
使用 `vmap`,我们可以避免循环,并将计算保留在加速器上

```{code-cell} ipython3
batch_mm_diff = jax.vmap(mm_diff)
batch_mm_diff(X)
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
batch_mm_diff(X) # Apply to each row of X
```

函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。

### 组合变换

JAX 的优势之一在于各变换可以自然地组合使用
JAX 的优势之一在于各种变换可以自然地组合使用

例如,我们可以对向量化函数进行 JIT 编译:

Expand All @@ -844,7 +799,7 @@ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
fast_batch_mm_diff(X)
```

`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。

## 练习

Expand Down
Loading
Loading