fix: numerical stability for GLA chunk_size=128#118
Conversation
The backward kernel and reference implementations overflow when chunk_size > 64 because exp(-b_g) and exp(b_g) produce Inf for large |gamma| * position values (e.g. 0.707 * 128 = 90.5 > 88.72 f32 limit). Fixes: - Backward kernel: replace separate exp(b_g)/exp(-b_g) with decay matrix exp(b_g[i]-b_g[j]) which is always <= 0 for causal entries - Forward intra kernel: clamp upper triangle + early causal mask - Backward A computation: use overflow-safe simple GLA intra kernel instead of chunk_gla_fwd_intra_gk_ref with synthetic per-K gates - Reference fwd/bwd: clamp ramp_diff before exp() in masked regions All fixes are transparent for chunk_size=64 (max exp arg = 44.5, well within f32 range). Tested on TPU v7x with B=2..128, T=1k..8k, H=16, K=V=128 — zero NaN/Inf in forward and backward. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughReplaced per-vector exponential scalings with pairwise exp(g_i - g_j) decays and added clamped exponent args before exp in multiple forward/backward GLA chunk kernels; refactored intra-chunk A recomputation and added a new test exercising per-head g_gamma slopes with forward/backward checks. Changes
Sequence Diagram(s)(Skipped — changes are numerical/stability updates and do not introduce a new multi-component control-flow requiring sequence diagrams.) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Comment |
There was a problem hiding this comment.
Code Review
This pull request enhances the numerical stability of GLA kernels by clamping exponent arguments to prevent overflows during decay matrix calculations. It refactors the backward pass and intra-chunk attention mechanisms to incorporate these safety measures across several functions. Feedback suggests further optimizing the decay matrix application in chunk_o.py by reusing the decayed attention matrix and removing redundant masking operations.
| # === Precompute decay matrices (numerically stable for large BT) === | ||
| # decay_lower[i, j] = exp(b_g[i] - b_g[j]) for i >= j; 0 otherwise | ||
| # Since gamma < 0 and i >= j: b_g[i] - b_g[j] = gamma*(i-j) <= 0 => exp in (0, 1] | ||
| b_g_diff = b_g[:, None] - b_g[None, :] # [BT, BT] | ||
| b_g_diff_safe = jnp.minimum(b_g_diff, 0.0) # clamp upper tri to avoid overflow | ||
| decay_lower = jnp.where(mask, jnp.exp(b_g_diff_safe), 0.0) # [BT, BT] | ||
| decay_upper = decay_lower.T # decay_upper[j, i] = exp(b_g[i] - b_g[j]) for i >= j | ||
|
|
||
| # 3. dq = (dA * decay_lower) @ k + do @ h^T * scale * exp(b_g) | ||
| # dq_intra[i] = sum_j dA[i,j] * k[j] * exp(b_g[i] - b_g[j]) | ||
| b_dq_intra = jnp.dot((b_dA * decay_lower).astype(b_k.dtype), b_k, | ||
| precision=jax.lax.Precision.HIGHEST, | ||
| preferred_element_type=jnp.float32) * jnp.exp(b_g)[:, None] | ||
| preferred_element_type=jnp.float32) | ||
| # dq_inter: exp(b_g) is safe (b_g <= 0 => exp <= 1, may underflow but no overflow) | ||
| b_dq_inter = jnp.dot(b_do, b_h.astype(b_do.dtype).T, | ||
| precision=jax.lax.Precision.HIGHEST, | ||
| preferred_element_type=jnp.float32) * (scale * jnp.exp(b_g)[:, None]) | ||
| b_dq = b_dq_intra + b_dq_inter | ||
| dq_ref[0, 0] = b_dq.astype(dq_ref.dtype) | ||
|
|
||
| # 4. dk = dA^T @ q_pos * exp(-b_g) + v @ dh^T * exp(b_gn - b_g) | ||
| q_pos = (b_q * jnp.exp(b_g)[:, None]).astype(b_q.dtype) | ||
| b_dk_intra = jnp.dot(b_dA.T.astype(q_pos.dtype), q_pos, | ||
| # 4. dk = (dA^T * decay_upper) @ q + v @ dh^T * exp(b_gn - b_g) | ||
| # dk_intra[j] = sum_i dA[i,j] * q[i] * exp(b_g[i] - b_g[j]) | ||
| b_dk_intra = jnp.dot((b_dA.T * decay_upper).astype(b_q.dtype), b_q, | ||
| precision=jax.lax.Precision.HIGHEST, | ||
| preferred_element_type=jnp.float32) * jnp.exp(-b_g)[:, None] | ||
| preferred_element_type=jnp.float32) | ||
| # dk_inter: exp(b_gn - b_g) is safe (b_gn <= b_g[j] => b_gn - b_g <= 0) | ||
| b_dk_inter = jnp.dot(b_v, b_dh.astype(b_v.dtype).T, | ||
| precision=jax.lax.Precision.HIGHEST, | ||
| preferred_element_type=jnp.float32) * jnp.exp(b_gn - b_g)[:, None] |
There was a problem hiding this comment.
The computation of decay matrices and their application to dq and dk can be optimized. By reusing the decayed attention matrix, you can avoid redundant element-wise multiplications and transpositions. Additionally, since b_dA is already masked at line 59, the explicit masking in decay_lower is redundant as long as we clamp the exponent to avoid overflow in the upper triangle.
# === Precompute decay matrix (numerically stable for large BT) ===
# decay[i, j] = exp(b_g[i] - b_g[j]). For causal entries (i >= j), b_g[i] - b_g[j] <= 0.
# We clamp to 0.0 to avoid overflow in the upper triangle (i < j).
b_g_diff = b_g[:, None] - b_g[None, :]
decay = jnp.exp(jnp.minimum(b_g_diff, 0.0))
b_dA_decayed = b_dA * decay
# 3. dq = (dA * decay) @ k + do @ h^T * scale * exp(b_g)
b_dq_intra = jnp.dot(b_dA_decayed.astype(b_k.dtype), b_k,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32)
# dq_inter: exp(b_g) is safe (b_g <= 0 => exp <= 1)
b_dq_inter = jnp.dot(b_do, b_h.astype(b_do.dtype).T,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * (scale * jnp.exp(b_g)[:, None])
b_dq = b_dq_intra + b_dq_inter
dq_ref[0, 0] = b_dq.astype(dq_ref.dtype)
# 4. dk = (dA * decay)^T @ q + v @ dh^T * exp(b_gn - b_g)
b_dk_intra = jnp.dot(b_dA_decayed.T.astype(b_q.dtype), b_q,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32)
# dk_inter: exp(b_gn - b_g) is safe (b_gn <= b_g[j] => b_gn - b_g <= 0)
b_dk_inter = jnp.dot(b_v, b_dh.astype(b_v.dtype).T,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * jnp.exp(b_gn - b_g)[:, None]There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 251-254: The clamped decay path uses jnp.minimum(..., 0.0) which
assumes g_gamma is non-positive; add an explicit validation at each public
entrypoint that can reach this clamped Toeplitz branch (specifically in
chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg) to assert or raise if g_gamma >
0 so callers fail fast instead of silently producing incorrect amplification;
apply the same guard where the ramp_diff -> ramp_diff_safe computation occurs
(the occurrences around the shown block and the locations referenced near lines
340-343 and 524-529) and ensure the check runs before computing
ramp_diff/ramp_diff_safe.
In `@tops/ops/simple_gla/chunk.py`:
- Around line 628-630: The code currently lets chunk_simple_gla_bwd call
_chunk_simple_gla_fwd_intra and later per-head index g_gamma even when g_gamma
is None; add an explicit guard at the start of chunk_simple_gla_bwd that checks
if g_gamma is None and raises a clear ValueError (or TypeError) describing that
g_gamma must be provided and non-None, so you fail fast before the
g_gamma.reshape / per-head indexing in _chunk_simple_gla_fwd_intra or the fused
backward kernel; reference g_gamma, chunk_simple_gla_bwd, and
_chunk_simple_gla_fwd_intra in the error message to help debugging.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dfd460ca-d882-4ef8-b68e-d4a96c69b3e8
📒 Files selected for processing (2)
tops/ops/common/chunk_o.pytops/ops/simple_gla/chunk.py
| # Clamp ramp diff to avoid overflow in upper triangle (masked later) | ||
| ramp_diff = ramp[..., :, None] - ramp[..., None, :] # [H, C, C] | ||
| ramp_diff_safe = jnp.minimum(ramp_diff, 0.0) | ||
| A = A * exp(ramp_diff_safe)[None, None] |
There was a problem hiding this comment.
Validate non-positive g_gamma before the clamped decay path.
jnp.minimum(..., 0.0) is only algebraically correct for decay heads. If any caller passes g_gamma > 0, these branches replace the required causal amplification with 1, so chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg silently compute different outputs/gradients instead of failing fast.
Suggested guard
assert_shape_or_none(g_gamma, (H,))
+ if g_gamma is not None:
+ assert (g_gamma <= 0).all(), "g_gamma must be non-positive for simple GLA decay"Please add the same guard in each public entrypoint that feeds this clamped Toeplitz path.
As per coding guidelines, public functions must enforce strict constraints and validation on the shape and types of input variables before executing the main logic.
Also applies to: 340-343, 524-529
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/common/chunk_o.py` around lines 251 - 254, The clamped decay path
uses jnp.minimum(..., 0.0) which assumes g_gamma is non-positive; add an
explicit validation at each public entrypoint that can reach this clamped
Toeplitz branch (specifically in chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg)
to assert or raise if g_gamma > 0 so callers fail fast instead of silently
producing incorrect amplification; apply the same guard where the ramp_diff ->
ramp_diff_safe computation occurs (the occurrences around the shown block and
the locations referenced near lines 340-343 and 524-529) and ensure the check
runs before computing ramp_diff/ramp_diff_safe.
| # 3. Compute A via simple GLA intra-chunk attention (overflow-safe) | ||
| g_gamma_4d = g_gamma.reshape(1, 1, H, 1) if g_gamma is not None else None | ||
| A = _chunk_simple_gla_fwd_intra(q, k, g_gamma_4d, scale, chunk_size=C) |
There was a problem hiding this comment.
Fail fast when g_gamma is absent.
chunk_simple_gla_bwd still advertises g_gamma: jax.Array | None, but this path now always calls _chunk_simple_gla_fwd_intra, and the fused backward kernel below also indexes g_gamma per head. With g_gamma=None, callers fall through to an internal reshape/indexing failure instead of a clear contract error.
Suggested guard
assert g is None, "per-element gating not supported in chunk_simple_gla_bwd"
+ assert g_gamma is not None, "chunk_simple_gla_bwd requires g_gamma"
assert scale is not None # fix pylance checkAs per coding guidelines, public functions must enforce strict constraints and validation on the shape and types of input variables before executing the main logic.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/simple_gla/chunk.py` around lines 628 - 630, The code currently lets
chunk_simple_gla_bwd call _chunk_simple_gla_fwd_intra and later per-head index
g_gamma even when g_gamma is None; add an explicit guard at the start of
chunk_simple_gla_bwd that checks if g_gamma is None and raises a clear
ValueError (or TypeError) describing that g_gamma must be provided and non-None,
so you fail fast before the g_gamma.reshape / per-head indexing in
_chunk_simple_gla_fwd_intra or the fused backward kernel; reference g_gamma,
chunk_simple_gla_bwd, and _chunk_simple_gla_fwd_intra in the error message to
help debugging.
Uses build_slope_tensor(16) with actual al_model.yml config (H=16, num_decoder_layers=20) to test chunk_size=128 stability. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/ops/simple_gla/test_actual_slopes_chunk128.py (1)
64-64: Clean up the lint-only nits in this new test (_ht,_dh0, andzip(..., strict=True)).Low-risk cleanup, but worth doing now to keep the test file warning-free.
Patch sketch
- o, ht = chunk_simple_gla_fwd( + o, _ht = chunk_simple_gla_fwd( ... - dq, dk, dv, dh0 = chunk_simple_gla_bwd( + dq, dk, dv, _dh0 = chunk_simple_gla_bwd( ... - for name, g in zip(["dq", "dk", "dv"], grads): + for name, g in zip(["dq", "dk", "dv"], grads, strict=True):Also applies to: 89-89, 188-188
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` at line 64, Rename unused unpacked variables from ht and dh0 to _ht and _dh0 where chunk_simple_gla_fwd (and similar calls) return values you don't use, and update any zip(...) calls in this test to zip(..., strict=True) to silence lint warnings and enforce equal-length iteration; search for the symbols chunk_simple_gla_fwd, ht, dh0, and zip in this file and make those three small edits at each occurrence (including the other noted lines around 89 and 188).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Around line 34-115: The test currently only exposes run_test(...) and a
main-guard script without pytest assertions or reference comparisons; convert it
into pytest-style tests by creating one or more test_* functions that call
run_test parameters or inline the logic, implement a simple CPU/naive reference
for the forward and backward computations (matching chunk_simple_gla_fwd and
chunk_simple_gla_bwd semantics) and use compare_tensor(output, reference,
atol=..., rtol=...) to assert numerical parity for both outputs and gradients,
then remove the if __name__ == "__main__" guard so pytest can collect the tests.
Ensure you reference the actual kernel entrypoints chunk_simple_gla_fwd and
chunk_simple_gla_bwd when computing the reference and assert with compare_tensor
for o, dq, dk, dv.
- Around line 81-83: Replace the broad "except Exception as e: print(...);
return False" blocks in the test (the handlers that print "Forward FAILED" and
similar at the shown locations) so exceptions are not swallowed: either remove
the try/except entirely to let the original exception and traceback propagate,
or replace the handler with pytest.fail(str(e), pytrace=True) or re-raise the
exception with "raise" or "raise ... from e" to preserve chaining; update the
error-handling around the Forward block and the other two similar blocks (the
ones printing failure and returning False) to use one of these approaches.
---
Nitpick comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Line 64: Rename unused unpacked variables from ht and dh0 to _ht and _dh0
where chunk_simple_gla_fwd (and similar calls) return values you don't use, and
update any zip(...) calls in this test to zip(..., strict=True) to silence lint
warnings and enforce equal-length iteration; search for the symbols
chunk_simple_gla_fwd, ht, dh0, and zip in this file and make those three small
edits at each occurrence (including the other noted lines around 89 and 188).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 64051237-fffe-4181-88df-54218bd88e1d
📒 Files selected for processing (1)
tests/ops/simple_gla/test_actual_slopes_chunk128.py
| def run_test(B, T, H, K, V, chunk_size, num_decoder_layers, layer_idx, seed=42): | ||
| """Run forward + backward with actual slope values.""" | ||
| np.random.seed(seed) | ||
| key = jax.random.PRNGKey(seed) | ||
|
|
||
| # Build actual g_gamma exactly as MaxText does | ||
| slope_base = build_slope_tensor(H) | ||
| denom = max(num_decoder_layers - 1, 1) | ||
| slope_scale = 1.0 - layer_idx / denom + 1e-5 | ||
| g_gamma = -slope_base * slope_scale # ALWAYS NEGATIVE | ||
|
|
||
| print(f"\n{'='*70}") | ||
| print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, chunk_size={chunk_size}") | ||
| print(f"Layer: {layer_idx}/{num_decoder_layers}, slope_scale={slope_scale:.4f}") | ||
| print(f"g_gamma range: [{g_gamma.min():.6f}, {g_gamma.max():.6f}]") | ||
| print(f"head 0: g_gamma={g_gamma[0]:.6f}, max exp arg in chunk = {abs(g_gamma[0]) * chunk_size:.2f}") | ||
| print(f"head 1: g_gamma={g_gamma[1]:.6f}, max exp arg in chunk = {abs(g_gamma[1]) * chunk_size:.2f}") | ||
| print(f"{'='*70}") | ||
|
|
||
| g_gamma_jax = jnp.array(g_gamma) | ||
|
|
||
| # Random inputs in bf16 (like actual training) | ||
| k1, k2, k3, k4 = jax.random.split(key, 4) | ||
| q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) | ||
| k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) | ||
|
|
||
| # Forward | ||
| print("\n--- Forward ---") | ||
| try: | ||
| o, ht = chunk_simple_gla_fwd( | ||
| q, k_arr, v, | ||
| g_gamma=g_gamma_jax, | ||
| chunk_size=chunk_size, | ||
| ) | ||
| has_nan = bool(jnp.any(jnp.isnan(o))) | ||
| has_inf = bool(jnp.any(jnp.isinf(o))) | ||
| print(f"Output shape: {o.shape}, dtype: {o.dtype}") | ||
| print(f"Output range: [{float(jnp.min(o)):.6f}, {float(jnp.max(o)):.6f}]") | ||
| print(f"NaN: {has_nan}, Inf: {has_inf}") | ||
| if has_nan or has_inf: | ||
| # Check per-head | ||
| for h in range(min(H, 8)): | ||
| oh = o[:, :, h, :] | ||
| print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, " | ||
| f"Inf={bool(jnp.any(jnp.isinf(oh)))}, " | ||
| f"range=[{float(jnp.min(oh)):.4f}, {float(jnp.max(oh)):.4f}]") | ||
| except Exception as e: | ||
| print(f"Forward FAILED: {e}") | ||
| return False | ||
|
|
||
| # Backward | ||
| print("\n--- Backward ---") | ||
| try: | ||
| do = jax.random.normal(k4, (B, T, H, V), dtype=jnp.bfloat16) | ||
| dq, dk, dv, dh0 = chunk_simple_gla_bwd( | ||
| q, k_arr, v, do, | ||
| g_gamma=g_gamma_jax, | ||
| chunk_size=chunk_size, | ||
| ) | ||
| for name, grad in [("dq", dq), ("dk", dk), ("dv", dv)]: | ||
| has_nan = bool(jnp.any(jnp.isnan(grad))) | ||
| has_inf = bool(jnp.any(jnp.isinf(grad))) | ||
| print(f"{name}: shape={grad.shape}, dtype={grad.dtype}, " | ||
| f"range=[{float(jnp.min(grad)):.6f}, {float(jnp.max(grad)):.6f}], " | ||
| f"NaN={has_nan}, Inf={has_inf}") | ||
| if has_nan or has_inf: | ||
| for h in range(min(H, 8)): | ||
| gh = grad[:, :, h, :] | ||
| print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(gh)))}, " | ||
| f"Inf={bool(jnp.any(jnp.isinf(gh)))}, " | ||
| f"range=[{float(jnp.min(gh)):.4f}, {float(jnp.max(gh)):.4f}]") | ||
| except Exception as e: | ||
| print(f"Backward FAILED: {e}") | ||
| return False | ||
|
|
||
| fwd_ok = not (bool(jnp.any(jnp.isnan(o))) or bool(jnp.any(jnp.isinf(o)))) | ||
| bwd_ok = all( | ||
| not (bool(jnp.any(jnp.isnan(g))) or bool(jnp.any(jnp.isinf(g)))) | ||
| for g in [dq, dk, dv] | ||
| ) | ||
| return fwd_ok and bwd_ok |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file="tests/ops/simple_gla/test_actual_slopes_chunk128.py"
echo "[1] pytest-discoverable tests in file:"
rg -n '^\s*def\s+test_' "$file" || true
echo "[2] main-guard usage:"
rg -n '__name__\s*==\s*["'"'"']__main__["'"'"']' "$file" || true
echo "[3] compare_tensor usage:"
rg -n '\bcompare_tensor\b' "$file" || trueRepository: primatrix/pallas-kernel
Length of output: 184
Restructure as pytest reference-parity test (currently bypasses pytest collection and lacks correctness verification).
This file is structured as a script under if __name__ == "__main__" and has no pytest-discoverable functions (def test_...). More critically, it lacks numerical correctness verification: the required compare_tensor utility is not used to compare optimized kernel output against a naive/CPU reference implementation.
Without reference-parity assertions, the test can report success while outputs diverge numerically. Per coding guidelines for tests/ops/**/*.py and tests/**/*.py, tests must use compare_tensor with tolerance-based assertions and be executable as standard pytest.
Minimal refactoring needed:
- Move logic from
run_test()into propertest_*functions - Implement naive/CPU reference kernel implementation
- Add
compare_tensor(output, reference, atol=..., rtol=...)assertions for forward and backward passes - Remove the
if __name__ == "__main__"guard
Also applies to: 118-201
🧰 Tools
🪛 Ruff (0.15.7)
[warning] 64-64: Unpacked variable ht is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 81-81: Do not catch blind exception: Exception
(BLE001)
[warning] 89-89: Unpacked variable dh0 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 106-106: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 34 - 115,
The test currently only exposes run_test(...) and a main-guard script without
pytest assertions or reference comparisons; convert it into pytest-style tests
by creating one or more test_* functions that call run_test parameters or inline
the logic, implement a simple CPU/naive reference for the forward and backward
computations (matching chunk_simple_gla_fwd and chunk_simple_gla_bwd semantics)
and use compare_tensor(output, reference, atol=..., rtol=...) to assert
numerical parity for both outputs and gradients, then remove the if __name__ ==
"__main__" guard so pytest can collect the tests. Ensure you reference the
actual kernel entrypoints chunk_simple_gla_fwd and chunk_simple_gla_bwd when
computing the reference and assert with compare_tensor for o, dq, dk, dv.
| except Exception as e: | ||
| print(f"Forward FAILED: {e}") | ||
| return False |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and get its basic info
wc -l tests/ops/simple_gla/test_actual_slopes_chunk128.pyRepository: primatrix/pallas-kernel
Length of output: 122
🏁 Script executed:
# Read the file to examine the exception handlers at the specified lines
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '70,120p'Repository: primatrix/pallas-kernel
Length of output: 2561
🏁 Script executed:
# Also check lines around 195-197
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '185,210p'Repository: primatrix/pallas-kernel
Length of output: 949
🏁 Script executed:
# Check the function definition to understand the test structure
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | head -70Repository: primatrix/pallas-kernel
Length of output: 3262
🏁 Script executed:
# Check imports to confirm pytest usage
head -20 tests/ops/simple_gla/test_actual_slopes_chunk128.pyRepository: primatrix/pallas-kernel
Length of output: 658
🏁 Script executed:
# Check the main block to see how run_test is being used
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | tail -30Repository: primatrix/pallas-kernel
Length of output: 1602
🏁 Script executed:
# Search for pytest tests that might call this run_test function
rg "test_" tests/ops/simple_gla/test_actual_slopes_chunk128.pyRepository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check if there are other test files in the same directory
fd "test_.*\.py" tests/ops/simple_gla/Repository: primatrix/pallas-kernel
Length of output: 314
🏁 Script executed:
# Check the main block structure to see how run_test is being called
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '118,145p'Repository: primatrix/pallas-kernel
Length of output: 1549
🏁 Script executed:
# Check if pytest has any configuration that runs this as a test
find . -name "pytest.ini" -o -name "pyproject.toml" -o -name "conftest.py" | head -5Repository: primatrix/pallas-kernel
Length of output: 103
🏁 Script executed:
# Check if there are any pytest tests that wrap or call run_test
rg "run_test\|test_actual_slopes" --type py tests/Repository: primatrix/pallas-kernel
Length of output: 49
Avoid broad except Exception handlers in test code; let failures propagate with full traceback.
These handlers suppress exceptions and return False or set flags, which weakens failure signals and makes debugging harder. Replace with proper error propagation: either re-raise exceptions with context, or if using pytest, apply pytest.fail() with exception chaining.
Applies to lines 81-83, 106-108, and 195-197.
🧰 Tools
🪛 Ruff (0.15.7)
[warning] 81-81: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 81 - 83,
Replace the broad "except Exception as e: print(...); return False" blocks in
the test (the handlers that print "Forward FAILED" and similar at the shown
locations) so exceptions are not swallowed: either remove the try/except
entirely to let the original exception and traceback propagate, or replace the
handler with pytest.fail(str(e), pytrace=True) or re-raise the exception with
"raise" or "raise ... from e" to preserve chaining; update the error-handling
around the Forward block and the other two similar blocks (the ones printing
failure and returning False) to use one of these approaches.
Matches the actual production path: bf16 input -> pallas_chunk_gla (custom_vjp) -> f32 conversion -> padding -> chunk_simple_gla_fwd/bwd Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/ops/simple_gla/test_actual_slopes_chunk128.py (2)
193-193: Addstrict=Truetozip()for safety.Using
zip()withoutstrict=can silently truncate if the iterables have different lengths. Since both lists should always have exactly 3 elements, addstrict=Trueto catch mismatches.♻️ Proposed fix
- for name, g in zip(["dq", "dk", "dv"], grads): + for name, g in zip(["dq", "dk", "dv"], grads, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` at line 193, The for-loop using zip(["dq", "dk", "dv"], grads) should use zip(..., strict=True) to ensure the two iterables have the same length and fail fast on mismatches; update the loop where the test iterates over name and g (i.e., the for name, g in zip(["dq", "dk", "dv"], grads) line in test_actual_slopes_chunk128.py) to call zip with strict=True so any length mismatch raises immediately.
179-182: Avoid broadexcept Exceptionhandlers; let test failures propagate.Lines 179 and 204 catch all exceptions and swallow them with a
return False, which obscures the root cause of failures. Additionally, lines 181 and 206 have multiple statements on one line (semicolon), which violates style conventions.♻️ Proposed fix: remove try/except or re-raise with context
- try: - o, _ = pallas_chunk_gla(q, k_arr, v, g_gamma_jax, chunk_size=chunk_size) - has_nan = bool(jnp.any(jnp.isnan(o))) - has_inf = bool(jnp.any(jnp.isinf(o))) - print(f"Output: {o.shape} {o.dtype}, range=[{float(jnp.min(o)):.4f}, {float(jnp.max(o)):.4f}]") - print(f"NaN: {has_nan}, Inf: {has_inf}") - if has_nan or has_inf: - for h in range(min(H, 8)): - oh = o[:, :, h, :] - print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, " - f"Inf={bool(jnp.any(jnp.isinf(oh)))}") - except Exception as e: - print(f"Forward FAILED: {e}") - import traceback; traceback.print_exc() - return False + o, _ = pallas_chunk_gla(q, k_arr, v, g_gamma_jax, chunk_size=chunk_size) + has_nan = bool(jnp.any(jnp.isnan(o))) + has_inf = bool(jnp.any(jnp.isinf(o))) + print(f"Output: {o.shape} {o.dtype}, range=[{float(jnp.min(o)):.4f}, {float(jnp.max(o)):.4f}]") + print(f"NaN: {has_nan}, Inf: {has_inf}") + if has_nan or has_inf: + for h in range(min(H, 8)): + oh = o[:, :, h, :] + print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, " + f"Inf={bool(jnp.any(jnp.isinf(oh)))}")Apply the same pattern to the backward pass (lines 186-207).
Also applies to: 204-207
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 179 - 182, Remove the broad try/except blocks that swallow errors in the forward and backward sections (the blocks that print "Forward FAILED" and "Backward FAILED"); either delete the try/except so exceptions propagate and fail the test, or catch specific exceptions and re-raise them with additional context (e.g., raise RuntimeError("Forward failed") from e). Also split any multiple statements on one line (remove semicolons used for traceback.print_exc()) so each statement is on its own line; locate the blocks by the "Forward FAILED" and "Backward FAILED" print strings and update the surrounding try/except accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Line 193: The for-loop using zip(["dq", "dk", "dv"], grads) should use
zip(..., strict=True) to ensure the two iterables have the same length and fail
fast on mismatches; update the loop where the test iterates over name and g
(i.e., the for name, g in zip(["dq", "dk", "dv"], grads) line in
test_actual_slopes_chunk128.py) to call zip with strict=True so any length
mismatch raises immediately.
- Around line 179-182: Remove the broad try/except blocks that swallow errors in
the forward and backward sections (the blocks that print "Forward FAILED" and
"Backward FAILED"); either delete the try/except so exceptions propagate and
fail the test, or catch specific exceptions and re-raise them with additional
context (e.g., raise RuntimeError("Forward failed") from e). Also split any
multiple statements on one line (remove semicolons used for
traceback.print_exc()) so each statement is on its own line; locate the blocks
by the "Forward FAILED" and "Backward FAILED" print strings and update the
surrounding try/except accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d60fa07f-727c-4942-972d-975408db22bd
📒 Files selected for processing (1)
tests/ops/simple_gla/test_actual_slopes_chunk128.py
Summary
Root Cause
For 16 heads, max |gamma| = 0.707. At chunk_size=128: |gamma| * 128 = 90.5 > 88.72 (f32 exp overflow threshold), causing
exp(-b_g)→ Inf → NaN.Fix Approach
Replace
exp(b_g[i]) * exp(-b_g[j])withexp(b_g[i] - b_g[j]), which is always ≤ 0 for causal (i ≥ j) entries. All fixes are transparent for chunk_size=64.Test Results (TPU v7x, 8 devices)
🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
Refactor
Tests