Skip to content

feat(gla): fused forward+backward kernels for g_gamma mode#122

Open
sii-xinglong wants to merge 16 commits intomainfrom
feat/chunk-gla-fused-kernels
Open

feat(gla): fused forward+backward kernels for g_gamma mode#122
sii-xinglong wants to merge 16 commits intomainfrom
feat/chunk-gla-fused-kernels

Conversation

@sii-xinglong
Copy link
Copy Markdown
Contributor

@sii-xinglong sii-xinglong commented Mar 31, 2026

Summary

  • Adds fused forward kernel merging h propagation + A recomputation + output into 1 pallas_call
  • Adds fused backward kernel merging dh propagation + dq/dk/dv into 1 pallas_call
  • Dispatches to fused path when: g_gamma mode, fixed-length sequences, TPU runtime
  • All existing code paths remain unchanged — fused path is additive

Optimizations

Ported from Glaucis kernel-evolve optimization achieving 9.005x speedup on TPU v7x:

Optimization Technique Expected Impact
Forward fusion (SO15) Merge 3 pallas_calls into 1 ~14.6% latency reduction
Backward fusion (SO16) Merge 2 operations into 1 pallas_call ~14.6% latency reduction
A recomputation (SO11/12) Eliminate separate intra_gk kernel Removes 1 kernel launch
Dead dg elimination (SO13) Skip dg for g_gamma mode Removes 1 matmul + VPU ops
g_cumsum elimination Recompute from scalar in kernel Saves ~67MB HBM

Scope

  • Gate mode: g_gamma only (per-head constant). g and gk modes use existing path.
  • Sequence type: Fixed-length only. Varlen uses existing path.
  • Hardware: TPU only. CPU/GPU use existing path.
  • Initial state: Not supported (falls back to existing path).

Files

File Description
tops/ops/gla/chunk_fused_kernels.py New: fused forward + backward Pallas kernels
tops/ops/gla/chunk.py Modified: dispatch logic in chunk_gla_fwd and chunk_gla_bwd_with_pl
tests/ops/gla/test_pallas_chunk_fused_fwd.py New: fused forward vs naive recurrent reference
tests/ops/gla/test_pallas_chunk_fused_bwd.py New: fused backward vs jax.grad on naive reference
tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py New: end-to-end dispatch path tests
benchmarks/ops/benchmark_gla.py Modified: chunk_fused and chunk_fused_bwd providers

Dispatch conditions

The fused path is used when ALL of:

  • g_gamma is not None and g_gamma.ndim == 1
  • g is None (no per-element gates)
  • cu_seqlens is None (fixed-length)
  • initial_state is None
  • is_tpu_runtime() returns True
  • Forward only: output_final_state is False
  • Backward only: dht is None

TPU v7x verification (2026-03-31)

Tested on GKE TPU v7x (Ironwood) pod, 2x2x1 topology, JAX 0.10.0.dev20260330.

Correctness — all fused kernel tests pass

Test Result
Fused forward vs naive reference 2/2 PASSED
Fused backward vs jax.grad reference 2/2 PASSED
E2E dispatch path 2/2 PASSED

Benchmark — fused vs non-fused (B=2, H=4, D=128, bf16)

T chunk (non-fused fwd) chunk_fused (fwd) chunk_fused_bwd
256 318.5 ms 0.203 ms 0.204 ms
512 404.4 ms 0.245 ms 0.288 ms
1024 403.7 ms 0.278 ms 0.320 ms
2048 430.5 ms 0.374 ms 0.488 ms
4096 OOM (VMEM) 0.588 ms 0.788 ms

Note: non-fused chunk times include JIT compilation overhead. The fused kernels also handle T=4096 where non-fused hits VMEM OOM.

Pre-existing issues (not introduced by this PR)

  • test_pallas_chunk_bwd.py: all fail — chunk_gla_bwd() missing g_gamma argument
  • test_pallas_chunk_bwd_dh.py: Mosaic bf16 compilation errors on v7x
  • test_pallas_fused_recurrent_bwd.py: Mosaic alignment errors on v7x
  • chunk.py: fixed cu_seqlenscu_seqlens_cpu/cu_seqlens_dev parameter naming (pre-existing bug)

