Skip to content

fix: stabilize simple_gla backward Pallas kernels for chunk_size=128#129

Open
jimoosciuc wants to merge 7 commits intomainfrom
fix/gla-bwd-simple-stabilize
Open

fix: stabilize simple_gla backward Pallas kernels for chunk_size=128#129
jimoosciuc wants to merge 7 commits intomainfrom
fix/gla-bwd-simple-stabilize

Conversation

@jimoosciuc
Copy link
Copy Markdown
Contributor

@jimoosciuc jimoosciuc commented Mar 31, 2026

Summary

Stabilize simple_gla backward Pallas kernels for chunk_size=128 (BT=128).

Root cause: The backward chunk_bwd_dh kernel was using a synthetic gk path
(via _build_gk_from_gamma()) instead of the native g_gamma path. In the gk path,
midpoint stabilization computes exp(gk - gk_mid) where gk_mid ≈ gamma * BT/2.
At BT=128 with gamma ≈ -0.707 (typical for GLA): exp(0.707 * 63.5) ≈ 2.8e19.
This intermediate value overflows bfloat16 max (65504) during the .astype(b_q.dtype)
cast in the kernel, producing infNaN in the subsequent dot product.

Fix (primary): Use g_gamma directly in chunk_bwd_dh instead of synthetic gk.
The g_gamma path computes exp(gamma*(t+1)) which is always ≤ 1 (since gamma < 0),
eliminating overflow at any chunk size. This matches the forward chunk_fwd_h kernel
which already uses g_gamma directly.

Additional hardening:

  • Safe exponent masking in chunk_simple_gla_bwd_kernel (chunk_o.py): zero upper-triangle
    exponents before exp() to prevent inf in the causal decay matrix.

Changes

  • tops/ops/simple_gla/chunk.py: Pass g_gamma directly to chunk_bwd_dh instead of
    converting to synthetic gk. Removes the forward/backward asymmetry.
  • tops/ops/common/chunk_o.py: Safe exponent masking in both forward and backward
    intra-chunk kernels — jnp.where(mask, g_diff, 0.0) before exp().
  • Debug instrumentation (temporary): NaN checks at each backward stage.

Test plan

  • Pallas-kernel TPU unit tests pass (chunk_simple_gla forward + backward)
  • ant-pretrain 10-step E2E smoke test: no NaN at chunk_size=128
  • Loss matches Megatron reference within tolerance

🤖 Generated with Claude Code

The previous fix (35ca41d, 81f02b8) only stabilized gla/chunk.py kernels,
but the actual training code uses the simple_gla path (chunk_o.py, chunk_h.py).

Apply midpoint stabilization to:
- chunk_simple_gla_bwd_kernel (chunk_o.py): factored decay matrix
  exp(g_i-g_j) = exp(g_i-g_mid)*exp(g_mid-g_j), stabilized dq_inter
  exp(g) = exp(g-g_mid)*exp(g_mid) absorbed into h
- _chunk_bwd_dh_kernel (chunk_h.py): gk path exp(g_last) and exp(gk)
  split via midpoint, restored after dot product

This fixes NaN in backward pass when |g_gamma| * chunk_size > ~88
(fp32 exp overflow threshold).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

Note

Currently processing new changes in this PR. This may take a few minutes, please wait...

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: aa5259ff-6c3f-4d0a-885c-6278836ebed9

📥 Commits

Reviewing files that changed from the base of the PR and between fd59501 and 567dfe6.

📒 Files selected for processing (4)
  • tests/ops/simple_gla/test_chunk_simple_gla_tpu.py
  • tops/ops/common/chunk_h.py
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py
 _______________________________________________________
< Coding ain't done 'til all the Tests run. 'Nuff said. >
 -------------------------------------------------------
  \
   \   \
        \ /\
        ( )
      .( o ).
📝 Walkthrough

Walkthrough

Numeric-stability edits replace direct per-gate exponentials with midpoint or masked-difference formulations across chunked forward/backward kernels; SimpleGLA chunk backward adds a NaN-check helper and now passes g_gamma directly into the backward-dh path instead of constructing a synthetic gk.

Changes

