Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .translate/state/jax_intro.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
synced-at: "2026-04-13"
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-04-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 7
Expand Down
4 changes: 2 additions & 2 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
synced-at: "2026-04-13"
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-04-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
Expand Down
210 changes: 91 additions & 119 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ translation:
JAX as a NumPy Replacement::Differences::A Workaround: راه‌حل جایگزین
Functional Programming: برنامه‌نویسی تابعی
Functional Programming::Pure functions: توابع خالص
Functional Programming::Examples: مثال‌ها
Functional Programming::Examples -- Pure and Impure: مثال‌ها -- توابع خالص و ناخالص
Functional Programming::Why Functional Programming?: چرا برنامه‌نویسی تابعی؟
Random numbers: اعداد تصادفی
Random numbers::Random number generation: تولید اعداد تصادفی
Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟
Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy
Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX
Random numbers::NumPy / MATLAB Approach: رویکرد NumPy / MATLAB
Random numbers::JAX: JAX
Random numbers::Benefits: مزایا
JIT Compilation: کامپایل JIT
JIT Compilation::With NumPy: با NumPy
JIT Compilation::With JAX: با JAX
Expand Down Expand Up @@ -357,19 +356,20 @@ a
* وضعیت سراسری را تغییر نمی‌دهد
* داده‌های ارسال شده به تابع را تغییر نمی‌دهد (داده‌های تغییرناپذیر)

### مثال‌ها
### مثال‌ها -- توابع خالص و ناخالص

در اینجا مثالی از یک تابع *غیرخالص* آورده شده است
در اینجا مثالی از یک تابع *ناخالص* آورده شده است

```{code-cell} ipython3
tax_rate = 0.1
prices = [10.0, 20.0]

def add_tax(prices):
for i, price in enumerate(prices):
prices[i] = price * (1 + tax_rate)
print('Post-tax prices: ', prices)
return prices

prices = [10.0, 20.0]
add_tax(prices)
prices
```

این تابع نمی‌تواند خالص باشد زیرا
Expand All @@ -380,15 +380,21 @@ def add_tax(prices):
در اینجا یک نسخه *خالص* آورده شده است

```{code-cell} ipython3
tax_rate = 0.1
prices = (10.0, 20.0)

def add_tax_pure(prices, tax_rate):
new_prices = [price * (1 + tax_rate) for price in prices]
return new_prices

tax_rate = 0.1
prices = (10.0, 20.0)
after_tax_prices = add_tax_pure(prices, tax_rate)
after_tax_prices
```

این نسخه خالص تمام وابستگی‌ها را از طریق آرگومان‌های تابع صریح می‌کند و هیچ وضعیت خارجی را تغییر نمی‌دهد.
این نسخه خالص است زیرا

* تمام وابستگی‌ها از طریق آرگومان‌های تابع صریح هستند
* و هیچ وضعیت خارجی را تغییر نمی‌دهد

### چرا برنامه‌نویسی تابعی؟

Expand Down Expand Up @@ -416,15 +422,31 @@ JAX از سبک برنامه‌نویسی تابعی استفاده می‌کن

## اعداد تصادفی

اعداد تصادفی در JAX نسبت به آنچه در NumPy یا Matlab می‌یابید بسیار متفاوت هستند.
تولید اعداد تصادفی در JAX نسبت به الگوهای موجود در NumPy یا MATLAB بسیار متفاوت است.

### رویکرد NumPy / MATLAB

در ابتدا ممکن است نحو را نسبتاً مفصل بیابید.
در NumPy / MATLAB، تولید با حفظ وضعیت سراسری پنهان کار می‌کند.

```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn(2))
```

هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، وضعیت پنهان به‌روزرسانی می‌شود:

```{code-cell} ipython3
print(np.random.randn(2))
```

این تابع *خالص نیست* زیرا:

اما به زودی متوجه خواهید شد که نحو و معناشناسی برای حفظ سبک برنامه‌نویسی تابعی که به تازگی مورد بحث قرار دادیم، ضروری است.
* غیرقطعی است: ورودی‌های یکسان، خروجی‌های متفاوت
* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد

علاوه بر این، کنترل کامل وضعیت تصادفی برای برنامه‌نویسی موازی، مانند زمانی که می‌خواهیم آزمایش‌های مستقل را در چندین رشته اجرا کنیم، ضروری است.
این در موازی‌سازی خطرناک است --- باید به دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد.

