fix: stabilize simple_gla backward Pallas kernels for chunk_size=128#129
fix: stabilize simple_gla backward Pallas kernels for chunk_size=128#129jimoosciuc wants to merge 7 commits intomainfrom
Conversation
The previous fix (35ca41d, 81f02b8) only stabilized gla/chunk.py kernels, but the actual training code uses the simple_gla path (chunk_o.py, chunk_h.py). Apply midpoint stabilization to: - chunk_simple_gla_bwd_kernel (chunk_o.py): factored decay matrix exp(g_i-g_j) = exp(g_i-g_mid)*exp(g_mid-g_j), stabilized dq_inter exp(g) = exp(g-g_mid)*exp(g_mid) absorbed into h - _chunk_bwd_dh_kernel (chunk_h.py): gk path exp(g_last) and exp(gk) split via midpoint, restored after dot product This fixes NaN in backward pass when |g_gamma| * chunk_size > ~88 (fp32 exp overflow threshold). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Note Currently processing new changes in this PR. This may take a few minutes, please wait... ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
📝 WalkthroughWalkthroughNumeric-stability edits replace direct per-gate exponentials with midpoint or masked-difference formulations across chunked forward/backward kernels; SimpleGLA chunk backward adds a NaN-check helper and now passes Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Code Review
This pull request implements midpoint stabilization in the chunk_h.py and chunk_o.py backward kernels to prevent exponent overflow by factoring exponential terms. The review feedback points out that similar stabilization is needed for scalar gate and per-head decay paths, as well as inter-chunk terms for dv and dk. The reviewer also suggests several code improvements, including the removal of redundant type casts and the reuse of precomputed stable exponential variables to enhance performance and clarity.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/common/chunk_h.py (1)
748-751:⚠️ Potential issue | 🔴 CriticalThe simple_gla
g_gammacarry path is still unstabilized.This midpoint split only protects
gk_ref. When the caller usesg_gamma, Lines 749-751 still doexp(b_g_last)andexp(b_g_ramp)directly, sodhcan overflow atchunk_size=128beforechunk_simple_gla_bwd_kerneleven runs. Apply the same midpoint split tog_gammaand absorb that midpoint after the dot as well.Patch sketch
if g_gamma is not None: b_g_last = g_gamma[head_index] * jnp.minimum(BT, eos - t0) - b_dh *= exp(b_g_last) - b_q = (b_q * exp(b_g_ramp)[:, None]).astype(b_q.dtype) + g_gamma_mid = (b_g_ramp[0] + b_g_last) * 0.5 + b_dh = b_dh * exp(b_g_last - g_gamma_mid) + b_q = (b_q * exp(b_g_ramp - g_gamma_mid)[:, None]).astype(b_q.dtype) b_dh = b_dh + jax.lax.dot( b_q.astype(jnp.float32).T, b_do.astype(jnp.float32), precision=lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) + if g_gamma is not None: + b_dh = b_dh * exp(g_gamma_mid) if gk_ref is not None: b_dh = b_dh * exp(gk_mid)[:, None]Also applies to: 757-772
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_h.py` around lines 748 - 751, The g_gamma carry path still applies exp(b_g_last) and exp(b_g_ramp) directly and can overflow before chunk_simple_gla_bwd_kernel runs; mirror the midpoint-splitting used for gk_ref: split the exponentials for g_gamma into two halves (e.g., exp_half = exp(b_g_* / 2)), apply one half before the dot that produces dh (so intermediate values are stabilized), and multiply/absorb the other half into b_q and b_dh immediately after the dot (so final values are equivalent). Update the operations involving g_gamma, b_g_last, b_g_ramp, b_dh and b_q (and the analogous block later that spans the other g_gamma usage) to use this two-step apply/absorb pattern.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 210-214: The computation still uses exp(b_gn - b_g) which
reintroduces full-range exponentials into dv/dk; instead reuse the
already-computed stabilized exp_neg_g_s (exp(g_mid - b_g)) and factor out
exp(b_gn - g_mid) once to keep exponentials bounded. Concretely, where dv/dk
currently multiply by exp(b_gn - b_g), replace that factor with exp_neg_g_s *
jnp.exp(b_gn - g_mid) (i.e., absorb g_mid) so the per-row dependence uses the
bounded exp_neg_g_s and only a single exp(b_gn - g_mid) remains; update all
occurrences (the expressions producing dv and dk that reference b_gn and b_g) to
follow this pattern and remove any direct exp(b_gn - b_g) calls.
- Around line 258-276: The current code reconstructs the full exp(g_i - g_j)
inside b_dA_gated which reintroduces the overflow window; instead remove the
midpoint factors from b_dA_gated (keep it as the masked dA only) and push the
exp(...) factors into the matrix multiplications: when computing b_dq_intra use
left = (b_dA * jnp.exp(b_g[:, None] - g_mid)) as the left operand and right =
(b_k * jnp.exp(g_mid - b_g[None, :])) as the right operand for the jnp.dot, and
for b_dq_inter absorb/shift the exp(g_mid) factors into b_h (or b_do) so the dot
multiplies already-scaled operands rather than materializing exp(g_i-g_j) in
b_dA_gated; update uses of b_dA_gated, b_dq_intra, b_dq_inter, b_h_scaled, b_k,
b_do and g_mid accordingly.
---
Outside diff comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 748-751: The g_gamma carry path still applies exp(b_g_last) and
exp(b_g_ramp) directly and can overflow before chunk_simple_gla_bwd_kernel runs;
mirror the midpoint-splitting used for gk_ref: split the exponentials for
g_gamma into two halves (e.g., exp_half = exp(b_g_* / 2)), apply one half before
the dot that produces dh (so intermediate values are stabilized), and
multiply/absorb the other half into b_q and b_dh immediately after the dot (so
final values are equivalent). Update the operations involving g_gamma, b_g_last,
b_g_ramp, b_dh and b_q (and the analogous block later that spans the other
g_gamma usage) to use this two-step apply/absorb pattern.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a638c50e-dabd-4281-bed5-3e4ac862c1d3
📒 Files selected for processing (2)
tops/ops/common/chunk_h.pytops/ops/common/chunk_o.py
At chunk_size=128, exp(|gamma|*(j-i)) overflows to inf for upper-triangle entries. In the backward kernel, dA is pre-masked to 0 for upper triangle, so 0 * inf = NaN. On TPU, jnp.where(False, NaN, 0.0) may not reliably return 0.0 due to XLA compiler optimizations. Fix: zero the exponent for upper-triangle entries BEFORE calling exp(), so exp(0) = 1 instead of exp(overflow) = inf. For lower-triangle entries (i>=j), gamma*(i-j) <= 0 (gamma<0), so exp <= 1, no overflow. The upper triangle is masked to 0 anyway, so the exponent value doesn't matter. This replaces the previous midpoint stabilization approach, which failed because the factored dot product reconstructed the full exp(g[i]-g[j]) during MXU accumulation, defeating the stabilization. Also applies the same fix to the forward intra-chunk kernel for consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instrument each backward stage to identify exact NaN source: - Stage 1: h recomputed via chunk_fwd_h - gk synthetic from _build_gk_from_gamma - Stage 2: dh via chunk_bwd_dh - Stage 3: dq/dk/dv via chunk_simple_gla_bwd_o_pl Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)
593-646:⚠️ Potential issue | 🟠 MajorGuard backward NaN diagnostics behind an opt-in flag.
The
_bwd_nan_checkhelper currently executes unconditionally on every backward pass across 5 primary call sites (h_recomputed, dh, dq_stage3, dk_stage3, dv_stage3), running expensiveisnan,where,max,minreductions andjax.debug.printin the hot path. This is useful for debugging but should be opt-in to avoid overhead during normal training.Add a
debug_nan_checks: bool = Falseparameter to the function and guard the helper with an early return when disabled.♻️ Proposed change
def chunk_simple_gla_bwd( q: jax.Array, k: jax.Array, v: jax.Array, do: jax.Array, *, g: jax.Array | None = None, g_gamma: jax.Array | None = None, scale: float | None = None, h0: jax.Array | None = None, dht: jax.Array | None = None, cu_seqlens_cpu: jax.Array | None = None, cu_seqlens_dev: jax.Array | None = None, chunk_size: int = 64, + debug_nan_checks: bool = False, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: def _bwd_nan_check(name, x): + if not debug_nan_checks: + return has_nan = jnp.any(jnp.isnan(x)) jax.debug.print( "GLA_BWD_STAGE {name}: has_nan={has_nan} shape={shape} max={mx} min={mn}", name=name, has_nan=has_nan, shape=x.shape, mx=jnp.max(jnp.where(jnp.isnan(x), -jnp.inf, x)), mn=jnp.min(jnp.where(jnp.isnan(x), jnp.inf, x)), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 593 - 646, The backward NaN diagnostics run unconditionally and are expensive; add a debug_nan_checks: bool = False parameter to the enclosing function and guard the helper _bwd_nan_check with an early return when debug_nan_checks is False so the isnan/where/max/min and jax.debug.print calls are skipped in the hot path; update callers inside this scope (where chunk_fwd_h, _build_gk_from_gamma, chunk_bwd_dh, and chunk_simple_gla_bwd_o_pl are used) to respect the new flag (leave default False) and ensure the five _bwd_nan_check calls (h_recomputed, gk_synthetic, dh, dq_stage3, dk_stage3, dv_stage3) are only executed when debug_nan_checks is True.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 593-646: The backward NaN diagnostics run unconditionally and are
expensive; add a debug_nan_checks: bool = False parameter to the enclosing
function and guard the helper _bwd_nan_check with an early return when
debug_nan_checks is False so the isnan/where/max/min and jax.debug.print calls
are skipped in the hot path; update callers inside this scope (where
chunk_fwd_h, _build_gk_from_gamma, chunk_bwd_dh, and chunk_simple_gla_bwd_o_pl
are used) to respect the new flag (leave default False) and ensure the five
_bwd_nan_check calls (h_recomputed, gk_synthetic, dh, dq_stage3, dk_stage3,
dv_stage3) are only executed when debug_nan_checks is True.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bd460849-f9ae-4b3f-95dd-089722f9d459
📒 Files selected for processing (1)
tops/ops/simple_gla/chunk.py
The backward dh computation was converting g_gamma to synthetic gk and using the gk path with midpoint stabilization. At BT=128, the midpoint stabilization can overflow when |gamma| > ~1.4 because: exp(gk - gk_mid) = exp(|gamma| * BT/2) > fp32 max The g_gamma path computes exp(gamma*(t+1)) directly, which is always <= 1 (since gamma < 0), eliminating overflow at any chunk size. The forward h recomputation already uses g_gamma directly (safe); this change makes the backward dh match, removing the asymmetry. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 618-621: The comment claims safety of using g_gamma relies on
gamma being non-positive, but there is no runtime check; add a validation that
enforces g_gamma <= 0 at the start of any public API that accepts g_gamma
(specifically in chunk_simple_gla_fwd and chunk_simple_gla_bwd and any other
exported functions taking g_gamma) by asserting or raising a clear error when
jnp.any(g_gamma > 0) (e.g., assert jnp.all(g_gamma <= 0) or raise
ValueError("g_gamma must be <= 0")), so invalid positive gammas are caught early
with a helpful message referencing the overflow stability requirement.
- Around line 593-600: The helper _bwd_nan_check unconditionally computes
jnp.any/jnp.max/jnp.min and calls jax.debug.print in the hot backward path; wrap
the whole computation and print behind a boolean debug flag (e.g.,
GLA_DEBUG_BWD_NAN) so that when the flag is false nothing is computed or
printed. Update _bwd_nan_check to first test the flag and return immediately if
false; apply the same guard to the other identical places noted (the other
invocations at the locations corresponding to lines ~615 and ~643–645) so none
of the NaN-check expressions evaluate unless the debug flag is enabled. Ensure
the flag is cheaply accessible (module-level constant or passed in once) to
avoid extra overhead when disabled.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 813af908-2b68-4509-9ee1-8a1686292480
📒 Files selected for processing (1)
tops/ops/simple_gla/chunk.py
Add BT=128 test cases to prevent regression of the gk overflow issue. Tests both forward and backward with g_gamma gating at chunk_size=128. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…validation Address review feedback: - NaN checks in backward path are now gated behind GLA_DEBUG_NAN=1 env var to avoid overhead in production - Add g_gamma <= 0 assertion to catch invalid positive gamma values early Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Stabilize simple_gla backward Pallas kernels for chunk_size=128 (BT=128).
Root cause: The backward
chunk_bwd_dhkernel was using a syntheticgkpath(via
_build_gk_from_gamma()) instead of the nativeg_gammapath. In the gk path,midpoint stabilization computes
exp(gk - gk_mid)wheregk_mid ≈ gamma * BT/2.At BT=128 with gamma ≈ -0.707 (typical for GLA):
exp(0.707 * 63.5) ≈ 2.8e19.This intermediate value overflows bfloat16 max (65504) during the
.astype(b_q.dtype)cast in the kernel, producing
inf→NaNin the subsequent dot product.Fix (primary): Use
g_gammadirectly inchunk_bwd_dhinstead of syntheticgk.The g_gamma path computes
exp(gamma*(t+1))which is always ≤ 1 (since gamma < 0),eliminating overflow at any chunk size. This matches the forward
chunk_fwd_hkernelwhich already uses
g_gammadirectly.Additional hardening:
chunk_simple_gla_bwd_kernel(chunk_o.py): zero upper-triangleexponents before
exp()to preventinfin the causal decay matrix.Changes
tops/ops/simple_gla/chunk.py: Passg_gammadirectly tochunk_bwd_dhinstead ofconverting to synthetic
gk. Removes the forward/backward asymmetry.tops/ops/common/chunk_o.py: Safe exponent masking in both forward and backwardintra-chunk kernels —
jnp.where(mask, g_diff, 0.0)beforeexp().Test plan
🤖 Generated with Claude Code