Cohort / File(s) Summary
Backward hidden-state kernel
tops/ops/common/chunk_h.py
Replaced direct exp(g_last) / exp(b_gk) scaling with a midpoint decomposition: compute gk_mid = 0.5*(b_gk[0,:] + g_last), apply exp(g_last - gk_mid) and exp(b_gk - gk_mid) during accumulation, then multiply final b_dh by exp(gk_mid) before storing the carry.
Forward/Backward attention decay & gating
tops/ops/common/chunk_o.py
When g_gamma_ref is present, form exponent-difference matrix, zero upper-triangle (causal) entries before exponentiation to avoid overflow, apply exp(g_diff_safe) to b_A/decay, ensure mask exists when g_gamma_ref is absent, and reuse the computed safe decay in backward recompute/gating.
SimpleGLA chunk backward
tops/ops/simple_gla/chunk.py
Added internal _bwd_nan_check(name, x) and inserted NaN/shape/min/max checks at key backward stages; removed construction of a synthetic gk from g_gamma and pass g_gamma (with gk=None) into chunk_bwd_dh, shifting the dh-derivation to the g_gamma-based path. No public signatures changed.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

"I nibble midpoints soft and slow,
So exponents don't overflow.
Gates breathe calm, carries sleep,
Gradients wander, gentle, deep.
A rabbit hops — the tensors glow." 🐇✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: stabilizing simple_gla backward Pallas kernels for chunk_size=128 to address NaN issues in the backward pass.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/gla-bwd-simple-stabilize

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements midpoint stabilization in the chunk_h.py and chunk_o.py backward kernels to prevent exponent overflow by factoring exponential terms. The review feedback points out that similar stabilization is needed for scalar gate and per-head decay paths, as well as inter-chunk terms for dv and dk. The reviewer also suggests several code improvements, including the removal of redundant type casts and the reuse of precomputed stable exponential variables to enhance performance and clarity.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/common/chunk_h.py (1)

748-751: ⚠️ Potential issue | 🔴 Critical

The simple_gla g_gamma carry path is still unstabilized.

This midpoint split only protects gk_ref. When the caller uses g_gamma, Lines 749-751 still do exp(b_g_last) and exp(b_g_ramp) directly, so dh can overflow at chunk_size=128 before chunk_simple_gla_bwd_kernel even runs. Apply the same midpoint split to g_gamma and absorb that midpoint after the dot as well.

Patch sketch
     if g_gamma is not None:
         b_g_last = g_gamma[head_index] * jnp.minimum(BT, eos - t0)
-        b_dh *= exp(b_g_last)
-        b_q = (b_q * exp(b_g_ramp)[:, None]).astype(b_q.dtype)
+        g_gamma_mid = (b_g_ramp[0] + b_g_last) * 0.5
+        b_dh = b_dh * exp(b_g_last - g_gamma_mid)
+        b_q = (b_q * exp(b_g_ramp - g_gamma_mid)[:, None]).astype(b_q.dtype)

     b_dh = b_dh + jax.lax.dot(
         b_q.astype(jnp.float32).T, b_do.astype(jnp.float32),
         precision=lax.Precision.HIGHEST,
         preferred_element_type=jnp.float32,
     )

+    if g_gamma is not None:
+        b_dh = b_dh * exp(g_gamma_mid)
     if gk_ref is not None:
         b_dh = b_dh * exp(gk_mid)[:, None]

Also applies to: 757-772

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 748 - 751, The g_gamma carry path
still applies exp(b_g_last) and exp(b_g_ramp) directly and can overflow before
chunk_simple_gla_bwd_kernel runs; mirror the midpoint-splitting used for gk_ref:
split the exponentials for g_gamma into two halves (e.g., exp_half = exp(b_g_* /
2)), apply one half before the dot that produces dh (so intermediate values are
stabilized), and multiply/absorb the other half into b_q and b_dh immediately
after the dot (so final values are equivalent). Update the operations involving
g_gamma, b_g_last, b_g_ramp, b_dh and b_q (and the analogous block later that
spans the other g_gamma usage) to use this two-step apply/absorb pattern.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 210-214: The computation still uses exp(b_gn - b_g) which
reintroduces full-range exponentials into dv/dk; instead reuse the
already-computed stabilized exp_neg_g_s (exp(g_mid - b_g)) and factor out
exp(b_gn - g_mid) once to keep exponentials bounded. Concretely, where dv/dk
currently multiply by exp(b_gn - b_g), replace that factor with exp_neg_g_s *
jnp.exp(b_gn - g_mid) (i.e., absorb g_mid) so the per-row dependence uses the
bounded exp_neg_g_s and only a single exp(b_gn - g_mid) remains; update all
occurrences (the expressions producing dv and dk that reference b_gn and b_g) to
follow this pattern and remove any direct exp(b_gn - b_g) calls.
- Around line 258-276: The current code reconstructs the full exp(g_i - g_j)
inside b_dA_gated which reintroduces the overflow window; instead remove the
midpoint factors from b_dA_gated (keep it as the masked dA only) and push the
exp(...) factors into the matrix multiplications: when computing b_dq_intra use
left = (b_dA * jnp.exp(b_g[:, None] - g_mid)) as the left operand and right =
(b_k * jnp.exp(g_mid - b_g[None, :])) as the right operand for the jnp.dot, and
for b_dq_inter absorb/shift the exp(g_mid) factors into b_h (or b_do) so the dot
multiplies already-scaled operands rather than materializing exp(g_i-g_j) in
b_dA_gated; update uses of b_dA_gated, b_dq_intra, b_dq_inter, b_h_scaled, b_k,
b_do and g_mid accordingly.

