fix(gla): midpoint stabilization in chunk_gla_fwd_intra_gk_ref for BT=128#133
fix(gla): midpoint stabilization in chunk_gla_fwd_intra_gk_ref for BT=128#133jimoosciuc wants to merge 8 commits intomainfrom
Conversation
The reference intra-chunk attention function used first-position stabilization (g_n = g_c[0]), causing exp(g_n - g_c[-1]) to overflow float32 when |g_gamma| * (chunk_size - 1) > 88.7. At chunk_size=128 with |g_gamma|=0.8, this produced exp(101.6) → Inf, then 0 * Inf = NaN on diagonal entries of the attention matrix. Switch to midpoint stabilization (already used by the Pallas kernel variant), halving the max exponent from 101.6 to 50.8 at BT=128. Also fix chunk_gla_bwd call to chunk_bwd_dh_ref to use keyword args matching the updated signature. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReplace single-row intra-chunk gate references with midpoint-based gating in forward kernels, clamp exponent differences to 0 on the causal upper triangle before exp() in chunk kernels, and route g_cumsum as gk with adjusted chunk_gla_bwd -> chunk_bwd_dh_ref args; add tests asserting no NaN/Inf in backward gradients. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Harness
participant Input as Input Tensors
participant Fwd as chunk_gla_fwd_intra_gk_*
participant Bwd as chunk_gla_bwd
participant DhRef as chunk_bwd_dh_ref
Test->>Input: build fixed-size GLA inputs + constant g_gamma
Input->>Fwd: forward (use g_mid = (first+last)/2)
Fwd-->>Input: produce gated q/k with midpoint-based exponents
Test->>Bwd: call chunk_gla_bwd(g=None, g_gamma=None, gk=g_cumsum, ...)
Bwd->>DhRef: route backward dh with output_dh0 flag
DhRef-->>Bwd: return dh gradients
Bwd-->>Test: return dq, dk, dv, dg (checked for NaN/Inf)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces midpoint stabilization in the Gated Linear Attention (GLA) forward pass to prevent exponent overflow when using large chunk sizes. It also adds a regression test to ensure numerical stability for constant g_gamma values in the backward pass. A bug was identified in the backward pass implementation where the 'output_dh0' argument was omitted, leading to the loss of initial state gradients; a code suggestion was provided to resolve this.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/gla/test_exp_overflow_stabilization.py`:
- Line 206: The test unpacks five values from chunk_gla_bwd into dq, dk, dv, dg,
dh0 but dh0 is unused which triggers Ruff RUF059; update the unpacking in the
test to use a deliberately ignored variable name (e.g., _dh0) instead of dh0 so
the intent is clear and the linter warning is silenced—locate the call to
chunk_gla_bwd in the test and change the last target from dh0 to _dh0.
- Around line 174-183: The file calls pytest.main([__file__, "-v"]) before new
tests are defined, causing tests like test_constant_g_gamma_bwd_no_nan to be
skipped when running the module directly; move (or remove) the pytest.main
invocation so it executes only after all tests are declared — e.g., place the if
__name__ == "__main__": pytest.main([...]) block at the end of the file (below
test_constant_g_gamma_bwd_no_nan and any other tests), ensuring pytest.main runs
after test definitions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7b2ed14a-5332-4bf3-b8c7-d9ce701212b0
📒 Files selected for processing (2)
tests/ops/gla/test_exp_overflow_stabilization.pytops/ops/gla/chunk.py
- Add output_dh0 flag to chunk_bwd_dh_ref call so dh0 is only computed when initial_state or dht is provided - Move pytest.main guard to end of test file (after all test functions) - Mark unused dh0 as _dh0 to satisfy linter Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The Toeplitz decay exp(gamma*(i-j)) overflows float32 in the upper triangle when |gamma|*(BT-1) > 88.7 (e.g. BT=128, |gamma|=0.8 gives exp(101.6) = Inf). On TPU, 0 * Inf = NaN leaks through jnp.where because XLA may compile select as mask*x + (1-mask)*y instead of a true IEEE select. Fix: clamp the exponent to 0.0 in the upper triangle *before* exp(), so exp(0)=1 replaces the overflow. The upper triangle is masked to zero anyway, so the result is mathematically identical. Applied to both forward (_chunk_fwd_o_kernel) and backward (chunk_simple_gla_bwd_kernel) kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
This PR[https://github.com//pull/131] has been merged. Try pointing |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tops/ops/common/chunk_o.py (1)
66-69: Add a direct regression for thechunk_fwd_omasked-exp path.The new stabilization tests called out in this PR exercise
chunk_gla_bwdandchunk_gla_fwd_intra_gk_ref, but not this_chunk_fwd_o_kernel/chunk_fwd_opath. Achunk_size=128,g_gamma=-0.8case here would keep this exact clamp from regressing independently of the backward fix.Also applies to: 76-79
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 218-222: The kernel assumes decay rates are non-positive but only
masks positive values locally; enforce the invariant at the public entry
point(s) that supply the gamma/ b_gamma arrays (or immediately before any exp()
use) by asserting all entries are <= 0 or by clamping b_gamma =
jnp.minimum(b_gamma, 0.0); apply this fix for every location that exponentiates
gamma-like values (symbols: b_gamma, g_gamma, g_gn, safe_diff, decay) so no
unmasked positive values reach jnp.exp and cause overflow.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 53a4b97b-6e5f-416d-a927-bcfabe1556b2
📒 Files selected for processing (1)
tops/ops/common/chunk_o.py
Add test_chunk_bwd_large_gamma_no_nan with g_gamma=-0.5/-0.8/-1.0 at chunk_size=128. This test runs in pallas-kernel TPU CI (not ignored like tests/ops/gla/) and exercises the Toeplitz decay overflow path that causes NaN in production. Also add chunk_size=128 cases to existing fwd/bwd test suites. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/ops/simple_gla/test_chunk_simple_gla_tpu.py (1)
306-337: Consider extending the NaN regression to also validatedh0finiteness.This test currently checks
dq/dk/dvonly. Adding anh0=Truebranch would directly cover theoutput_dh0backward path as well.Suggested extension
-@pytest.mark.parametrize("g_gamma_val", [-0.5, -0.8, -1.0]) -def test_chunk_bwd_large_gamma_no_nan(g_gamma_val): +@pytest.mark.parametrize("g_gamma_val", [-0.5, -0.8, -1.0]) +@pytest.mark.parametrize("use_h0", [False, True]) +def test_chunk_bwd_large_gamma_no_nan(g_gamma_val, use_h0): @@ - q = jax.random.normal(keys[0], (B, T, H, K), dtype=jnp.bfloat16) - k = jax.random.normal(keys[1], (B, T, H, K), dtype=jnp.bfloat16) - v = jax.random.normal(keys[2], (B, T, H, V), dtype=jnp.bfloat16) - do = jax.random.normal(keys[3], (B, T, H, V), dtype=jnp.bfloat16) + q = jax.random.normal(keys[0], (B, T, H, K), dtype=jnp.bfloat16) + k = jax.random.normal(keys[1], (B, T, H, K), dtype=jnp.bfloat16) + v = jax.random.normal(keys[2], (B, T, H, V), dtype=jnp.bfloat16) + do = jax.random.normal(keys[3], (B, T, H, V), dtype=jnp.bfloat16) + h0 = jax.random.normal(keys[4], (B, H, K, V), dtype=jnp.bfloat16) if use_h0 else None @@ - dq, dk, dv, _dh0 = chunk_simple_gla_bwd( + dq, dk, dv, dh0 = chunk_simple_gla_bwd( q, k, v, do, g_gamma=g_gamma, scale=scale, + h0=h0, chunk_size=C, ) @@ - for name, arr in [("dq", dq), ("dk", dk), ("dv", dv)]: + grads = [("dq", dq), ("dk", dk), ("dv", dv)] + if dh0 is not None: + grads.append(("dh0", dh0)) + for name, arr in grads: assert not jnp.any(jnp.isnan(arr)), ( f"{name} contains NaN (g_gamma={g_gamma_val}, chunk_size={C})" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_chunk_simple_gla_tpu.py` around lines 306 - 337, The test test_chunk_bwd_large_gamma_no_nan currently asserts finiteness for dq/dk/dv but ignores dh0 returned by chunk_simple_gla_bwd; update the test to also check the dh0 output (variable _dh0 or rename to dh0) for NaN/Inf so the output_dh0 backward path is covered. Specifically, capture the fourth return value from chunk_simple_gla_bwd as dh0 and add the same jnp.any(jnp.isnan(...)) and jnp.any(jnp.isinf(...)) assertions used for dq/dk/dv, with a clear failure message referencing g_gamma_val and chunk_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_tpu.py`:
- Around line 306-337: The test test_chunk_bwd_large_gamma_no_nan currently
asserts finiteness for dq/dk/dv but ignores dh0 returned by
chunk_simple_gla_bwd; update the test to also check the dh0 output (variable
_dh0 or rename to dh0) for NaN/Inf so the output_dh0 backward path is covered.
Specifically, capture the fourth return value from chunk_simple_gla_bwd as dh0
and add the same jnp.any(jnp.isnan(...)) and jnp.any(jnp.isinf(...)) assertions
used for dq/dk/dv, with a clear failure message referencing g_gamma_val and
chunk_size.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: eeb2de0b-8068-49a5-9d4e-f0fe18fa497b
📒 Files selected for processing (1)
tests/ops/simple_gla/test_chunk_simple_gla_tpu.py
Enforces the decay-rate invariant (g_gamma must be non-positive) at chunk_simple_gla_fwd and chunk_simple_gla_bwd entry points, rather than relying on kernel-local masking assumptions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)
741-750: 🛠️ Refactor suggestion | 🟠 MajorConsider adding validation at the public entry point
chunk_simple_glabefore tracing.The g_gamma validation assertions are placed in
chunk_simple_gla_fwd(line 489) andchunk_simple_gla_bwd(line 585), but these are called during tracing by_pallas_chunk_gla_fwdand_pallas_chunk_gla_bwdrespectively.The safest location for the assertion is in the
chunk_simple_glafunction body itself (before line 769), which executes eagerly beforecustom_vjpbegins tracing the forward/backward rules:Proposed fix: add validation at entry point
def chunk_simple_gla( q: jax.Array, k: jax.Array, v: jax.Array, g_gamma: jax.Array, initial_state: jax.Array | None = None, scale: float | None = None, output_final_state: bool = False, chunk_size: int = 64, ) -> tuple[jax.Array, jax.Array | None]: """Chunked Simple GLA using pallas-kernel's Pallas TPU kernels. ... """ + assert g_gamma is None or (g_gamma <= 0).all(), \ + "g_gamma must be <= 0 (decay rate); positive values overflow exp()" dtype = q.dtypeThen consider whether the assertions in
chunk_simple_gla_fwdandchunk_simple_gla_bwdare still needed (they may cause tracing issues if those functions are ever called in a traced context).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 741 - 750, Add eager validation of g_gamma at the public entry point chunk_simple_gla before any tracing begins: check that g_gamma has the expected shape/dtype and non-null constraints at the top of chunk_simple_gla (before calls to _pallas_chunk_gla_fwd/_pallas_chunk_gla_bwd), mirroring the existing assertions in chunk_simple_gla_fwd and chunk_simple_gla_bwd; after adding this pre-check you can optionally remove or relax the redundant assertions inside chunk_simple_gla_fwd and chunk_simple_gla_bwd to avoid tracing-time issues.
🧹 Nitpick comments (1)
tops/ops/simple_gla/chunk.py (1)
467-480: Missing docstring for public functionchunk_simple_gla_fwd.Per coding guidelines, public functions must have a docstring explaining business semantics and detailing tensor shapes for all inputs/outputs. This function lacks any documentation.
While not part of the current PR's changes, consider adding a docstring for maintainability:
Example docstring
def chunk_simple_gla_fwd( q: jax.Array, k: jax.Array, v: jax.Array, *, g: jax.Array | None = None, g_gamma: jax.Array | None = None, scale: float | None = None, h0: jax.Array | None = None, use_ht: bool = False, cu_seqlens_cpu: jax.Array | None = None, cu_seqlens_dev: jax.Array | None = None, chunk_size: int = 64, ) -> tuple[jax.Array, jax.Array | None]: """Simple GLA forward pass using chunked computation. Args: q: Query tensor [B, T, H, K] k: Key tensor [B, T, H, K] v: Value tensor [B, T, H, V] g: Per-element gate [B, T, H] or None g_gamma: Per-head decay rate [H], must be <= 0 scale: Attention scale factor (default: K^-0.5) h0: Initial state [N, H, K, V] or None use_ht: Whether to return final state cu_seqlens_cpu: Cumulative sequence lengths on CPU [B+1] cu_seqlens_dev: Cumulative sequence lengths on device [B+1] chunk_size: Chunk size for block computation Returns: (o, ht): Output [B, T, H, V] and optional final state [N, H, K, V] """As per coding guidelines: "All public functions must have a clear docstring... that explicitly explains the business semantics of the function and clearly details the tensor shape and dimension meaning for every input and output argument."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 467 - 480, Add a descriptive docstring to the public function chunk_simple_gla_fwd that explains its business semantics (chunked/simple GLA forward pass) and documents the tensor shapes and meanings for every parameter and return value; explicitly describe q, k, v shapes (e.g. [B, T, H, K] or whatever this module expects), optional tensors g, g_gamma, scale, h0, booleans like use_ht, cu_seqlens_cpu/cu_seqlens_dev shapes, chunk_size semantics, and the returned tuple (output o shape and optional final state ht shape), so consumers and maintainers can unambiguously understand inputs/outputs and behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 741-750: Add eager validation of g_gamma at the public entry point
chunk_simple_gla before any tracing begins: check that g_gamma has the expected
shape/dtype and non-null constraints at the top of chunk_simple_gla (before
calls to _pallas_chunk_gla_fwd/_pallas_chunk_gla_bwd), mirroring the existing
assertions in chunk_simple_gla_fwd and chunk_simple_gla_bwd; after adding this
pre-check you can optionally remove or relax the redundant assertions inside
chunk_simple_gla_fwd and chunk_simple_gla_bwd to avoid tracing-time issues.
---
Nitpick comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 467-480: Add a descriptive docstring to the public function
chunk_simple_gla_fwd that explains its business semantics (chunked/simple GLA
forward pass) and documents the tensor shapes and meanings for every parameter
and return value; explicitly describe q, k, v shapes (e.g. [B, T, H, K] or
whatever this module expects), optional tensors g, g_gamma, scale, h0, booleans
like use_ht, cu_seqlens_cpu/cu_seqlens_dev shapes, chunk_size semantics, and the
returned tuple (output o shape and optional final state ht shape), so consumers
and maintainers can unambiguously understand inputs/outputs and behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 55869d70-f2da-4b09-9eb4-6f608ac5cb59
📒 Files selected for processing (1)
tops/ops/simple_gla/chunk.py
- test_chunk_simple_gla_grad_no_nan: uses chunk_simple_gla (custom_vjp) with jax.grad to test the full forward+backward pipeline - test_chunk_bwd_components_no_nan: tests each backward stage individually (chunk_fwd_h, chunk_bwd_dh, bwd_o_pl) Both pass on CPU interpret mode. Need TPU CI to verify. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
chunk_size=128:chunk_gla_fwd_intra_gk_refused first-position stabilization (g_n = g_c[0]), causingexp(|g_gamma|*127)to overflow float32 when|g_gamma| > 0.69chunk_gla_fwd_intra_gk_pl), halving the max exponent from 101.6 to 50.8chunk_gla_bwdcall tochunk_bwd_dh_refto use keyword args matching updated signatureRoot Cause
At
chunk_size=128with|g_gamma|=0.8:exp(g_n - g_c[-1]) = exp(0.8 * 127) = exp(101.6)→ float32 overflow (Inf)k_gated[-1] = k * Inf,q_gated[-1] = q * 0→0 * Inf = NaNon diagonal of A matrixAt
chunk_size=64:exp(0.8 * 63) = exp(50.4)→ safe (float32 max exp ≈ 88.7)Test plan
test_exp_overflow_stabilization.py)test_constant_g_gamma_bwd_no_nancovers exact AL model scenario (BT=128, g_gamma=-0.8)🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
Tests