### تولید اعداد تصادفی
### JAX

در JAX، وضعیت مولد اعداد تصادفی به صورت صریح کنترل می‌شود.

Expand Down Expand Up @@ -539,125 +561,58 @@ plt.show()
تابع زیر `k` ماتریس تصادفی `n x n` (شبه) مستقل را با استفاده از `split` تولید می‌کند.

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
def gen_random_matrices(
key, # JAX key for random numbers
n=2, # Matrices will be n x n
k=3 # Number of matrices to generate
):
matrices = []
for _ in range(k):
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
seed = 42
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```

همچنین می‌توانیم هنگام تکرار در یک حلقه از `fold_in` استفاده کنیم:

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
matrices = []
for i in range(k):
step_key = jax.random.fold_in(key, i)
A = jax.random.uniform(step_key, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```

### چرا وضعیت تصادفی صریح؟

چرا JAX به این رویکرد نسبتاً مفصل برای تولید اعداد تصادفی نیاز دارد؟

یکی از دلایل حفظ توابع خالص است.

بیایید ببینیم که چگونه تولید اعداد تصادفی با توابع خالص با مقایسه NumPy و JAX مرتبط است.

#### رویکرد NumPy

در NumPy، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند.

هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، این وضعیت به‌روزرسانی می‌شود:

```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn()) # Updates state of random number generator
print(np.random.randn()) # Updates state of random number generator
```

هر فراخوانی یک مقدار متفاوت را برمی‌گرداند، حتی اگر ما همان تابع را با همان ورودی‌ها (بدون آرگومان، در این مورد) فراخوانی می‌کنیم.

این تابع *خالص نیست* زیرا:

* غیرقطعی است: ورودی‌های یکسان (در این مورد هیچ) خروجی‌های متفاوت می‌دهند
* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد

#### رویکرد JAX

همانطور که در بالا دیدیم، JAX رویکرد متفاوتی اتخاذ می‌کند و تصادفی بودن را از طریق کلیدها صریح می‌کند.

برای مثال،

```{code-cell} ipython3
def random_sum_jax(key):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1)
y = jax.random.normal(key2)
return x + y
```

با همان کلید، همیشه نتیجه یکسانی دریافت می‌کنیم:

```{code-cell} ipython3
key = jax.random.key(42)
random_sum_jax(key)
```

```{code-cell} ipython3
random_sum_jax(key)
gen_random_matrices(key)
```

برای دریافت نمونه‌های جدید باید یک کلید جدید ارائه دهیم.

تابع `random_sum_jax` خالص است زیرا:
این تابع *خالص* است

* قطعی است: کلید یکسان همیشه خروجی یکسان تولید می‌کند
* قطعی است: ورودی‌های یکسان، خروجی یکسان
* بدون عوارض جانبی: هیچ وضعیت پنهانی تغییر نمی‌کند

صریح بودن JAX مزایای قابل توجهی به همراه دارد:
### مزایا

همانطور که در بالا ذکر شد، این صراحت ارزشمند است:

* تکرارپذیری: با استفاده مجدد از کلیدها، تکرار نتایج آسان است
* موازی‌سازی: هر رشته می‌تواند کلید خاص خود را بدون تضاد داشته باشد
* اشکال‌زدایی: نبود وضعیت پنهان استدلال در مورد کد را آسان‌تر می‌کند
* موازی‌سازی: کنترل آنچه در رشته‌های جداگانه اتفاق می‌افتد
* اشکال‌زدایی: نبود وضعیت پنهان، آزمایش کد را آسان‌تر می‌کند
* سازگاری با JIT: کامپایلر می‌تواند توابع خالص را به طور تهاجمی‌تری بهینه کند

نکته آخر در بخش بعدی گسترش داده می‌شود.

## کامپایل JIT

کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سخت‌افزار متفاوت است، تسریع می‌کند.

ما قدرت کامپایلر JIT JAX را در ترکیب با سخت‌افزار موازی {ref}`در بالا <jax_speed>` مشاهده کردیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم.

بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم:
اینجا کامپایل JIT را برای توابع پیچیده‌تر بررسی می‌کنیم

### با NumPy

