diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 414535f..09842f7 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 -synced-at: "2026-04-13" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 3f071e6..9091a5d 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 -synced-at: "2026-04-13" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 46b6757..2a4d77a 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -24,13 +24,12 @@ translation: JAX as a NumPy Replacement::Differences::A Workaround: راه‌حل جایگزین Functional Programming: برنامه‌نویسی تابعی Functional Programming::Pure functions: توابع خالص - Functional Programming::Examples: مثال‌ها + Functional Programming::Examples -- Pure and Impure: مثال‌ها -- توابع خالص و ناخالص Functional Programming::Why Functional Programming?: چرا برنامه‌نویسی تابعی؟ Random numbers: اعداد تصادفی - Random numbers::Random number generation: تولید اعداد تصادفی - Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟ - Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy - Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX + Random numbers::NumPy / MATLAB Approach: رویکرد NumPy / MATLAB + Random numbers::JAX: JAX + Random numbers::Benefits: مزایا JIT Compilation: کامپایل JIT JIT Compilation::With NumPy: با NumPy JIT Compilation::With JAX: با JAX @@ -357,19 +356,20 @@ a * وضعیت سراسری را تغییر نمی‌دهد * داده‌های ارسال شده به تابع را تغییر نمی‌دهد (داده‌های تغییرناپذیر) -### مثال‌ها +### مثال‌ها -- توابع خالص و ناخالص -در اینجا مثالی از یک تابع *غیرخالص* آورده شده است +در اینجا مثالی از یک تابع *ناخالص* آورده شده است ```{code-cell} ipython3 tax_rate = 0.1 -prices = [10.0, 20.0] def add_tax(prices): for i, price in enumerate(prices): prices[i] = price * (1 + tax_rate) - print('Post-tax prices: ', prices) - return prices + +prices = [10.0, 20.0] +add_tax(prices) +prices ``` این تابع نمی‌تواند خالص باشد زیرا @@ -380,15 +380,21 @@ def add_tax(prices): در اینجا یک نسخه *خالص* آورده شده است ```{code-cell} ipython3 -tax_rate = 0.1 -prices = (10.0, 20.0) def add_tax_pure(prices, tax_rate): new_prices = [price * (1 + tax_rate) for price in prices] return new_prices + +tax_rate = 0.1 +prices = (10.0, 20.0) +after_tax_prices = add_tax_pure(prices, tax_rate) +after_tax_prices ``` -این نسخه خالص تمام وابستگی‌ها را از طریق آرگومان‌های تابع صریح می‌کند و هیچ وضعیت خارجی را تغییر نمی‌دهد. +این نسخه خالص است زیرا + +* تمام وابستگی‌ها از طریق آرگومان‌های تابع صریح هستند +* و هیچ وضعیت خارجی را تغییر نمی‌دهد ### چرا برنامه‌نویسی تابعی؟ @@ -416,15 +422,31 @@ JAX از سبک برنامه‌نویسی تابعی استفاده می‌کن ## اعداد تصادفی -اعداد تصادفی در JAX نسبت به آنچه در NumPy یا Matlab می‌یابید بسیار متفاوت هستند. +تولید اعداد تصادفی در JAX نسبت به الگوهای موجود در NumPy یا MATLAB بسیار متفاوت است. + +### رویکرد NumPy / MATLAB -در ابتدا ممکن است نحو را نسبتاً مفصل بیابید. +در NumPy / MATLAB، تولید با حفظ وضعیت سراسری پنهان کار می‌کند. + +```{code-cell} ipython3 +np.random.seed(42) +print(np.random.randn(2)) +``` + +هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، وضعیت پنهان به‌روزرسانی می‌شود: + +```{code-cell} ipython3 +print(np.random.randn(2)) +``` + +این تابع *خالص نیست* زیرا: -اما به زودی متوجه خواهید شد که نحو و معناشناسی برای حفظ سبک برنامه‌نویسی تابعی که به تازگی مورد بحث قرار دادیم، ضروری است. +* غیرقطعی است: ورودی‌های یکسان، خروجی‌های متفاوت +* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد -علاوه بر این، کنترل کامل وضعیت تصادفی برای برنامه‌نویسی موازی، مانند زمانی که می‌خواهیم آزمایش‌های مستقل را در چندین رشته اجرا کنیم، ضروری است. +این در موازی‌سازی خطرناک است --- باید به دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد. -### تولید اعداد تصادفی +### JAX در JAX، وضعیت مولد اعداد تصادفی به صورت صریح کنترل می‌شود. @@ -539,115 +561,50 @@ plt.show() تابع زیر `k` ماتریس تصادفی `n x n` (شبه) مستقل را با استفاده از `split` تولید می‌کند. ```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): +def gen_random_matrices( + key, # JAX key for random numbers + n=2, # Matrices will be n x n + k=3 # Number of matrices to generate + ): matrices = [] for _ in range(k): key, subkey = jax.random.split(key) A = jax.random.uniform(subkey, (n, n)) matrices.append(A) - print(A) return matrices ``` ```{code-cell} ipython3 seed = 42 key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -همچنین می‌توانیم هنگام تکرار در یک حلقه از `fold_in` استفاده کنیم: - -```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): - matrices = [] - for i in range(k): - step_key = jax.random.fold_in(key, i) - A = jax.random.uniform(step_key, (n, n)) - matrices.append(A) - print(A) - return matrices -``` - -```{code-cell} ipython3 -key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -### چرا وضعیت تصادفی صریح؟ - -چرا JAX به این رویکرد نسبتاً مفصل برای تولید اعداد تصادفی نیاز دارد؟ - -یکی از دلایل حفظ توابع خالص است. - -بیایید ببینیم که چگونه تولید اعداد تصادفی با توابع خالص با مقایسه NumPy و JAX مرتبط است. - -#### رویکرد NumPy - -در NumPy، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند. - -هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، این وضعیت به‌روزرسانی می‌شود: - -```{code-cell} ipython3 -np.random.seed(42) -print(np.random.randn()) # Updates state of random number generator -print(np.random.randn()) # Updates state of random number generator -``` - -هر فراخوانی یک مقدار متفاوت را برمی‌گرداند، حتی اگر ما همان تابع را با همان ورودی‌ها (بدون آرگومان، در این مورد) فراخوانی می‌کنیم. - -این تابع *خالص نیست* زیرا: - -* غیرقطعی است: ورودی‌های یکسان (در این مورد هیچ) خروجی‌های متفاوت می‌دهند -* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد - -#### رویکرد JAX - -همانطور که در بالا دیدیم، JAX رویکرد متفاوتی اتخاذ می‌کند و تصادفی بودن را از طریق کلیدها صریح می‌کند. - -برای مثال، - -```{code-cell} ipython3 -def random_sum_jax(key): - key1, key2 = jax.random.split(key) - x = jax.random.normal(key1) - y = jax.random.normal(key2) - return x + y -``` - -با همان کلید، همیشه نتیجه یکسانی دریافت می‌کنیم: - -```{code-cell} ipython3 -key = jax.random.key(42) -random_sum_jax(key) -``` - -```{code-cell} ipython3 -random_sum_jax(key) +gen_random_matrices(key) ``` -برای دریافت نمونه‌های جدید باید یک کلید جدید ارائه دهیم. - -تابع `random_sum_jax` خالص است زیرا: +این تابع *خالص* است -* قطعی است: کلید یکسان همیشه خروجی یکسان تولید می‌کند +* قطعی است: ورودی‌های یکسان، خروجی یکسان * بدون عوارض جانبی: هیچ وضعیت پنهانی تغییر نمی‌کند -صریح بودن JAX مزایای قابل توجهی به همراه دارد: +### مزایا + +همانطور که در بالا ذکر شد، این صراحت ارزشمند است: * تکرارپذیری: با استفاده مجدد از کلیدها، تکرار نتایج آسان است -* موازی‌سازی: هر رشته می‌تواند کلید خاص خود را بدون تضاد داشته باشد -* اشکال‌زدایی: نبود وضعیت پنهان استدلال در مورد کد را آسان‌تر می‌کند +* موازی‌سازی: کنترل آنچه در رشته‌های جداگانه اتفاق می‌افتد +* اشکال‌زدایی: نبود وضعیت پنهان، آزمایش کد را آسان‌تر می‌کند * سازگاری با JIT: کامپایلر می‌تواند توابع خالص را به طور تهاجمی‌تری بهینه کند -نکته آخر در بخش بعدی گسترش داده می‌شود. - ## کامپایل JIT کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سخت‌افزار متفاوت است، تسریع می‌کند. ما قدرت کامپایلر JIT JAX را در ترکیب با سخت‌افزار موازی {ref}`در بالا ` مشاهده کردیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم. -بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم: +اینجا کامپایل JIT را برای توابع پیچیده‌تر بررسی می‌کنیم + +### با NumPy + +ابتدا با NumPy امتحان خواهیم کرد، با استفاده از ```{code-cell} def f(x): @@ -655,9 +612,7 @@ def f(x): return y ``` -### با NumPy - -ابتدا با NumPy امتحان خواهیم کرد +بیایید با `x` بزرگ اجرا کنیم ```{code-cell} n = 50_000_000 @@ -670,9 +625,17 @@ with qe.Timer(): y = f(x) ``` -### با JAX +مدل اجرای **Eager** -اکنون بیایید دوباره با JAX امتحان کنیم. +* هر عملیات بلافاصله پس از مواجهه اجرا می‌شود و نتیجه آن قبل از شروع عملیات بعدی مادی می‌شود. + +معایب + +* موازی‌سازی حداقلی +* ردپای حافظه سنگین --- آرایه‌های میانی زیادی تولید می‌کند +* خواندن/نوشتن حافظه زیاد + +### با JAX به عنوان اولین مرحله، `np` را در همه جا با `jnp` جایگزین می‌کنیم: @@ -703,14 +666,22 @@ with qe.Timer(): jax.block_until_ready(y); ``` -نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در -اجرای دوم پس از کامپایل JIT. +نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT. + +این به این دلیل است که عملیات‌های آرایه‌ای منفرد روی GPU موازی‌سازی می‌شوند -علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. +اما ما هنوز از اجرای eager استفاده می‌کنیم + +* حافظه زیاد به دلیل آرایه‌های میانی +* خواندن/نوشتن حافظه زیاد + +همچنین، هسته‌های جداگانه زیادی روی GPU راه‌اندازی می‌شوند ### کامپایل کل تابع -کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات آرایه‌ای در یک هسته بهینه شده واحد تسریع کند. +خوشبختانه، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. + +کامپایلر تمام عملیات‌های آرایه‌ای را در یک هسته بهینه‌شده واحد ادغام می‌کند بیایید این را با تابع `f` امتحان کنیم: @@ -734,9 +705,12 @@ with qe.Timer(): jax.block_until_ready(y); ``` -زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمی‌تری بهینه‌سازی کند. +زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم -برای مثال، کامپایلر می‌تواند چندین فراخوانی به شتاب‌دهنده سخت‌افزاری و ایجاد تعدادی آرایه میانی را حذف کند. +* بهینه‌سازی تهاجمی بر اساس کل دنباله محاسباتی +* حذف چندین فراخوانی به شتاب‌دهنده سخت‌افزاری + +ردپای حافظه نیز بسیار کمتر است --- بدون ایجاد آرایه‌های میانی اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است @@ -750,17 +724,15 @@ def f(x): هنگامی که `jax.jit` را به یک تابع اعمال می‌کنیم، JAX آن را *ردیابی* می‌کند: به جای اجرای فوری عملیات‌ها، دنباله عملیات‌ها را به صورت یک گراف محاسباتی ثبت می‌کند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل می‌دهد. -سپس XLA عملیات‌ها را در یک هسته کامپایل شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است. +سپس XLA عملیات‌ها را در یک هسته کامپایل‌شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است. -اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند. +اولین فراخوانی به یک تابع JIT-کامپایل‌شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل‌شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند. ### کامپایل توابع غیرخالص -اکنون که دیدیم کامپایل JIT چقدر قدرتمند می‌تواند باشد، درک رابطه آن با توابع خالص مهم است. - -در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود. +در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود! -در اینجا تصویری از این واقعیت با استفاده از متغیرهای سراسری آورده شده است: +در اینجا تصویری از این واقعیت آورده شده است: ```{code-cell} ipython3 a = 1 # global @@ -780,7 +752,7 @@ f(x) در کد بالا، مقدار سراسری `a=1` در تابع jitted ادغام می‌شود. -حتی اگر `a` را تغییر دهیم، خروجی `f` تحت تأثیر قرار نخواهد گرفت --- تا زمانی که همان نسخه کامپایل شده فراخوانی شود. +حتی اگر `a` را تغییر دهیم، خروجی `f` تحت تأثیر قرار نخواهد گرفت --- تا زمانی که همان نسخه کامپایل‌شده فراخوانی شود. ```{code-cell} ipython3 a = 42 diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index a5ee332..e522524 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -13,6 +13,7 @@ translation: Vectorized operations: عملیات برداری شده Vectorized operations::Problem Statement: بیان مسئله Vectorized operations::NumPy vectorization: برداری‌سازی NumPy + Vectorized operations::Memory Issues: مشکلات حافظه Vectorized operations::A Comparison with Numba: مقایسه با Numba Vectorized operations::Parallelized Numba: Numba موازی شده Vectorized operations::Vectorized code with JAX: کد برداری شده با JAX @@ -21,6 +22,8 @@ translation: Sequential operations: عملیات ترتیبی Sequential operations::Numba Version: نسخه Numba Sequential operations::JAX Version: نسخه JAX + Sequential operations::JAX Version::First Attempt: تلاش اول + Sequential operations::JAX Version::Second Attempt: تلاش دوم Sequential operations::Summary: خلاصه Overall recommendations: توصیه‌های کلی --- @@ -137,37 +140,75 @@ m = -np.inf for x in grid: for y in grid: z = f(x, y) - if z > m: - m = z + m = max(m, z) ``` ### برداری‌سازی NumPy -اگر به برداری‌سازی به سبک NumPy تغییر دهیم، می‌توانیم از یک شبکه بسیار بزرگتر استفاده کنیم و کد نسبتاً سریع اجرا می‌شود. +اجازه دهید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم -در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند. +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) # Large grid +``` + +به عنوان اولین گام برداری‌سازی، ممکن است چیزی شبیه به این را امتحان کنیم + +```{code-cell} ipython3 +# Large grid +z = np.max(f(grid, grid)) # This is wrong! +``` + +مشکل اینجاست که `f(grid, grid)` از حلقه تودرتو پیروی نمی‌کند. + +از نظر شکل بالا، این فقط مقادیر `f` را در امتداد قطر محاسبه می‌کند. -(این استراتژی به Matlab بازمی‌گردد.) +برای اینکه NumPy را مجبور کنیم `f(x,y)` را برای هر جفت `x,y` محاسبه کند، باید از `np.meshgrid` استفاده کنیم. + +در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند. ```{code-cell} ipython3 +# Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) + +x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): - z_max_numpy = np.max(f(x, y)) + z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works +``` + +در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند. +استفاده از `meshgrid` به ما امکان می‌دهد حلقه `for` تودرتو را تکرار کنیم. + +خروجی باید نزدیک به یک باشد: + +```{code-cell} ipython3 print(f"NumPy result: {z_max_numpy:.6f}") ``` -در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند. +### مشکلات حافظه + +پس راه‌حل درست را در زمان معقول داریم --- اما مصرف حافظه بسیار زیاد است. + +در حالی که آرایه‌های تخت حافظه کمی دارند + +```{code-cell} ipython3 +grid.nbytes +``` + +شبکه‌های mesh دوبعدی هستند و از این رو بسیار فشرده از نظر حافظه + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` -علاوه بر این، NumPy از چندنخی ضمنی استفاده می‌کند، به طوری که حداقل مقداری موازی‌سازی رخ می‌دهد. +علاوه بر این، اجرای فوری NumPy آرایه‌های میانی زیادی با همان اندازه ایجاد می‌کند! -(موازی‌سازی نمی‌تواند بسیار کارآمد باشد زیرا فایل باینری قبل از اینکه اندازه آرایه‌های `x` و `y` را ببیند کامپایل می‌شود.) +این نوع مصرف حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد. ### مقایسه با Numba -حالا بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم. +بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم. ```{code-cell} ipython3 @numba.jit @@ -188,8 +229,6 @@ grid = np.linspace(-3, 3, 3_000) with qe.Timer(): # First run z_max_numba = compute_max_numba(grid) - -print(f"Numba result: {z_max_numba:.6f}") ``` بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود. @@ -200,13 +239,13 @@ with qe.Timer(): compute_max_numba(grid) ``` -بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد. +توجه داشته باشید که تقریباً هیچ حافظه‌ای استفاده نمی‌کنیم --- فقط به `grid` یک‌بعدی نیاز داریم. -در اکثر موارد، Numba کمی بهتر است. +علاوه بر این، سرعت اجرا خوب است. -از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب می‌کند که مزیتی فراهم می‌کند. +در اکثر دستگاه‌ها، نسخه Numba کمی سریع‌تر از NumPy خواهد بود. -از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم. +دلیل آن کد ماشین کارآمد به علاوه خواندن-نوشتن حافظه کمتر است. ### Numba موازی شده @@ -232,8 +271,6 @@ def compute_max_numba_parallel(grid): with qe.Timer(): # First run z_max_parallel = compute_max_numba_parallel(grid) - -print(f"Numba result: {z_max_parallel:.6f}") ``` در اینجا زمان‌بندی برای نسخه از پیش کامپایل شده آمده است. @@ -244,19 +281,22 @@ with qe.Timer(): compute_max_numba_parallel(grid) ``` -اگر چندین هسته دارید، باید حداقل برخی مزایا را از موازی‌سازی در اینجا ببینید. +اگر چندین هسته دارید، باید مزایایی از موازی‌سازی در اینجا ببینید. -برای دستگاه‌های قدرتمندتر و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت قابل توجهی ایجاد کند، حتی روی CPU. +بیایید مطمئن شویم که هنوز نتیجه درستی داریم (نزدیک به یک): -### کد برداری شده با JAX +```{code-cell} ipython3 +print(f"Numba result: {z_max_parallel:.6f}") +``` + +برای دستگاه‌های قدرتمند و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت مفیدی ایجاد کند، حتی روی CPU. -در ظاهر، کد برداری شده در JAX شبیه به کد NumPy است. +### کد برداری شده با JAX -اما تفاوت‌هایی نیز وجود دارد که در اینجا آنها را برجسته می‌کنیم. +بیایید رویکرد برداری شده NumPy را با JAX تکرار کنیم. بیایید با تابع شروع کنیم که `np` را به `jnp` تغییر می‌دهد و `jax.jit` را اضافه می‌کند. - ```{code-cell} ipython3 @jax.jit def f(x, y): @@ -264,7 +304,7 @@ def f(x, y): ``` -همانند NumPy، برای به دست آوردن شکل درست و محاسبه حلقه `for` تودرتوی صحیح، می‌توانیم از عملیات `meshgrid` طراحی شده برای این منظور استفاده کنیم: +از رویکرد meshgrid به سبک NumPy استفاده می‌کنیم: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) @@ -299,82 +339,34 @@ with qe.Timer(): ### JAX به علاوه vmap -یک مشکل با کد NumPy و کد JAX وجود دارد: - -در حالی که آرایه‌های تخت حافظه کمی دارند +چون از `jax.jit` در بالا استفاده کردیم، از ایجاد آرایه‌های میانی زیاد جلوگیری کردیم. -```{code-cell} ipython3 -grid.nbytes -``` +اما هنوز آرایه‌های بزرگ `z_max`، `x_mesh` و `y_mesh` را ایجاد می‌کنیم. -شبکه‌های mesh فشرده از نظر حافظه هستند - -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` - -این استفاده اضافی از حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد. - -خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) می‌پذیرد. - -ایده `vmap` این است که برداری‌سازی را به مراحل تقسیم کند و تابعی که روی مقادیر تکی عمل می‌کند را به تابعی تبدیل کند که روی آرایه‌ها عمل می‌کند. +خوشبختانه، می‌توانیم با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) از این امر اجتناب کنیم. در اینجا نحوه اعمال آن به مسئله ما آمده است. -```{code-cell} ipython3 -# f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند -f_vec_x = lambda y: f(grid, y) -# یک تابع دوم ایجاد کنید که این عملیات را روی تمام y برداری کند -f_vec = jax.vmap(f_vec_x) -``` - -اکنون `f_vec` هنگام فراخوانی با آرایه تخت `grid`، `f(x,y)` را در هر `x,y` محاسبه می‌کند. - -بیایید زمان‌بندی را ببینیم: - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() - -print(f"JAX vmap v1 result: {z_max:.6f}") -``` - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() -``` - -با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری با زمان اجرای مشابه استفاده می‌کند. - -اما هنوز برخی بهره‌های سرعت را از دست می‌دهیم. - -کد فوق آرایه دوبعدی کامل `f(x,y)` را محاسبه می‌کند و سپس max را می‌گیرد. - -علاوه بر این، فراخوانی `jnp.max` خارج از تابع JIT-کامپایل شده `f` قرار دارد، بنابراین کامپایلر نمی‌تواند این عملیات را در یک kernel واحد ادغام کند. - -می‌توانیم هر دو مشکل را با انتقال max به داخل و پوشاندن همه چیز در یک `@jax.jit` واحد برطرف کنیم: - ```{code-cell} ipython3 @jax.jit def compute_max_vmap(grid): - # یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد - f_vec_x_max = lambda y: jnp.max(f(grid, y)) - # تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم - f_vec_max = jax.vmap(f_vec_x_max) - # تابع برداری شده را فراخوانی کنید و حداکثر را بگیرید - return jnp.max(f_vec_max(grid)) + # Construct a function that takes the max over all x for given y + compute_column_max = lambda y: jnp.max(f(grid, y)) + # Vectorize the function so we can call on all y simultaneously + vectorized_compute_column_max = jax.vmap(compute_column_max) + # Compute the column max at every row + column_maxes = vectorized_compute_column_max(grid) + # Compute the max of the column maxes and return + return jnp.max(column_maxes) ``` -در اینجا - -* `f_vec_x_max` حداکثر را در امتداد هر سطر داده شده محاسبه می‌کند -* `f_vec_max` یک نسخه برداری شده است که می‌تواند حداکثر تمام سطرها را به صورت موازی محاسبه کند. +توجه داشته باشید که هرگز ایجاد نمی‌کنیم -ما این تابع را روی تمام سطرها اعمال می‌کنیم و سپس حداکثر max های سطر را می‌گیریم. +* شبکه دوبعدی `x_mesh` +* شبکه دوبعدی `y_mesh` یا +* آرایه دوبعدی `f(x,y)` -چون max را به داخل منتقل می‌کنیم، هرگز آرایه دوبعدی کامل `f(x,y)` را نمی‌سازیم و حافظه بیشتری صرفه‌جویی می‌شود. +مانند Numba، فقط از آرایه تخت `grid` استفاده می‌کنیم. و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند. @@ -382,7 +374,10 @@ def compute_max_vmap(grid): ```{code-cell} ipython3 with qe.Timer(): - z_max = compute_max_vmap(grid).block_until_ready() + # First run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() print(f"JAX vmap result: {z_max:.6f}") ``` @@ -391,7 +386,10 @@ print(f"JAX vmap result: {z_max:.6f}") ```{code-cell} ipython3 with qe.Timer(): - z_max = compute_max_vmap(grid).block_until_ready() + # Second run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() ``` ### خلاصه @@ -400,13 +398,11 @@ with qe.Timer(): هم از نظر سرعت (از طریق JIT-compilation و موازی‌سازی) و هم از نظر کارایی حافظه (از طریق vmap) بر NumPy غلبه می‌کند. -علاوه بر این، رویکرد `vmap` گاهی اوقات می‌تواند منجر به کد به طور قابل توجهی واضح‌تری شود. +همچنین هنگام اجرا روی GPU بر Numba غلبه می‌کند. -در حالی که Numba چشمگیر است، زیبایی JAX این است که با عملیات کاملاً برداری شده، می‌توانیم دقیقاً همان کد را روی دستگاه‌های با شتاب‌دهنده سخت‌افزاری اجرا کنیم و بدون تلاش اضافی از تمام مزایا بهره‌مند شویم. - -علاوه بر این، JAX قبلاً می‌داند چگونه بسیاری از عملیات آرایه رایج را به طور مؤثر موازی کند، که کلید اجرای سریع است. - -برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان کدنویسی دستی کنیم. +```{note} +Numba می‌تواند از طریق `numba.cuda` از برنامه‌نویسی GPU پشتیبانی کند اما در آن صورت باید موازی‌سازی را به صورت دستی انجام دهیم. برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان به صورت دستی کدنویسی کنیم. +``` ## عملیات ترتیبی @@ -436,6 +432,7 @@ def qm(x0, n, α=4.0): n = 10_000_000 with qe.Timer(): + # First run x = qm(0.1, n) ``` @@ -443,18 +440,21 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): + # Second run x = qm(0.1, n) ``` Numba این عملیات ترتیبی را به طور بسیار کارآمد مدیریت می‌کند. -توجه کنید که اجرای دوم پس از تکمیل کامپایل JIT به طور قابل توجهی سریعتر است. +### نسخه JAX -کامپایل Numba معمولاً بسیار سریع است و عملکرد کد حاصل برای عملیات ترتیبی مانند این عالی است. +ما نمی‌توانیم مستقیماً `numba.jit` را با `jax.jit` جایگزین کنیم زیرا آرایه‌های JAX تغییرناپذیر هستند. -### نسخه JAX +اما ما همچنان می‌توانیم این عملیات را پیاده‌سازی کنیم. + +#### تلاش اول -حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همان‌طور که {ref}`در درس JAX بحث شد `، راه‌حلی برای آرایه‌های تغییرناپذیر فراهم می‌کند. +در اینجا یک راه‌حل با استفاده از سینتکس `at[t].set` که {ref}`در درس JAX بحث شد ` آمده است. ما از `lax.fori_loop` استفاده می‌کنیم که نسخه‌ای از حلقه for است که می‌تواند توسط XLA کامپایل شود. @@ -477,7 +477,7 @@ def qm_jax_fori(x0, n, α=4.0): * ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود. * ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد. -اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد. +مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد! بیایید آن را با همان پارامترها زمان‌بندی کنیم: @@ -499,7 +499,9 @@ with qe.Timer(): x_jax.block_until_ready() ``` -JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است. +JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است! + +#### تلاش دوم روش دیگری برای پیاده‌سازی حلقه وجود دارد که از `lax.scan` استفاده می‌کند. @@ -538,11 +540,11 @@ with qe.Timer(): x_jax.block_until_ready() ``` -هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند. +شگفت‌انگیز است که JAX نیز پس از کامپایل عملکرد قوی ارائه می‌دهد. ### خلاصه -در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، *تفاوت‌های قابل توجهی در خوانایی کد و سهولت استفاده وجود دارد*. +در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، تفاوت‌هایی در خوانایی کد و سهولت استفاده وجود دارد. نسخه Numba ساده و طبیعی برای خواندن است: ما به سادگی یک آرایه اختصاص می‌دهیم و آن را عنصر به عنصر با استفاده از یک حلقه استاندارد Python پر می‌کنیم. @@ -552,8 +554,6 @@ with qe.Timer(): در حالی که سینتکس `at[t].set` در JAX به‌روزرسانی عنصر به عنصر را ممکن می‌سازد، کد کلی همچنان سخت‌تر از معادل Numba برای خواندن است. -برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است. - ## توصیه‌های کلی حال قدمی به عقب بر می‌داریم و مبادلات را خلاصه می‌کنیم. @@ -566,17 +566,12 @@ with qe.Timer(): علاوه بر این، توابع JAX به‌صورت خودکار مشتق‌پذیر هستند، همان‌طور که در {doc}`autodiff` بررسی می‌کنیم. -برای **عملیات ترتیبی**، Numba مزایای آشکاری دارد. +برای **عملیات ترتیبی**، Numba نحو بهتری دارد. کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است. JAX می‌تواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است. -```{note} -یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کنند، که Numba قادر به انجام آن نیست. -اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است. -``` - -در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند. +از سوی دیگر، نسخه‌های JAX از مشتق‌گیری خودکار پشتیبانی می‌کنند. -یک قاعده سرانگشتی مناسب: برای پروژه‌های جدید، به‌ویژه زمانی که شتاب‌دهی سخت‌افزاری یا مشتق‌پذیری ممکن است مفید باشد، به‌طور پیش‌فرض از JAX استفاده کنید، و هنگامی که یک حلقه ترتیبی فشرده نیاز به سرعت و خوانایی دارد، به Numba متوسل شوید. +این ممکن است مورد توجه باشد، مثلاً زمانی که می‌خواهیم حساسیت‌های یک مسیر را نسبت به پارامترهای مدل محاسبه کنیم.