Skip to content

Commit 95378b8

Browse files
authored
misc (#528)
1 parent 8d73de3 commit 95378b8

File tree

3 files changed

+81
-122
lines changed

3 files changed

+81
-122
lines changed

lectures/jax_intro.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -645,20 +645,15 @@ efficient machine code that varies with both task size and hardware.
645645
We saw the power of JAX's JIT compiler combined with parallel hardware when we
646646
{ref}`above <jax_speed>`, when we applied `cos` to a large array.
647647

648-
Let's try the same thing with a more complex function.
649-
650-
651-
### Evaluating a more complicated function
652-
653-
Consider the function
648+
Let's try the same thing with a more complex function:
654649

655650
```{code-cell}
656651
def f(x):
657652
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
658653
return y
659654
```
660655

661-
#### With NumPy
656+
### With NumPy
662657

663658
We'll try first with NumPy
664659

@@ -675,7 +670,7 @@ with qe.Timer():
675670

676671

677672

678-
#### With JAX
673+
### With JAX
679674

680675
Now let's try again with JAX.
681676

@@ -712,10 +707,10 @@ The outcome is similar to the `cos` example --- JAX is faster, especially on the
712707
second run after JIT compilation.
713708

714709
However, with JAX, we have another trick up our sleeve --- we can JIT-compile
715-
the *entire* function, not just individual operations.
710+
the entire function, not just individual operations.
716711

717712

718-
### Compiling the whole function
713+
### Compiling the Whole Function
719714

720715
The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
721716
operations into a single optimized kernel.

lectures/numba.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ n = 10_000_000
124124
125125
with qe.Timer() as timer1:
126126
# Time Python base version
127-
x = qm(0.1, int(n))
127+
x = qm(0.1, n)
128128
129129
```
130130

@@ -154,7 +154,7 @@ Let's time this new version:
154154
```{code-cell} ipython3
155155
with qe.Timer() as timer2:
156156
# Time jitted version
157-
x = qm_numba(0.1, int(n))
157+
x = qm_numba(0.1, n)
158158
```
159159

160160
This is a large speed gain.
@@ -167,7 +167,7 @@ function has been compiled and is in memory:
167167
```{code-cell} ipython3
168168
with qe.Timer() as timer3:
169169
# Second run
170-
x = qm_numba(0.1, int(n))
170+
x = qm_numba(0.1, n)
171171
```
172172

173173
Here's the speed gain

0 commit comments

Comments
 (0)