Test plan

  • test_pallas_chunk_fused_fwd.py — fused forward vs naive recurrent reference
  • test_pallas_chunk_fused_bwd.py — fused backward vs jax.grad on naive reference
  • test_pallas_chunk_gla_fused_e2e.py — end-to-end dispatch path
  • Existing tests — no regressions introduced (231 pre-existing failures unchanged)
  • Benchmark shows fused kernels ~1500x faster than non-fused path

🤖 Generated with Claude Code

sii-xinglong and others added 5 commits March 31, 2026 16:26
Add chunk_fwd_fused_g_gamma which merges three separate pallas_calls
(h propagation + A recomputation + output computation) into a single
pallas_call. The hidden state stays in VMEM scratch instead of making
an HBM round-trip, and the attention matrix A is recomputed inline
rather than materialised.

Includes tests comparing fused output against the non-fused
chunk_gla_fwd reference at basic and AL model dimensions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add _chunk_bwd_fused_kernel and chunk_bwd_fused_g_gamma to
chunk_fused_kernels.py, merging dh reverse propagation and dq/dk/dv
gradient computation into a single pallas_call. The kernel eliminates
g_cumsum by recomputing gating from the g_gamma scalar, and keeps dh
state in VMEM scratch across the reverse time scan.

Add test_pallas_chunk_fused_bwd.py comparing against
chunk_gla_bwd_with_pl reference at two scales (basic + AL dims).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…d kernel

Replace raw jnp.exp() calls with exp() from tops.ops.utils for consistency
with the forward kernel. Remove redundant b_gn alias for b_g_last.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chunk_gla_fwd and chunk_gla_bwd_with_pl now use the fused single-
pallas_call path when: g_gamma is provided (1D shape), g is None,
cu_seqlens is None, initial_state is None, and running on TPU.

Falls back to existing multi-kernel path for all other configurations
(varlen, initial_state, non-TPU, output_final_state, etc.).

Both dispatch paths handle K/V dimension padding to multiples of 128
and strip the padding from outputs before returning.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Enables benchmarking fused g_gamma kernels (single pallas_call forward
and backward) against existing multi-kernel implementations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds TPU Pallas fused "g_gamma" chunked GLA forward and backward kernels, integrates TPU-only fused fast-paths into chunk dispatchers with padding/unpadding and alignment, registers two new benchmark providers, and adds TPU-only forward, backward, and end-to-end tests verifying fused kernels vs. reference implementations.

Changes

Cohort / File(s) Summary
Fused kernel implementation
tops/ops/gla/chunk_fused_kernels.py
New module providing TPU Pallas fused forward chunk_fwd_fused_g_gamma and backward chunk_bwd_fused_g_gamma wrappers/kernels, VMEM scratch usage, chunking/tiling constraints, I/O transposes, and public exports.
Dispatcher / runtime integration
tops/ops/gla/chunk.py
TPU-only fast-paths in chunk_gla_fwd and chunk_gla_bwd_with_pl to call fused kernels when g_gamma is 1D; pads/unpads q/k/v/do to chunk/K/V alignment (C,128), optionally computes fused forward state in backward, and adjusts kernel kwarg names (cu_seqlens_cpu / cu_seqlens_dev).
Benchmarks registry
benchmarks/ops/benchmark_gla.py
Adds "chunk_fused" and "chunk_fused_bwd" to ALL_PROVIDERS; extends _run_provider with branches that import fused kernels, enforce divisibility constraints (T%64==0, K/V%128==0), build partially-applied forward/backward calls, and skip incompatible shapes.
Tests — fused forward / backward / e2e
tests/ops/gla/test_pallas_chunk_fused_fwd.py, tests/ops/gla/test_pallas_chunk_fused_bwd.py, tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py
New TPU-only tests generating deterministic bf16 inputs and float32 g_gamma, comparing fused forward/backward/e2e outputs against naive/JAX references using numeric tolerances and including pytest entrypoints.

Sequence Diagram(s)