---

Outside diff comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 748-751: The g_gamma carry path still applies exp(b_g_last) and
exp(b_g_ramp) directly and can overflow before chunk_simple_gla_bwd_kernel runs;
mirror the midpoint-splitting used for gk_ref: split the exponentials for
g_gamma into two halves (e.g., exp_half = exp(b_g_* / 2)), apply one half before
the dot that produces dh (so intermediate values are stabilized), and
multiply/absorb the other half into b_q and b_dh immediately after the dot (so
final values are equivalent). Update the operations involving g_gamma, b_g_last,
b_g_ramp, b_dh and b_q (and the analogous block later that spans the other
g_gamma usage) to use this two-step apply/absorb pattern.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a638c50e-dabd-4281-bed5-3e4ac862c1d3

📥 Commits

Reviewing files that changed from the base of the PR and between 2b47657 and bb16e06.

📒 Files selected for processing (2)
  • tops/ops/common/chunk_h.py
  • tops/ops/common/chunk_o.py

jimoosciuc and others added 2 commits March 31, 2026 20:33
At chunk_size=128, exp(|gamma|*(j-i)) overflows to inf for upper-triangle
entries. In the backward kernel, dA is pre-masked to 0 for upper triangle,
so 0 * inf = NaN. On TPU, jnp.where(False, NaN, 0.0) may not reliably
return 0.0 due to XLA compiler optimizations.

Fix: zero the exponent for upper-triangle entries BEFORE calling exp(),
so exp(0) = 1 instead of exp(overflow) = inf. For lower-triangle entries
(i>=j), gamma*(i-j) <= 0 (gamma<0), so exp <= 1, no overflow. The upper
triangle is masked to 0 anyway, so the exponent value doesn't matter.

This replaces the previous midpoint stabilization approach, which failed
because the factored dot product reconstructed the full exp(g[i]-g[j])
during MXU accumulation, defeating the stabilization.

Also applies the same fix to the forward intra-chunk kernel for consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instrument each backward stage to identify exact NaN source:
- Stage 1: h recomputed via chunk_fwd_h
- gk synthetic from _build_gk_from_gamma
- Stage 2: dh via chunk_bwd_dh
- Stage 3: dq/dk/dv via chunk_simple_gla_bwd_o_pl

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)

593-646: ⚠️ Potential issue | 🟠 Major

Guard backward NaN diagnostics behind an opt-in flag.

The _bwd_nan_check helper currently executes unconditionally on every backward pass across 5 primary call sites (h_recomputed, dh, dq_stage3, dk_stage3, dv_stage3), running expensive isnan, where, max, min reductions and jax.debug.print in the hot path. This is useful for debugging but should be opt-in to avoid overhead during normal training.

Add a debug_nan_checks: bool = False parameter to the function and guard the helper with an early return when disabled.

♻️ Proposed change
 def chunk_simple_gla_bwd(
     q: jax.Array,
     k: jax.Array,
     v: jax.Array,
     do: jax.Array,
     *,
     g: jax.Array | None = None,
     g_gamma: jax.Array | None = None,
     scale: float | None = None,
     h0: jax.Array | None = None,
     dht: jax.Array | None = None,
     cu_seqlens_cpu: jax.Array | None = None,
     cu_seqlens_dev: jax.Array | None = None,
     chunk_size: int = 64,
+    debug_nan_checks: bool = False,
 ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]:
     
     def _bwd_nan_check(name, x):
