diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 414535f..4f0ca12 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 -synced-at: "2026-04-13" +source-sha: 450bafecd23db638602150b47f4272b98aad3146 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 3f071e6..e904fbe 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56 -synced-at: "2026-04-13" +source-sha: 450bafecd23db638602150b47f4272b98aad3146 +synced-at: "2026-04-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 46b6757..a00de1b 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -27,10 +27,9 @@ translation: Functional Programming::Examples: مثال‌ها Functional Programming::Why Functional Programming?: چرا برنامه‌نویسی تابعی؟ Random numbers: اعداد تصادفی - Random numbers::Random number generation: تولید اعداد تصادفی - Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟ - Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy - Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX + Random numbers::NumPy / MATLAB Approach: رویکرد NumPy / MATLAB + Random numbers::JAX: JAX + Random numbers::Benefits: مزایا JIT Compilation: کامپایل JIT JIT Compilation::With NumPy: با NumPy JIT Compilation::With JAX: با JAX @@ -416,15 +415,31 @@ JAX از سبک برنامه‌نویسی تابعی استفاده می‌کن ## اعداد تصادفی -اعداد تصادفی در JAX نسبت به آنچه در NumPy یا Matlab می‌یابید بسیار متفاوت هستند. +اعداد تصادفی در JAX نسبت به آنچه در NumPy یا MATLAB می‌یابید بسیار متفاوت هستند. -در ابتدا ممکن است نحو را نسبتاً مفصل بیابید. +### رویکرد NumPy / MATLAB -اما به زودی متوجه خواهید شد که نحو و معناشناسی برای حفظ سبک برنامه‌نویسی تابعی که به تازگی مورد بحث قرار دادیم، ضروری است. +در NumPy / MATLAB، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند. -علاوه بر این، کنترل کامل وضعیت تصادفی برای برنامه‌نویسی موازی، مانند زمانی که می‌خواهیم آزمایش‌های مستقل را در چندین رشته اجرا کنیم، ضروری است. +```{code-cell} ipython3 +np.random.seed(42) +print(np.random.randn(2)) +``` + +هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، وضعیت پنهان به‌روزرسانی می‌شود: + +```{code-cell} ipython3 +print(np.random.randn(2)) +``` + +این تابع *خالص نیست* زیرا: -### تولید اعداد تصادفی +* غیرقطعی است: ورودی‌های یکسان، خروجی‌های متفاوت +* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد + +در موازی‌سازی خطرناک است --- باید با دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد! + +### JAX در JAX، وضعیت مولد اعداد تصادفی به صورت صریح کنترل می‌شود. @@ -545,109 +560,40 @@ def gen_random_matrices(key, n=2, k=3): key, subkey = jax.random.split(key) A = jax.random.uniform(subkey, (n, n)) matrices.append(A) - print(A) return matrices ``` ```{code-cell} ipython3 seed = 42 key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -همچنین می‌توانیم هنگام تکرار در یک حلقه از `fold_in` استفاده کنیم: - -```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): - matrices = [] - for i in range(k): - step_key = jax.random.fold_in(key, i) - A = jax.random.uniform(step_key, (n, n)) - matrices.append(A) - print(A) - return matrices -``` - -```{code-cell} ipython3 -key = jax.random.key(seed) -matrices = gen_random_matrices(key) -``` - -### چرا وضعیت تصادفی صریح؟ - -چرا JAX به این رویکرد نسبتاً مفصل برای تولید اعداد تصادفی نیاز دارد؟ - -یکی از دلایل حفظ توابع خالص است. - -بیایید ببینیم که چگونه تولید اعداد تصادفی با توابع خالص با مقایسه NumPy و JAX مرتبط است. - -#### رویکرد NumPy - -در NumPy، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند. - -هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، این وضعیت به‌روزرسانی می‌شود: - -```{code-cell} ipython3 -np.random.seed(42) -print(np.random.randn()) # Updates state of random number generator -print(np.random.randn()) # Updates state of random number generator +gen_random_matrices(key) ``` -هر فراخوانی یک مقدار متفاوت را برمی‌گرداند، حتی اگر ما همان تابع را با همان ورودی‌ها (بدون آرگومان، در این مورد) فراخوانی می‌کنیم. +این تابع *خالص* است -این تابع *خالص نیست* زیرا: - -* غیرقطعی است: ورودی‌های یکسان (در این مورد هیچ) خروجی‌های متفاوت می‌دهند -* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد - -#### رویکرد JAX - -همانطور که در بالا دیدیم، JAX رویکرد متفاوتی اتخاذ می‌کند و تصادفی بودن را از طریق کلیدها صریح می‌کند. - -برای مثال، - -```{code-cell} ipython3 -def random_sum_jax(key): - key1, key2 = jax.random.split(key) - x = jax.random.normal(key1) - y = jax.random.normal(key2) - return x + y -``` - -با همان کلید، همیشه نتیجه یکسانی دریافت می‌کنیم: - -```{code-cell} ipython3 -key = jax.random.key(42) -random_sum_jax(key) -``` - -```{code-cell} ipython3 -random_sum_jax(key) -``` - -برای دریافت نمونه‌های جدید باید یک کلید جدید ارائه دهیم. - -تابع `random_sum_jax` خالص است زیرا: - -* قطعی است: کلید یکسان همیشه خروجی یکسان تولید می‌کند +* قطعی است: ورودی‌های یکسان، خروجی یکسان * بدون عوارض جانبی: هیچ وضعیت پنهانی تغییر نمی‌کند +### مزایا + صریح بودن JAX مزایای قابل توجهی به همراه دارد: * تکرارپذیری: با استفاده مجدد از کلیدها، تکرار نتایج آسان است -* موازی‌سازی: هر رشته می‌تواند کلید خاص خود را بدون تضاد داشته باشد -* اشکال‌زدایی: نبود وضعیت پنهان استدلال در مورد کد را آسان‌تر می‌کند +* موازی‌سازی: کنترل آنچه در رشته‌های جداگانه اتفاق می‌افتد +* اشکال‌زدایی: نبود وضعیت پنهان آزمایش کد را آسان‌تر می‌کند * سازگاری با JIT: کامپایلر می‌تواند توابع خالص را به طور تهاجمی‌تری بهینه کند -نکته آخر در بخش بعدی گسترش داده می‌شود. - ## کامپایل JIT کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سخت‌افزار متفاوت است، تسریع می‌کند. ما قدرت کامپایلر JIT JAX را در ترکیب با سخت‌افزار موازی {ref}`در بالا ` مشاهده کردیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم. -بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم: +در اینجا کامپایل JIT را برای توابع پیچیده‌تر بررسی می‌کنیم. + +### با NumPy + +ابتدا با NumPy امتحان خواهیم کرد، با استفاده از ```{code-cell} def f(x): @@ -655,9 +601,7 @@ def f(x): return y ``` -### با NumPy - -ابتدا با NumPy امتحان خواهیم کرد +بیایید با `x` بزرگ اجرا کنیم ```{code-cell} n = 50_000_000 @@ -670,9 +614,17 @@ with qe.Timer(): y = f(x) ``` -### با JAX +مدل اجرای **Eager** + +* هر عملیات بلافاصله هنگامی که با آن مواجه می‌شود اجرا می‌شود و نتیجه آن را قبل از شروع عملیات بعدی مادی‌سازی می‌کند. + +معایب + +* موازی‌سازی حداقل +* ردپای حافظه سنگین --- آرایه‌های میانی زیادی تولید می‌کند +* خواندن/نوشتن حافظه زیاد -اکنون بیایید دوباره با JAX امتحان کنیم. +### با JAX به عنوان اولین مرحله، `np` را در همه جا با `jnp` جایگزین می‌کنیم: @@ -703,14 +655,15 @@ with qe.Timer(): jax.block_until_ready(y); ``` -نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در -اجرای دوم پس از کامپایل JIT. +نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT. -علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. +اما همچنان از اجرای eager استفاده می‌کنیم --- حافظه و خواندن/نوشتن زیاد. ### کامپایل کل تابع -کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات آرایه‌ای در یک هسته بهینه شده واحد تسریع کند. +خوشبختانه، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. + +کامپایلر تمام عملیات آرایه‌ای را در یک هسته بهینه‌شده واحد ادغام می‌کند. بیایید این را با تابع `f` امتحان کنیم: @@ -734,9 +687,11 @@ with qe.Timer(): jax.block_until_ready(y); ``` -زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمی‌تری بهینه‌سازی کند. +زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم. -برای مثال، کامپایلر می‌تواند چندین فراخوانی به شتاب‌دهنده سخت‌افزاری و ایجاد تعدادی آرایه میانی را حذف کند. +* بهینه‌سازی تهاجمی بر اساس کل دنباله محاسباتی +* حذف چندین فراخوانی به شتاب‌دهنده سخت‌افزاری +* عدم ایجاد آرایه‌های میانی اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است @@ -756,11 +711,9 @@ def f(x): ### کامپایل توابع غیرخالص -اکنون که دیدیم کامپایل JIT چقدر قدرتمند می‌تواند باشد، درک رابطه آن با توابع خالص مهم است. +در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود! -در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود. - -در اینجا تصویری از این واقعیت با استفاده از متغیرهای سراسری آورده شده است: +در اینجا تصویری از این واقعیت آورده شده است: ```{code-cell} ipython3 a = 1 # global @@ -840,17 +793,13 @@ for row in X: با این حال، حلقه‌های Python کُند هستند و نمی‌توانند به‌طور کارآمد توسط JAX کامپایل یا موازی‌سازی شوند. -استفاده از `vmap` محاسبه را روی شتاب‌دهنده نگه می‌دارد و با سایر -تبدیل‌های JAX مانند `jit` و `grad` ترکیب می‌شود: +با استفاده از `vmap`، می‌توانیم از حلقه‌ها اجتناب کنیم و محاسبه را روی شتاب‌دهنده نگه داریم: ```{code-cell} ipython3 -batch_mm_diff = jax.vmap(mm_diff) -batch_mm_diff(X) +batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version +batch_mm_diff(X) # Apply to each row of X ``` -تابع `mm_diff` برای یک آرایه منفرد نوشته شده بود، و `vmap` به‌طور خودکار -آن را برای عمل سطربه‌سطر روی یک ماتریس ارتقا داد --- بدون حلقه، بدون تغییر شکل. - ### ترکیب تبدیل‌ها یکی از نقاط قوت JAX این است که تبدیل‌ها به‌طور طبیعی با هم ترکیب می‌شوند. diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index a5ee332..d560e90 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -21,6 +21,8 @@ translation: Sequential operations: عملیات ترتیبی Sequential operations::Numba Version: نسخه Numba Sequential operations::JAX Version: نسخه JAX + Sequential operations::JAX Version::First Attempt: تلاش اول + Sequential operations::JAX Version::Second Attempt: تلاش دوم Sequential operations::Summary: خلاصه Overall recommendations: توصیه‌های کلی --- @@ -137,33 +139,34 @@ m = -np.inf for x in grid: for y in grid: z = f(x, y) - if z > m: - m = z + m = max(m, z) ``` ### برداری‌سازی NumPy -اگر به برداری‌سازی به سبک NumPy تغییر دهیم، می‌توانیم از یک شبکه بسیار بزرگتر استفاده کنیم و کد نسبتاً سریع اجرا می‌شود. +بیایید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند. -(این استراتژی به Matlab بازمی‌گردد.) - ```{code-cell} ipython3 +# Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) + +x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): z_max_numpy = np.max(f(x, y)) - -print(f"NumPy result: {z_max_numpy:.6f}") ``` در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند. -علاوه بر این، NumPy از چندنخی ضمنی استفاده می‌کند، به طوری که حداقل مقداری موازی‌سازی رخ می‌دهد. +استفاده از `meshgrid` به ما امکان می‌دهد حلقه for تودرتو را تکرار کنیم. -(موازی‌سازی نمی‌تواند بسیار کارآمد باشد زیرا فایل باینری قبل از اینکه اندازه آرایه‌های `x` و `y` را ببیند کامپایل می‌شود.) +خروجی باید نزدیک به یک باشد: + +```{code-cell} ipython3 +print(f"NumPy result: {z_max_numpy:.6f}") +``` ### مقایسه با Numba @@ -188,8 +191,6 @@ grid = np.linspace(-3, 3, 3_000) with qe.Timer(): # First run z_max_numba = compute_max_numba(grid) - -print(f"Numba result: {z_max_numba:.6f}") ``` بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود. @@ -232,8 +233,6 @@ def compute_max_numba_parallel(grid): with qe.Timer(): # First run z_max_parallel = compute_max_numba_parallel(grid) - -print(f"Numba result: {z_max_parallel:.6f}") ``` در اینجا زمان‌بندی برای نسخه از پیش کامپایل شده آمده است. @@ -244,19 +243,22 @@ with qe.Timer(): compute_max_numba_parallel(grid) ``` -اگر چندین هسته دارید، باید حداقل برخی مزایا را از موازی‌سازی در اینجا ببینید. +اگر چندین هسته دارید، باید مزایایی از موازی‌سازی در اینجا ببینید. -برای دستگاه‌های قدرتمندتر و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت قابل توجهی ایجاد کند، حتی روی CPU. +بیایید مطمئن شویم که نتیجه صحیح را به دست می‌آوریم (نزدیک به یک): -### کد برداری شده با JAX +```{code-cell} ipython3 +print(f"Numba result: {z_max_parallel:.6f}") +``` + +برای دستگاه‌های قدرتمند و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت مفیدی ایجاد کند، حتی روی CPU. -در ظاهر، کد برداری شده در JAX شبیه به کد NumPy است. +### کد برداری شده با JAX -اما تفاوت‌هایی نیز وجود دارد که در اینجا آنها را برجسته می‌کنیم. +بیایید رویکرد برداری شده NumPy را با JAX تکرار کنیم. بیایید با تابع شروع کنیم که `np` را به `jnp` تغییر می‌دهد و `jax.jit` را اضافه می‌کند. - ```{code-cell} ipython3 @jax.jit def f(x, y): @@ -264,7 +266,7 @@ def f(x, y): ``` -همانند NumPy، برای به دست آوردن شکل درست و محاسبه حلقه `for` تودرتوی صحیح، می‌توانیم از عملیات `meshgrid` طراحی شده برای این منظور استفاده کنیم: +از رویکرد meshgrid به سبک NumPy استفاده می‌کنیم: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) @@ -321,60 +323,26 @@ x_mesh.nbytes + y_mesh.nbytes در اینجا نحوه اعمال آن به مسئله ما آمده است. -```{code-cell} ipython3 -# f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند -f_vec_x = lambda y: f(grid, y) -# یک تابع دوم ایجاد کنید که این عملیات را روی تمام y برداری کند -f_vec = jax.vmap(f_vec_x) -``` - -اکنون `f_vec` هنگام فراخوانی با آرایه تخت `grid`، `f(x,y)` را در هر `x,y` محاسبه می‌کند. - -بیایید زمان‌بندی را ببینیم: - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() - -print(f"JAX vmap v1 result: {z_max:.6f}") -``` - -```{code-cell} ipython3 -with qe.Timer(): - z_max = jnp.max(f_vec(grid)) - z_max.block_until_ready() -``` - -با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری با زمان اجرای مشابه استفاده می‌کند. - -اما هنوز برخی بهره‌های سرعت را از دست می‌دهیم. - -کد فوق آرایه دوبعدی کامل `f(x,y)` را محاسبه می‌کند و سپس max را می‌گیرد. - -علاوه بر این، فراخوانی `jnp.max` خارج از تابع JIT-کامپایل شده `f` قرار دارد، بنابراین کامپایلر نمی‌تواند این عملیات را در یک kernel واحد ادغام کند. - -می‌توانیم هر دو مشکل را با انتقال max به داخل و پوشاندن همه چیز در یک `@jax.jit` واحد برطرف کنیم: - ```{code-cell} ipython3 @jax.jit def compute_max_vmap(grid): - # یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد + # Construct a function that takes the max over all x for given y f_vec_x_max = lambda y: jnp.max(f(grid, y)) - # تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم + # Vectorize the function so we can call on all y simultaneously f_vec_max = jax.vmap(f_vec_x_max) - # تابع برداری شده را فراخوانی کنید و حداکثر را بگیرید - return jnp.max(f_vec_max(grid)) + # Compute the max across x at every y + maxes = f_vec_max(grid) + # Compute the max of the maxes and return + return jnp.max(maxes) ``` -در اینجا - -* `f_vec_x_max` حداکثر را در امتداد هر سطر داده شده محاسبه می‌کند -* `f_vec_max` یک نسخه برداری شده است که می‌تواند حداکثر تمام سطرها را به صورت موازی محاسبه کند. +توجه کنید که هرگز -ما این تابع را روی تمام سطرها اعمال می‌کنیم و سپس حداکثر max های سطر را می‌گیریم. +* شبکه دوبعدی `x_mesh` +* شبکه دوبعدی `y_mesh` یا +* آرایه دوبعدی `f(x,y)` -چون max را به داخل منتقل می‌کنیم، هرگز آرایه دوبعدی کامل `f(x,y)` را نمی‌سازیم و حافظه بیشتری صرفه‌جویی می‌شود. +را نمی‌سازیم. و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند. @@ -382,7 +350,10 @@ def compute_max_vmap(grid): ```{code-cell} ipython3 with qe.Timer(): - z_max = compute_max_vmap(grid).block_until_ready() + # First run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() print(f"JAX vmap result: {z_max:.6f}") ``` @@ -391,7 +362,10 @@ print(f"JAX vmap result: {z_max:.6f}") ```{code-cell} ipython3 with qe.Timer(): - z_max = compute_max_vmap(grid).block_until_ready() + # Second run + z_max = compute_max_vmap(grid) + # Hold interpreter + z_max.block_until_ready() ``` ### خلاصه @@ -448,13 +422,15 @@ with qe.Timer(): Numba این عملیات ترتیبی را به طور بسیار کارآمد مدیریت می‌کند. -توجه کنید که اجرای دوم پس از تکمیل کامپایل JIT به طور قابل توجهی سریعتر است. +### نسخه JAX + +ما نمی‌توانیم مستقیماً `numba.jit` را با `jax.jit` جایگزین کنیم زیرا آرایه‌های JAX تغییرناپذیر هستند. -کامپایل Numba معمولاً بسیار سریع است و عملکرد کد حاصل برای عملیات ترتیبی مانند این عالی است. +اما می‌توانیم این عملیات را پیاده‌سازی کنیم. -### نسخه JAX +#### تلاش اول -حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همان‌طور که {ref}`در درس JAX بحث شد `، راه‌حلی برای آرایه‌های تغییرناپذیر فراهم می‌کند. +در اینجا یک راه‌حل با استفاده از سینتکس `at[t].set` ارائه می‌شود که {ref}`در درس JAX بحث شد `. ما از `lax.fori_loop` استفاده می‌کنیم که نسخه‌ای از حلقه for است که می‌تواند توسط XLA کامپایل شود. @@ -477,7 +453,7 @@ def qm_jax_fori(x0, n, α=4.0): * ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود. * ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد. -اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد. +مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد! بیایید آن را با همان پارامترها زمان‌بندی کنیم: @@ -499,7 +475,9 @@ with qe.Timer(): x_jax.block_until_ready() ``` -JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است. +JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است! + +#### تلاش دوم روش دیگری برای پیاده‌سازی حلقه وجود دارد که از `lax.scan` استفاده می‌کند. @@ -538,11 +516,11 @@ with qe.Timer(): x_jax.block_until_ready() ``` -هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند. +شگفت‌انگیز است که JAX نیز پس از کامپایل عملکرد قوی ارائه می‌دهد. ### خلاصه -در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، *تفاوت‌های قابل توجهی در خوانایی کد و سهولت استفاده وجود دارد*. +در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، تفاوت‌هایی در خوانایی کد و سهولت استفاده وجود دارد. نسخه Numba ساده و طبیعی برای خواندن است: ما به سادگی یک آرایه اختصاص می‌دهیم و آن را عنصر به عنصر با استفاده از یک حلقه استاندارد Python پر می‌کنیم.