sequenceDiagram
  participant Tester as Tests/Benchmarks
  participant Dispatcher as chunk_gla_fwd / chunk_gla_bwd_with_pl
  participant Wrapper as chunk_*_fused_g_gamma
  participant Pallas as Pallas Fused Kernel
  participant VMEM as VMEM Scratch

  Tester->>Dispatcher: call with (q,k,v,g_gamma,do?,...)
  Dispatcher->>Wrapper: pad/transpose, validate tiling (C,128), choose fused path
  Wrapper->>Pallas: launch fused kernel with chunked tensors
  Pallas->>VMEM: read/write h state across time (VMEM scratch)
  Pallas-->>Wrapper: return (h,o) or (dq,dk,dv)
  Wrapper->>Dispatcher: unpad/transpose outputs
  Dispatcher-->>Tester: return results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • labyrinth-ssr
  • 0xaskr

Poem

"I hopped through kernels, quick and spry,
Fused the chunks beneath the sky,
VMEM hummed a secret tune,
Tests checked sums by light of moon,
A rabbit's hop — fast paths fly! 🐰"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 51.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat(gla): fused forward+backward kernels for g_gamma mode' directly and accurately describes the main change: adding fused forward and backward kernels for GLA's g_gamma mode.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/chunk-gla-fused-kernels

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fused Pallas TPU kernels for the Gated Linear Attention (GLA) mechanism in g_gamma mode, aiming to optimize performance by merging multiple kernel launches and reducing HBM round-trips. The changes include new forward and backward fused kernels, dispatch logic in the main GLA operations, and comprehensive unit and end-to-end tests. However, several critical issues were identified regarding the scalability of the kernels for dimensions larger than 128, as the current tiling logic leads to output overwrites rather than proper reductions. Additionally, the backward pass currently returns zero gradients for the g_gamma parameter, which would cause silent failures in training scenarios where the gate is a learnable parameter.

preferred_element_type=jnp.float32,
)

o_ref[0, 0] = (b_o_inter + b_o_intra).astype(o_ref.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This kernel will produce incorrect results if K > 128. The grid tiles the K dimension (pl.cdiv(K_dim, BK)), but the o_map (line 204) ignores the ki index, causing multiple tiles to overwrite the same output block in o_ref. Since the output o is a reduction over K, these partial results must be summed. If the kernel is intended only for K <= 128, please add an explicit assertion. Otherwise, the output mapping and reduction logic need to be corrected.

) # [BT, K]