ابتدا با NumPy امتحان خواهیم کرد، با استفاده از

```{code-cell}
def f(x):
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
return y
```

### با NumPy

ابتدا با NumPy امتحان خواهیم کرد
بیایید با `x` بزرگ اجرا کنیم

```{code-cell}
n = 50_000_000
Expand All @@ -670,9 +625,17 @@ with qe.Timer():
y = f(x)
```

### با JAX
مدل اجرای **Eager**

اکنون بیایید دوباره با JAX امتحان کنیم.
* هر عملیات بلافاصله پس از مواجهه اجرا می‌شود و نتیجه آن قبل از شروع عملیات بعدی مادی می‌شود.

معایب

* موازی‌سازی حداقلی
* ردپای حافظه سنگین --- آرایه‌های میانی زیادی تولید می‌کند
* خواندن/نوشتن حافظه زیاد

### با JAX

به عنوان اولین مرحله، `np` را در همه جا با `jnp` جایگزین می‌کنیم:

Expand Down Expand Up @@ -703,14 +666,22 @@ with qe.Timer():
jax.block_until_ready(y);
```

نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در
اجرای دوم پس از کامپایل JIT.
نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT.

این به این دلیل است که عملیات‌های آرایه‌ای منفرد روی GPU موازی‌سازی می‌شوند

علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.
اما ما هنوز از اجرای eager استفاده می‌کنیم

* حافظه زیاد به دلیل آرایه‌های میانی
* خواندن/نوشتن حافظه زیاد

همچنین، هسته‌های جداگانه زیادی روی GPU راه‌اندازی می‌شوند

### کامپایل کل تابع

کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات آرایه‌ای در یک هسته بهینه شده واحد تسریع کند.
خوشبختانه، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.

کامپایلر تمام عملیات‌های آرایه‌ای را در یک هسته بهینه‌شده واحد ادغام می‌کند

بیایید این را با تابع `f` امتحان کنیم:

Expand All @@ -734,9 +705,12 @@ with qe.Timer():
jax.block_until_ready(y);
```

زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمی‌تری بهینه‌سازی کند.
زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم

برای مثال، کامپایلر می‌تواند چندین فراخوانی به شتاب‌دهنده سخت‌افزاری و ایجاد تعدادی آرایه میانی را حذف کند.
* بهینه‌سازی تهاجمی بر اساس کل دنباله محاسباتی
* حذف چندین فراخوانی به شتاب‌دهنده سخت‌افزاری

ردپای حافظه نیز بسیار کمتر است --- بدون ایجاد آرایه‌های میانی

اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است

Expand All @@ -750,17 +724,15 @@ def f(x):

هنگامی که `jax.jit` را به یک تابع اعمال می‌کنیم، JAX آن را *ردیابی* می‌کند: به جای اجرای فوری عملیات‌ها، دنباله عملیات‌ها را به صورت یک گراف محاسباتی ثبت می‌کند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل می‌دهد.

سپس XLA عملیات‌ها را در یک هسته کامپایل شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است.
سپس XLA عملیات‌ها را در یک هسته کامپایل‌شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است.

اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند.
اولین فراخوانی به یک تابع JIT-کامپایل‌شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل‌شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند.

### کامپایل توابع غیرخالص

اکنون که دیدیم کامپایل JIT چقدر قدرتمند می‌تواند باشد، درک رابطه آن با توابع خالص مهم است.

در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود.
در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود!

در اینجا تصویری از این واقعیت با استفاده از متغیرهای سراسری آورده شده است:
در اینجا تصویری از این واقعیت آورده شده است:

```{code-cell} ipython3
a = 1 # global
Expand All @@ -780,7 +752,7 @@ f(x)

در کد بالا، مقدار سراسری `a=1` در تابع jitted ادغام می‌شود.

حتی اگر `a` را تغییر دهیم، خروجی `f` تحت تأثیر قرار نخواهد گرفت --- تا زمانی که همان نسخه کامپایل شده فراخوانی شود.
حتی اگر `a` را تغییر دهیم، خروجی `f` تحت تأثیر قرار نخواهد گرفت --- تا زمانی که همان نسخه کامپایل‌شده فراخوانی شود.

```{code-cell} ipython3
a = 42
Expand Down
Loading
Loading