+        if not debug_nan_checks:
+            return
         has_nan = jnp.any(jnp.isnan(x))
         jax.debug.print(
             "GLA_BWD_STAGE {name}: has_nan={has_nan} shape={shape} max={mx} min={mn}",
             name=name, has_nan=has_nan, shape=x.shape,
             mx=jnp.max(jnp.where(jnp.isnan(x), -jnp.inf, x)),
             mn=jnp.min(jnp.where(jnp.isnan(x), jnp.inf, x)),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 593 - 646, The backward NaN
diagnostics run unconditionally and are expensive; add a debug_nan_checks: bool
= False parameter to the enclosing function and guard the helper _bwd_nan_check
with an early return when debug_nan_checks is False so the isnan/where/max/min
and jax.debug.print calls are skipped in the hot path; update callers inside
this scope (where chunk_fwd_h, _build_gk_from_gamma, chunk_bwd_dh, and
chunk_simple_gla_bwd_o_pl are used) to respect the new flag (leave default
False) and ensure the five _bwd_nan_check calls (h_recomputed, gk_synthetic, dh,
dq_stage3, dk_stage3, dv_stage3) are only executed when debug_nan_checks is
True.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 593-646: The backward NaN diagnostics run unconditionally and are
expensive; add a debug_nan_checks: bool = False parameter to the enclosing
function and guard the helper _bwd_nan_check with an early return when
debug_nan_checks is False so the isnan/where/max/min and jax.debug.print calls
are skipped in the hot path; update callers inside this scope (where
chunk_fwd_h, _build_gk_from_gamma, chunk_bwd_dh, and chunk_simple_gla_bwd_o_pl
are used) to respect the new flag (leave default False) and ensure the five
_bwd_nan_check calls (h_recomputed, gk_synthetic, dh, dq_stage3, dk_stage3,
dv_stage3) are only executed when debug_nan_checks is True.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bd460849-f9ae-4b3f-95dd-089722f9d459

📥 Commits

Reviewing files that changed from the base of the PR and between db57d27 and 39b968c.

📒 Files selected for processing (1)
  • tops/ops/simple_gla/chunk.py

The backward dh computation was converting g_gamma to synthetic gk and
using the gk path with midpoint stabilization. At BT=128, the midpoint
stabilization can overflow when |gamma| > ~1.4 because:
  exp(gk - gk_mid) = exp(|gamma| * BT/2) > fp32 max

The g_gamma path computes exp(gamma*(t+1)) directly, which is always
<= 1 (since gamma < 0), eliminating overflow at any chunk size.

The forward h recomputation already uses g_gamma directly (safe); this
change makes the backward dh match, removing the asymmetry.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 618-621: The comment claims safety of using g_gamma relies on
gamma being non-positive, but there is no runtime check; add a validation that
enforces g_gamma <= 0 at the start of any public API that accepts g_gamma
(specifically in chunk_simple_gla_fwd and chunk_simple_gla_bwd and any other
exported functions taking g_gamma) by asserting or raising a clear error when
jnp.any(g_gamma > 0) (e.g., assert jnp.all(g_gamma <= 0) or raise
ValueError("g_gamma must be <= 0")), so invalid positive gammas are caught early
with a helpful message referencing the overflow stability requirement.
- Around line 593-600: The helper _bwd_nan_check unconditionally computes
jnp.any/jnp.max/jnp.min and calls jax.debug.print in the hot backward path; wrap
the whole computation and print behind a boolean debug flag (e.g.,
GLA_DEBUG_BWD_NAN) so that when the flag is false nothing is computed or
printed. Update _bwd_nan_check to first test the flag and return immediately if
false; apply the same guard to the other identical places noted (the other
invocations at the locations corresponding to lines ~615 and ~643–645) so none
of the NaN-check expressions evaluate unless the debug flag is enabled. Ensure
the flag is cheaply accessible (module-level constant or passed in once) to
avoid extra overhead when disabled.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 813af908-2b68-4509-9ee1-8a1686292480

📥 Commits

Reviewing files that changed from the base of the PR and between 39b968c and fd59501.

📒 Files selected for processing (1)
  • tops/ops/simple_gla/chunk.py

jimoosciuc and others added 3 commits March 31, 2026 21:15
Add BT=128 test cases to prevent regression of the gk overflow issue.
Tests both forward and backward with g_gamma gating at chunk_size=128.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…validation

Address review feedback:
- NaN checks in backward path are now gated behind GLA_DEBUG_NAN=1 env var
  to avoid overhead in production
- Add g_gamma <= 0 assertion to catch invalid positive gamma values early

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants