feat(gla): fused forward+backward kernels for g_gamma mode#122
feat(gla): fused forward+backward kernels for g_gamma mode#122sii-xinglong wants to merge 16 commits intomainfrom
Conversation
Add chunk_fwd_fused_g_gamma which merges three separate pallas_calls (h propagation + A recomputation + output computation) into a single pallas_call. The hidden state stays in VMEM scratch instead of making an HBM round-trip, and the attention matrix A is recomputed inline rather than materialised. Includes tests comparing fused output against the non-fused chunk_gla_fwd reference at basic and AL model dimensions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add _chunk_bwd_fused_kernel and chunk_bwd_fused_g_gamma to chunk_fused_kernels.py, merging dh reverse propagation and dq/dk/dv gradient computation into a single pallas_call. The kernel eliminates g_cumsum by recomputing gating from the g_gamma scalar, and keeps dh state in VMEM scratch across the reverse time scan. Add test_pallas_chunk_fused_bwd.py comparing against chunk_gla_bwd_with_pl reference at two scales (basic + AL dims). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…d kernel Replace raw jnp.exp() calls with exp() from tops.ops.utils for consistency with the forward kernel. Remove redundant b_gn alias for b_g_last. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chunk_gla_fwd and chunk_gla_bwd_with_pl now use the fused single- pallas_call path when: g_gamma is provided (1D shape), g is None, cu_seqlens is None, initial_state is None, and running on TPU. Falls back to existing multi-kernel path for all other configurations (varlen, initial_state, non-TPU, output_final_state, etc.). Both dispatch paths handle K/V dimension padding to multiples of 128 and strip the padding from outputs before returning. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Enables benchmarking fused g_gamma kernels (single pallas_call forward and backward) against existing multi-kernel implementations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds TPU Pallas fused "g_gamma" chunked GLA forward and backward kernels, integrates TPU-only fused fast-paths into chunk dispatchers with padding/unpadding and alignment, registers two new benchmark providers, and adds TPU-only forward, backward, and end-to-end tests verifying fused kernels vs. reference implementations. Changes
Sequence Diagram(s)sequenceDiagram
participant Tester as Tests/Benchmarks
participant Dispatcher as chunk_gla_fwd / chunk_gla_bwd_with_pl
participant Wrapper as chunk_*_fused_g_gamma
participant Pallas as Pallas Fused Kernel
participant VMEM as VMEM Scratch
Tester->>Dispatcher: call with (q,k,v,g_gamma,do?,...)
Dispatcher->>Wrapper: pad/transpose, validate tiling (C,128), choose fused path
Wrapper->>Pallas: launch fused kernel with chunked tensors
Pallas->>VMEM: read/write h state across time (VMEM scratch)
Pallas-->>Wrapper: return (h,o) or (dq,dk,dv)
Wrapper->>Dispatcher: unpad/transpose outputs
Dispatcher-->>Tester: return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces fused Pallas TPU kernels for the Gated Linear Attention (GLA) mechanism in g_gamma mode, aiming to optimize performance by merging multiple kernel launches and reducing HBM round-trips. The changes include new forward and backward fused kernels, dispatch logic in the main GLA operations, and comprehensive unit and end-to-end tests. However, several critical issues were identified regarding the scalability of the kernels for dimensions larger than 128, as the current tiling logic leads to output overwrites rather than proper reductions. Additionally, the backward pass currently returns zero gradients for the g_gamma parameter, which would cause silent failures in training scenarios where the gate is a learnable parameter.
| preferred_element_type=jnp.float32, | ||
| ) | ||
|
|
||
| o_ref[0, 0] = (b_o_inter + b_o_intra).astype(o_ref.dtype) |
There was a problem hiding this comment.
This kernel will produce incorrect results if K > 128. The grid tiles the K dimension (pl.cdiv(K_dim, BK)), but the o_map (line 204) ignores the ki index, causing multiple tiles to overwrite the same output block in o_ref. Since the output o is a reduction over K, these partial results must be summed. If the kernel is intended only for K <= 128, please add an explicit assertion. Otherwise, the output mapping and reduction logic need to be corrected.
tops/ops/gla/chunk_fused_kernels.py
Outdated
| ) # [BT, K] | ||
|
|
||
| # Phase 5 (VPU): Combine results and write to 5D output slots | ||
| dv_ref[0, i_t, 0] = (b_dv_intra + b_dv_inter).astype(dv_ref.dtype) |
tops/ops/gla/chunk_fused_kernels.py
Outdated
| dq_ref[0, i_t, 0] = b_dq.astype(dq_ref.dtype) | ||
|
|
||
| b_dk = b_dk_intra_raw * exp_neg[:, None] + b_dk_inter * exp_gn_minus[:, None] | ||
| dk_ref[0, i_t, 0] = b_dk.astype(dk_ref.dtype) |
| dv = dv[..., :V] | ||
|
|
||
| # dg: zeros (g_gamma gradient handled by caller via sum-reduction) | ||
| dg = jnp.zeros_like(g_gamma) |
There was a problem hiding this comment.
The fused backward path currently returns zero gradients for g_gamma. While the PR description mentions "Dead dg elimination" as an optimization, this will break training for models where g_gamma is a learnable parameter (which is the standard case for Simple GLA). If this optimization is intended only for inference or fixed-gate scenarios, the dispatch logic should be more restrictive. Otherwise, dg should be correctly computed and returned to avoid silent training failure.
The forward kernel writes the float32 VMEM scratch state to the bfloat16 h_ref output without casting, triggering a dtype mismatch error in JAX's Pallas swap primitive. Add .astype(h_ref.dtype) to match the pattern already used for o_ref and dq/dk/dv_ref writes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pallas does not allow kernel functions to capture traced JAX constants. When scale is not static, JIT traces it as a JAX scalar which then gets closed over by functools.partial, triggering "captures constants [f32[]]". Mark scale as static_argnames in both forward and backward JIT wrappers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py`:
- Around line 44-119: The tests currently only compare two TPU/Pallas paths
(chunk_bwd_fused_g_gamma vs chunk_gla_bwd_with_pl) so add a CPU/naive oracle for
dq, dk, dv and assert fused outputs against it; implement or call a pure-CPU
reference (e.g., cpu_chunk_gla_bwd_naive or similar) that computes the
non-optimized backward (using plain NumPy/JAX CPU ops) and in both
test_fused_bwd_basic and test_fused_bwd_al_dims obtain cpu_ref_dq, cpu_ref_dk,
cpu_ref_dv and add compare_tensor assertions (same atol/rtol/dtype) comparing
dq_fused/dk_fused/dv_fused to the CPU references in addition to the existing TPU
comparisons; update or add a helper (analogous to _run_reference) that returns
the CPU oracle so tests call that helper.
In `@tests/ops/gla/test_pallas_chunk_fused_fwd.py`:
- Around line 39-87: The tests call _run_reference which uses the optimized
chunk_gla_fwd (TPU/Pallas) so add a true CPU/JAX-naive oracle and assert fused
outputs against it; implement a new function (e.g., naive_chunk_gla_fwd or
cpu_chunk_gla_fwd) that computes h and o using plain JAX/Numpy ops (explicit
chunk loop or full attention with g_gamma applied and the same scaling) and
replace calls in test_fused_fwd_basic and test_fused_fwd_al_dims to compare
h_fused and o_fused against the CPU oracle (still keep existing fused vs
chunk_gla_fwd checks if desired), using compare_tensor with the same atol/rtol
and dtype to satisfy the guideline.
In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 180-182: The fused launcher kernels must be restricted to
single-tile shapes to avoid silent overwrites; add hard assertions ensuring
K_dim == 128, K == 128 and V == 128 (in addition to the existing modulo checks)
so the forward launcher (which drops ki in o_map) and the backward launcher
(which uses out_k_map/out_v_map hardcoded to (b,0,h,0,0)) never span multiple
tiles; update the assert block that currently checks T%BT, K_dim%BK and V%BV to
also assert K_dim == 128, K == 128 and V == 128 to prevent cross-tile launches
that cause data corruption.
In `@tops/ops/gla/chunk.py`:
- Around line 1049-1051: The fused-TPU branch currently returns a hard-coded
zero tensor for dg (dg = jnp.zeros_like(g_gamma)) which drops gradients for the
learnable g_gamma; replace this with the same gradient computation and reduction
used in the non-fused path so dg contains the real gradient reduced back to
g_gamma.shape (or call the existing reduction helper used by the non-fused
branch), and return that reduced dg instead of zeros so g_gamma receives correct
gradients.
- Around line 1018-1021: The 1-D g_gamma ([H]) must be reshaped to a
broadcast-compatible shape before fallback broadcasts: inside the fallback
branches of chunk_gla_bwd_with_pl and chunk_gla_fwd, detect if g_gamma is 1-D
(g_gamma.ndim == 1) and replace it with g_gamma.reshape((1, 1, H, 1)) (where H =
g_gamma.shape[0]) before any multiplication/broadcast (e.g., before computing
g_gamma * pos or jnp.broadcast_to calls) so subsequent operations like
jnp.broadcast_to(g_gamma * pos, q.shape) succeed for q.shape = [B, T_pad, H, K].
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2b00d9e5-8b22-48c3-ba9a-7a1806650494
📒 Files selected for processing (6)
benchmarks/ops/benchmark_gla.pytests/ops/gla/test_pallas_chunk_fused_bwd.pytests/ops/gla/test_pallas_chunk_fused_fwd.pytests/ops/gla/test_pallas_chunk_gla_fused_e2e.pytops/ops/gla/chunk.pytops/ops/gla/chunk_fused_kernels.py
tops/ops/gla/chunk.py
Outdated
| # dg: zeros (g_gamma gradient handled by caller via sum-reduction) | ||
| dg = jnp.zeros_like(g_gamma) | ||
| return dq, dk, dv, dg, None |
There was a problem hiding this comment.
Don't return a fake zero gradient for g_gamma.
The non-fused path below computes a real dg and reduces it back to g_gamma.shape, but this branch hard-codes zeros and exits early. That silently drops gradients for learnable g_gamma whenever the fused TPU path is selected.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/gla/chunk.py` around lines 1049 - 1051, The fused-TPU branch
currently returns a hard-coded zero tensor for dg (dg = jnp.zeros_like(g_gamma))
which drops gradients for the learnable g_gamma; replace this with the same
gradient computation and reduction used in the non-fused path so dg contains the
real gradient reduced back to g_gamma.shape (or call the existing reduction
helper used by the non-fused branch), and return that reduced dg instead of
zeros so g_gamma receives correct gradients.
Mosaic compiler fails with "Bad lhs type" when K/V grid dimensions use "arbitrary" semantics, likely causing incorrect vectorization of bf16 matmuls. These dimensions are independent (each computes a separate BK x BV block of h), so "parallel" is semantically correct and avoids the compiler issue. Only the time dimension needs "arbitrary" for sequential h-state accumulation in VMEM scratch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rnels Mosaic compiler on TPU v7x fails with "Bad lhs type" when compiling bf16 matmuls inside Pallas kernels with PrefetchScalarGridSpec and VMEM scratch. The bf16 inputs appear to get incorrectly vectorized in the internal MLIR representation (128→256 dimension doubling). Fix by keeping all matmul inputs in float32 within the kernel body. This is semantically identical (all matmuls already accumulated in f32 via preferred_element_type) and the performance impact is negligible for these tile sizes (BT=64, BK=BV=128). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chunk_fwd_h_kernel expects cu_seqlens_cpu (not cu_seqlens) and chunk_bwd_dh_kernel expects cu_seqlens_dev (not cu_seqlens). These mismatches caused TypeError when running the non-fused reference path. Pre-existing bug on main, fixed here. Also includes f32 matmul inputs fix for Mosaic v7x compatibility. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tops/ops/gla/chunk_fused_kernels.py (1)
86-87: Inconsistent use ofexpvsjnp.expthroughout kernels.The forward kernel mixes
jnp.exp(lines 86-87, 132) andexpfromtops.ops.utils(line 129). Similarly, the backward kernel (lines 320-322, 411, 413) uses the utilityexp. Ifexpprovides numerical stability benefits, consider using it consistently throughout both kernels.Also applies to: 129-129, 132-132
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gla/chunk_fused_kernels.py` around lines 86 - 87, The forward and backward fused kernels mix jnp.exp and the utility exp from tops.ops.utils (e.g., in calculations producing exp_g/exp_neg_g around b_g_ramp and other uses in the backward pass), causing inconsistent numerical behavior; pick the stable implementation and use it consistently across both kernels by replacing all jnp.exp calls with the utility exp (or replacing the utility with jnp.exp if you decide to drop the helper), update imports accordingly, and ensure symbols like b_g_ramp, exp_g, exp_neg_g and any backward uses of exp refer to the same chosen implementation so both forward (exp_g/exp_neg_g) and backward computations use the identical exp function.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 67-69: b_g_last is computed without casting while b_g_ramp uses
g_gamma[i_h].astype(jnp.float32), causing potential dtype inconsistencies;
change the b_g_last computation to use the same float32 cast (e.g., compute
b_g_last from g_gamma[i_h].astype(jnp.float32) * BT) so both b_g_ramp and
b_g_last are consistently float32 when later used.
---
Nitpick comments:
In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 86-87: The forward and backward fused kernels mix jnp.exp and the
utility exp from tops.ops.utils (e.g., in calculations producing exp_g/exp_neg_g
around b_g_ramp and other uses in the backward pass), causing inconsistent
numerical behavior; pick the stable implementation and use it consistently
across both kernels by replacing all jnp.exp calls with the utility exp (or
replacing the utility with jnp.exp if you decide to drop the helper), update
imports accordingly, and ensure symbols like b_g_ramp, exp_g, exp_neg_g and any
backward uses of exp refer to the same chosen implementation so both forward
(exp_g/exp_neg_g) and backward computations use the identical exp function.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 01d3b524-7866-455f-97d7-1e9ff2c5667c
📒 Files selected for processing (1)
tops/ops/gla/chunk_fused_kernels.py
The reference path was casting inputs to float32 before calling the non-fused kernel, giving it higher precision than the fused kernel which operates on bfloat16 inputs. This caused numerical mismatches at tight tolerances. Pass bfloat16 inputs to both paths for fair comparison. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
♻️ Duplicate comments (2)
tops/ops/gla/chunk.py (2)
1019-1021:⚠️ Potential issue | 🔴 CriticalNormalize 1-D
g_gammabefore fallbackg_cumsumbroadcast.The new fast-path guards accept 1-D
g_gamma, but fallback analytic cumsum still uses rawg_gamma * pos(Line 1059 and Line 1197). With JAX right-aligned broadcasting, this can fail whenH != K.#!/bin/bash # Verify guard + fallback pattern coexist without 1-D reshape normalization. rg -n -C2 'g_gamma\.ndim == 1|broadcast_to\(g_gamma \* pos, q\.shape\)' tops/ops/gla/chunk.pySuggested fix
@@ - g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape) + gg = g_gamma.reshape((1, 1, g_gamma.shape[0], 1)) if g_gamma.ndim == 1 else g_gamma + g_cumsum = jnp.broadcast_to(gg * pos, q.shape) @@ - g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape) + gg = g_gamma.reshape((1, 1, g_gamma.shape[0], 1)) if g_gamma.ndim == 1 else g_gamma + g_cumsum = jnp.broadcast_to(gg * pos, q.shape)Also applies to: 1146-1149
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gla/chunk.py` around lines 1019 - 1021, The fallback analytic cumsum path uses raw g_gamma * pos with JAX right-aligned broadcasting which fails when g_gamma is 1-D and H != K; ensure g_gamma is normalized to a 2-D row before any fallback broadcast by inserting a reshape when g_gamma.ndim == 1 (e.g., g_gamma = g_gamma.reshape(1, -1)) so subsequent uses in g_cumsum, broadcast_to(g_gamma * pos, q.shape) and any multiplications with pos/q shapes work; apply this normalization in the same guarded branches that check g_gamma.ndim == 1 (affecting the code around the g_cumsum fallback and the other occurrences near the noted blocks).
1049-1051:⚠️ Potential issue | 🔴 CriticalDon’t zero-out
dgin fused backward path.Line 1049 currently returns a fake zero gradient for
g_gamma; this silently drops gradients relative to the non-fused path semantics.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gla/chunk.py` around lines 1049 - 1051, The fused backward path currently creates a fake zero gradient by assigning dg = jnp.zeros_like(g_gamma) which drops g_gamma gradients; remove that zeroing and propagate the true gradient instead (e.g., stop returning zeros and return the computed g_gamma gradient itself). Locate the fused-backward return that yields "dq, dk, dv, dg, None" and replace the fake dg with the actual gradient (use g_gamma or the computed gradient value produced earlier in this fused path) so semantics match the non-fused path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tops/ops/gla/chunk.py`:
- Around line 1019-1021: The fallback analytic cumsum path uses raw g_gamma *
pos with JAX right-aligned broadcasting which fails when g_gamma is 1-D and H !=
K; ensure g_gamma is normalized to a 2-D row before any fallback broadcast by
inserting a reshape when g_gamma.ndim == 1 (e.g., g_gamma = g_gamma.reshape(1,
-1)) so subsequent uses in g_cumsum, broadcast_to(g_gamma * pos, q.shape) and
any multiplications with pos/q shapes work; apply this normalization in the same
guarded branches that check g_gamma.ndim == 1 (affecting the code around the
g_cumsum fallback and the other occurrences near the noted blocks).
- Around line 1049-1051: The fused backward path currently creates a fake zero
gradient by assigning dg = jnp.zeros_like(g_gamma) which drops g_gamma
gradients; remove that zeroing and propagate the true gradient instead (e.g.,
stop returning zeros and return the computed g_gamma gradient itself). Locate
the fused-backward return that yields "dq, dk, dv, dg, None" and replace the
fake dg with the actual gradient (use g_gamma or the computed gradient value
produced earlier in this fused path) so semantics match the non-fused path.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/ops/gla/test_pallas_chunk_fused_bwd.py (1)
94-119: Consider parameterizing the two test cases.Both
test_fused_bwd_basicandtest_fused_bwd_al_dimsfollow identical logic with different dimension parameters. Using@pytest.mark.parametrizewould reduce duplication:`@pytest.mark.parametrize`("B,T,H,K,V,C,seed", [ (2, 256, 4, 128, 128, 64, 42), # basic (2, 4096, 16, 128, 128, 64, 123), # AL dims ]) def test_fused_bwd(self, B, T, H, K, V, C, seed): ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py` around lines 94 - 119, These two tests (test_fused_bwd_basic and test_fused_bwd_al_dims) duplicate the same logic with only different dimension/seed values; refactor them into a single parameterized test using pytest.mark.parametrize so the same body runs for each tuple (B, T, H, K, V, C, seed), and update the test function name (e.g., test_fused_bwd) to accept those parameters and call _make_test_data, chunk_fwd_fused_g_gamma, chunk_bwd_fused_g_gamma and self._run_reference as before; keep the same assertions (compare_tensor) and tolerances but remove the duplicated test functions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py`:
- Around line 94-119: These two tests (test_fused_bwd_basic and
test_fused_bwd_al_dims) duplicate the same logic with only different
dimension/seed values; refactor them into a single parameterized test using
pytest.mark.parametrize so the same body runs for each tuple (B, T, H, K, V, C,
seed), and update the test function name (e.g., test_fused_bwd) to accept those
parameters and call _make_test_data, chunk_fwd_fused_g_gamma,
chunk_bwd_fused_g_gamma and self._run_reference as before; keep the same
assertions (compare_tensor) and tolerances but remove the duplicated test
functions.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 09d0abee-adcc-42f1-bdd0-bd9614e9119c
📒 Files selected for processing (2)
tests/ops/gla/test_pallas_chunk_fused_bwd.pytests/ops/gla/test_pallas_chunk_fused_fwd.py
✅ Files skipped from review due to trivial changes (1)
- tests/ops/gla/test_pallas_chunk_fused_fwd.py
The non-fused Pallas kernels have pre-existing bf16 matmul compilation
issues on TPU v7x ("Bad lhs type" in Mosaic). Switch all fused kernel
tests to compare against the pure JAX naive_recurrent_gla reference
which uses no Pallas kernels and works on any backend.
Also remove the large T=4096 test cases for backward (too slow with
naive recurrent) and add small-dimension tests instead.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Unscaled g_gamma values (~-1.0) cause exp(g_gamma * BT) overflow with BT=64, producing NaN. Add * 0.1 scaling to match fwd/bwd tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Enforce K==128, V==128 assertions in fused kernels to prevent silent data corruption from multi-tile overwrites (Gemini, CodeRabbit) - Fix b_g_last dtype: add .astype(jnp.float32) for consistency - Use exp() from tops.ops.utils consistently (replace jnp.exp) - Clarify dg=zeros semantics: g_gamma mode treats gate as constant; pass 4-D g_gamma to get real gradients via non-fused path - Fix 1-D g_gamma broadcast: reshape to (1,1,H,1) before fallback broadcast_to in chunk_gla_fwd and chunk_gla_bwd_with_pl - Hardcode grid K/V dims to 1 (matching the K==128, V==128 invariant) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rse index_map Replace pre-flip/post-flip data transformation chain with reverse BlockSpec index_maps (t → NT-1-t) and in-kernel reverse output writes (NT-1-i_t). This eliminates 8 physical copies (5 input flips + 3 output flips) that previously consumed 288 MiB of HBM for B=2,T=4096,H=16,K=128,V=128. Training peak: 480 MiB → 288 MiB (-40%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Analytical HBM footprint comparison covering activation retention, forward/backward peak, and training peak. Shows ~40% memory reduction for fused kernels by eliminating g_cumsum, A, dh, dg tensors and jnp.flip copies via reverse index_maps. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
pallas_callpallas_callg_gammamode, fixed-length sequences, TPU runtimeOptimizations
Ported from Glaucis kernel-evolve optimization achieving 9.005x speedup on TPU v7x:
Scope
Files
tops/ops/gla/chunk_fused_kernels.pytops/ops/gla/chunk.pychunk_gla_fwdandchunk_gla_bwd_with_pltests/ops/gla/test_pallas_chunk_fused_fwd.pytests/ops/gla/test_pallas_chunk_fused_bwd.pytests/ops/gla/test_pallas_chunk_gla_fused_e2e.pybenchmarks/ops/benchmark_gla.pychunk_fusedandchunk_fused_bwdprovidersDispatch conditions
The fused path is used when ALL of:
g_gamma is not Noneandg_gamma.ndim == 1g is None(no per-element gates)cu_seqlens is None(fixed-length)initial_state is Noneis_tpu_runtime()returns Trueoutput_final_state is Falsedht is NoneTPU v7x verification (2026-03-31)
Tested on GKE TPU v7x (Ironwood) pod, 2x2x1 topology, JAX 0.10.0.dev20260330.
Correctness — all fused kernel tests pass
Benchmark — fused vs non-fused (B=2, H=4, D=128, bf16)
Note: non-fused
chunktimes include JIT compilation overhead. The fused kernels also handle T=4096 where non-fused hits VMEM OOM.Pre-existing issues (not introduced by this PR)
test_pallas_chunk_bwd.py: all fail —chunk_gla_bwd()missingg_gammaargumenttest_pallas_chunk_bwd_dh.py: Mosaic bf16 compilation errors on v7xtest_pallas_fused_recurrent_bwd.py: Mosaic alignment errors on v7xchunk.py: fixedcu_seqlens→cu_seqlens_cpu/cu_seqlens_devparameter naming (pre-existing bug)Test plan
test_pallas_chunk_fused_fwd.py— fused forward vs naive recurrent referencetest_pallas_chunk_fused_bwd.py— fused backward vs jax.grad on naive referencetest_pallas_chunk_gla_fused_e2e.py— end-to-end dispatch path🤖 Generated with Claude Code