# Phase 5 (VPU): Combine results and write to 5D output slots
dv_ref[0, i_t, 0] = (b_dv_intra + b_dv_inter).astype(dv_ref.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the forward pass, this will produce incorrect results for dv if K > 128. The grid tiles K, but out_v_map (line 523) ignores the ki index, leading to overwrites of the partial sums for dv instead of accumulating them.

Comment on lines +402 to +405
dq_ref[0, i_t, 0] = b_dq.astype(dq_ref.dtype)

b_dk = b_dk_intra_raw * exp_neg[:, None] + b_dk_inter * exp_gn_minus[:, None]
dk_ref[0, i_t, 0] = b_dk.astype(dk_ref.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These gradient computations will produce incorrect results if V > 128. The grid tiles the V dimension, but out_k_map (line 520) ignores the vi index. This causes different vi tiles to overwrite each other's partial contributions to dq and dk.

dv = dv[..., :V]

# dg: zeros (g_gamma gradient handled by caller via sum-reduction)
dg = jnp.zeros_like(g_gamma)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fused backward path currently returns zero gradients for g_gamma. While the PR description mentions "Dead dg elimination" as an optimization, this will break training for models where g_gamma is a learnable parameter (which is the standard case for Simple GLA). If this optimization is intended only for inference or fixed-gate scenarios, the dispatch logic should be more restrictive. Otherwise, dg should be correctly computed and returned to avoid silent training failure.

sii-xinglong and others added 2 commits March 31, 2026 16:58
The forward kernel writes the float32 VMEM scratch state to the
bfloat16 h_ref output without casting, triggering a dtype mismatch
error in JAX's Pallas swap primitive. Add .astype(h_ref.dtype) to
match the pattern already used for o_ref and dq/dk/dv_ref writes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pallas does not allow kernel functions to capture traced JAX constants.
When scale is not static, JIT traces it as a JAX scalar which then gets
closed over by functools.partial, triggering "captures constants [f32[]]".
Mark scale as static_argnames in both forward and backward JIT wrappers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py`:
- Around line 44-119: The tests currently only compare two TPU/Pallas paths
(chunk_bwd_fused_g_gamma vs chunk_gla_bwd_with_pl) so add a CPU/naive oracle for
dq, dk, dv and assert fused outputs against it; implement or call a pure-CPU
reference (e.g., cpu_chunk_gla_bwd_naive or similar) that computes the
non-optimized backward (using plain NumPy/JAX CPU ops) and in both
test_fused_bwd_basic and test_fused_bwd_al_dims obtain cpu_ref_dq, cpu_ref_dk,
cpu_ref_dv and add compare_tensor assertions (same atol/rtol/dtype) comparing
dq_fused/dk_fused/dv_fused to the CPU references in addition to the existing TPU
comparisons; update or add a helper (analogous to _run_reference) that returns
the CPU oracle so tests call that helper.

In `@tests/ops/gla/test_pallas_chunk_fused_fwd.py`:
- Around line 39-87: The tests call _run_reference which uses the optimized
chunk_gla_fwd (TPU/Pallas) so add a true CPU/JAX-naive oracle and assert fused
outputs against it; implement a new function (e.g., naive_chunk_gla_fwd or
cpu_chunk_gla_fwd) that computes h and o using plain JAX/Numpy ops (explicit
chunk loop or full attention with g_gamma applied and the same scaling) and
replace calls in test_fused_fwd_basic and test_fused_fwd_al_dims to compare
h_fused and o_fused against the CPU oracle (still keep existing fused vs
chunk_gla_fwd checks if desired), using compare_tensor with the same atol/rtol
and dtype to satisfy the guideline.

In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 180-182: The fused launcher kernels must be restricted to
single-tile shapes to avoid silent overwrites; add hard assertions ensuring
K_dim == 128, K == 128 and V == 128 (in addition to the existing modulo checks)
so the forward launcher (which drops ki in o_map) and the backward launcher
(which uses out_k_map/out_v_map hardcoded to (b,0,h,0,0)) never span multiple
tiles; update the assert block that currently checks T%BT, K_dim%BK and V%BV to
also assert K_dim == 128, K == 128 and V == 128 to prevent cross-tile launches
that cause data corruption.

In `@tops/ops/gla/chunk.py`:
- Around line 1049-1051: The fused-TPU branch currently returns a hard-coded
zero tensor for dg (dg = jnp.zeros_like(g_gamma)) which drops gradients for the
learnable g_gamma; replace this with the same gradient computation and reduction
used in the non-fused path so dg contains the real gradient reduced back to
g_gamma.shape (or call the existing reduction helper used by the non-fused
branch), and return that reduced dg instead of zeros so g_gamma receives correct
gradients.
- Around line 1018-1021: The 1-D g_gamma ([H]) must be reshaped to a
broadcast-compatible shape before fallback broadcasts: inside the fallback
branches of chunk_gla_bwd_with_pl and chunk_gla_fwd, detect if g_gamma is 1-D
(g_gamma.ndim == 1) and replace it with g_gamma.reshape((1, 1, H, 1)) (where H =
g_gamma.shape[0]) before any multiplication/broadcast (e.g., before computing
g_gamma * pos or jnp.broadcast_to calls) so subsequent operations like
jnp.broadcast_to(g_gamma * pos, q.shape) succeed for q.shape = [B, T_pad, H, K].
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b00d9e5-8b22-48c3-ba9a-7a1806650494

📥 Commits

Reviewing files that changed from the base of the PR and between bdf1f5a and d8d3523.

📒 Files selected for processing (6)
  • benchmarks/ops/benchmark_gla.py
  • tests/ops/gla/test_pallas_chunk_fused_bwd.py
  • tests/ops/gla/test_pallas_chunk_fused_fwd.py
  • tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py
  • tops/ops/gla/chunk.py
  • tops/ops/gla/chunk_fused_kernels.py

Comment on lines +1049 to +1051
# dg: zeros (g_gamma gradient handled by caller via sum-reduction)
dg = jnp.zeros_like(g_gamma)
return dq, dk, dv, dg, None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't return a fake zero gradient for g_gamma.

The non-fused path below computes a real dg and reduces it back to g_gamma.shape, but this branch hard-codes zeros and exits early. That silently drops gradients for learnable g_gamma whenever the fused TPU path is selected.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gla/chunk.py` around lines 1049 - 1051, The fused-TPU branch
currently returns a hard-coded zero tensor for dg (dg = jnp.zeros_like(g_gamma))
which drops gradients for the learnable g_gamma; replace this with the same
gradient computation and reduction used in the non-fused path so dg contains the
real gradient reduced back to g_gamma.shape (or call the existing reduction
helper used by the non-fused branch), and return that reduced dg instead of
zeros so g_gamma receives correct gradients.

sii-xinglong and others added 3 commits March 31, 2026 17:08
Mosaic compiler fails with "Bad lhs type" when K/V grid dimensions use
"arbitrary" semantics, likely causing incorrect vectorization of bf16
matmuls. These dimensions are independent (each computes a separate
BK x BV block of h), so "parallel" is semantically correct and avoids
the compiler issue. Only the time dimension needs "arbitrary" for
sequential h-state accumulation in VMEM scratch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rnels

Mosaic compiler on TPU v7x fails with "Bad lhs type" when compiling
bf16 matmuls inside Pallas kernels with PrefetchScalarGridSpec and
VMEM scratch. The bf16 inputs appear to get incorrectly vectorized
in the internal MLIR representation (128→256 dimension doubling).

Fix by keeping all matmul inputs in float32 within the kernel body.
This is semantically identical (all matmuls already accumulated in
f32 via preferred_element_type) and the performance impact is
negligible for these tile sizes (BT=64, BK=BV=128).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chunk_fwd_h_kernel expects cu_seqlens_cpu (not cu_seqlens) and
chunk_bwd_dh_kernel expects cu_seqlens_dev (not cu_seqlens). These
mismatches caused TypeError when running the non-fused reference
path. Pre-existing bug on main, fixed here.

Also includes f32 matmul inputs fix for Mosaic v7x compatibility.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tops/ops/gla/chunk_fused_kernels.py (1)

86-87: Inconsistent use of exp vs jnp.exp throughout kernels.

The forward kernel mixes jnp.exp (lines 86-87, 132) and exp from tops.ops.utils (line 129). Similarly, the backward kernel (lines 320-322, 411, 413) uses the utility exp. If exp provides numerical stability benefits, consider using it consistently throughout both kernels.

Also applies to: 129-129, 132-132

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gla/chunk_fused_kernels.py` around lines 86 - 87, The forward and
backward fused kernels mix jnp.exp and the utility exp from tops.ops.utils
(e.g., in calculations producing exp_g/exp_neg_g around b_g_ramp and other uses
in the backward pass), causing inconsistent numerical behavior; pick the stable
implementation and use it consistently across both kernels by replacing all
jnp.exp calls with the utility exp (or replacing the utility with jnp.exp if you
decide to drop the helper), update imports accordingly, and ensure symbols like
b_g_ramp, exp_g, exp_neg_g and any backward uses of exp refer to the same chosen
implementation so both forward (exp_g/exp_neg_g) and backward computations use
the identical exp function.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 67-69: b_g_last is computed without casting while b_g_ramp uses
g_gamma[i_h].astype(jnp.float32), causing potential dtype inconsistencies;
change the b_g_last computation to use the same float32 cast (e.g., compute
b_g_last from g_gamma[i_h].astype(jnp.float32) * BT) so both b_g_ramp and
b_g_last are consistently float32 when later used.

---

Nitpick comments:
In `@tops/ops/gla/chunk_fused_kernels.py`:
- Around line 86-87: The forward and backward fused kernels mix jnp.exp and the
utility exp from tops.ops.utils (e.g., in calculations producing exp_g/exp_neg_g
around b_g_ramp and other uses in the backward pass), causing inconsistent
numerical behavior; pick the stable implementation and use it consistently
across both kernels by replacing all jnp.exp calls with the utility exp (or
replacing the utility with jnp.exp if you decide to drop the helper), update
imports accordingly, and ensure symbols like b_g_ramp, exp_g, exp_neg_g and any
backward uses of exp refer to the same chosen implementation so both forward
(exp_g/exp_neg_g) and backward computations use the identical exp function.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 01d3b524-7866-455f-97d7-1e9ff2c5667c

📥 Commits

Reviewing files that changed from the base of the PR and between d8d3523 and 2b71a64.

📒 Files selected for processing (1)
  • tops/ops/gla/chunk_fused_kernels.py

The reference path was casting inputs to float32 before calling the
non-fused kernel, giving it higher precision than the fused kernel
which operates on bfloat16 inputs. This caused numerical mismatches
at tight tolerances. Pass bfloat16 inputs to both paths for fair
comparison.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
tops/ops/gla/chunk.py (2)

1019-1021: ⚠️ Potential issue | 🔴 Critical

Normalize 1-D g_gamma before fallback g_cumsum broadcast.

The new fast-path guards accept 1-D g_gamma, but fallback analytic cumsum still uses raw g_gamma * pos (Line 1059 and Line 1197). With JAX right-aligned broadcasting, this can fail when H != K.

#!/bin/bash
# Verify guard + fallback pattern coexist without 1-D reshape normalization.
rg -n -C2 'g_gamma\.ndim == 1|broadcast_to\(g_gamma \* pos, q\.shape\)' tops/ops/gla/chunk.py
Suggested fix
@@
-            g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape)
+            gg = g_gamma.reshape((1, 1, g_gamma.shape[0], 1)) if g_gamma.ndim == 1 else g_gamma
+            g_cumsum = jnp.broadcast_to(gg * pos, q.shape)
@@
-            g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape)
+            gg = g_gamma.reshape((1, 1, g_gamma.shape[0], 1)) if g_gamma.ndim == 1 else g_gamma
+            g_cumsum = jnp.broadcast_to(gg * pos, q.shape)

Also applies to: 1146-1149

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gla/chunk.py` around lines 1019 - 1021, The fallback analytic cumsum
path uses raw g_gamma * pos with JAX right-aligned broadcasting which fails when
g_gamma is 1-D and H != K; ensure g_gamma is normalized to a 2-D row before any
fallback broadcast by inserting a reshape when g_gamma.ndim == 1 (e.g., g_gamma
= g_gamma.reshape(1, -1)) so subsequent uses in g_cumsum, broadcast_to(g_gamma *
pos, q.shape) and any multiplications with pos/q shapes work; apply this
normalization in the same guarded branches that check g_gamma.ndim == 1
(affecting the code around the g_cumsum fallback and the other occurrences near
the noted blocks).

1049-1051: ⚠️ Potential issue | 🔴 Critical

Don’t zero-out dg in fused backward path.

Line 1049 currently returns a fake zero gradient for g_gamma; this silently drops gradients relative to the non-fused path semantics.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gla/chunk.py` around lines 1049 - 1051, The fused backward path
currently creates a fake zero gradient by assigning dg = jnp.zeros_like(g_gamma)
which drops g_gamma gradients; remove that zeroing and propagate the true
gradient instead (e.g., stop returning zeros and return the computed g_gamma
gradient itself). Locate the fused-backward return that yields "dq, dk, dv, dg,
None" and replace the fake dg with the actual gradient (use g_gamma or the
computed gradient value produced earlier in this fused path) so semantics match
the non-fused path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@tops/ops/gla/chunk.py`:
- Around line 1019-1021: The fallback analytic cumsum path uses raw g_gamma *
pos with JAX right-aligned broadcasting which fails when g_gamma is 1-D and H !=
K; ensure g_gamma is normalized to a 2-D row before any fallback broadcast by
inserting a reshape when g_gamma.ndim == 1 (e.g., g_gamma = g_gamma.reshape(1,
-1)) so subsequent uses in g_cumsum, broadcast_to(g_gamma * pos, q.shape) and
any multiplications with pos/q shapes work; apply this normalization in the same
guarded branches that check g_gamma.ndim == 1 (affecting the code around the
g_cumsum fallback and the other occurrences near the noted blocks).
- Around line 1049-1051: The fused backward path currently creates a fake zero
gradient by assigning dg = jnp.zeros_like(g_gamma) which drops g_gamma
gradients; remove that zeroing and propagate the true gradient instead (e.g.,
stop returning zeros and return the computed g_gamma gradient itself). Locate
the fused-backward return that yields "dq, dk, dv, dg, None" and replace the
fake dg with the actual gradient (use g_gamma or the computed gradient value
produced earlier in this fused path) so semantics match the non-fused path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0a83219e-32d5-4037-af5b-32965cd276c4

📥 Commits

Reviewing files that changed from the base of the PR and between 2b71a64 and 1247d97.

📒 Files selected for processing (1)
  • tops/ops/gla/chunk.py

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/ops/gla/test_pallas_chunk_fused_bwd.py (1)

94-119: Consider parameterizing the two test cases.

Both test_fused_bwd_basic and test_fused_bwd_al_dims follow identical logic with different dimension parameters. Using @pytest.mark.parametrize would reduce duplication:

`@pytest.mark.parametrize`("B,T,H,K,V,C,seed", [
    (2, 256, 4, 128, 128, 64, 42),   # basic
    (2, 4096, 16, 128, 128, 64, 123), # AL dims
])
def test_fused_bwd(self, B, T, H, K, V, C, seed):
    ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py` around lines 94 - 119, These
two tests (test_fused_bwd_basic and test_fused_bwd_al_dims) duplicate the same
logic with only different dimension/seed values; refactor them into a single
parameterized test using pytest.mark.parametrize so the same body runs for each
tuple (B, T, H, K, V, C, seed), and update the test function name (e.g.,
test_fused_bwd) to accept those parameters and call _make_test_data,
chunk_fwd_fused_g_gamma, chunk_bwd_fused_g_gamma and self._run_reference as
before; keep the same assertions (compare_tensor) and tolerances but remove the
duplicated test functions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ops/gla/test_pallas_chunk_fused_bwd.py`:
- Around line 94-119: These two tests (test_fused_bwd_basic and
test_fused_bwd_al_dims) duplicate the same logic with only different
dimension/seed values; refactor them into a single parameterized test using
pytest.mark.parametrize so the same body runs for each tuple (B, T, H, K, V, C,
seed), and update the test function name (e.g., test_fused_bwd) to accept those
parameters and call _make_test_data, chunk_fwd_fused_g_gamma,
chunk_bwd_fused_g_gamma and self._run_reference as before; keep the same
assertions (compare_tensor) and tolerances but remove the duplicated test
functions.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 09d0abee-adcc-42f1-bdd0-bd9614e9119c

📥 Commits

Reviewing files that changed from the base of the PR and between 1247d97 and 66c19d8.

📒 Files selected for processing (2)
  • tests/ops/gla/test_pallas_chunk_fused_bwd.py
  • tests/ops/gla/test_pallas_chunk_fused_fwd.py
✅ Files skipped from review due to trivial changes (1)
  • tests/ops/gla/test_pallas_chunk_fused_fwd.py

sii-xinglong and others added 5 commits March 31, 2026 17:41
The non-fused Pallas kernels have pre-existing bf16 matmul compilation
issues on TPU v7x ("Bad lhs type" in Mosaic). Switch all fused kernel
tests to compare against the pure JAX naive_recurrent_gla reference
which uses no Pallas kernels and works on any backend.

Also remove the large T=4096 test cases for backward (too slow with
naive recurrent) and add small-dimension tests instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Unscaled g_gamma values (~-1.0) cause exp(g_gamma * BT) overflow with
BT=64, producing NaN. Add * 0.1 scaling to match fwd/bwd tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Enforce K==128, V==128 assertions in fused kernels to prevent
  silent data corruption from multi-tile overwrites (Gemini, CodeRabbit)
- Fix b_g_last dtype: add .astype(jnp.float32) for consistency
- Use exp() from tops.ops.utils consistently (replace jnp.exp)
- Clarify dg=zeros semantics: g_gamma mode treats gate as constant;
  pass 4-D g_gamma to get real gradients via non-fused path
- Fix 1-D g_gamma broadcast: reshape to (1,1,H,1) before fallback
  broadcast_to in chunk_gla_fwd and chunk_gla_bwd_with_pl
- Hardcode grid K/V dims to 1 (matching the K==128, V==128 invariant)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rse index_map

Replace pre-flip/post-flip data transformation chain with reverse
BlockSpec index_maps (t → NT-1-t) and in-kernel reverse output
writes (NT-1-i_t). This eliminates 8 physical copies (5 input
flips + 3 output flips) that previously consumed 288 MiB of HBM
for B=2,T=4096,H=16,K=128,V=128.

Training peak: 480 MiB → 288 MiB (-40%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Analytical HBM footprint comparison covering activation retention,
forward/backward peak, and training peak. Shows ~40% memory reduction
for fused kernels by eliminating g_cumsum, A, dh, dg tensors and
jnp.flip copies via reverse index_maps.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant