Skip to content

fix: numerical stability for GLA chunk_size=128#118

Closed
Iamleos wants to merge 3 commits intomainfrom
fix/chunk128-nan-stability
Closed

fix: numerical stability for GLA chunk_size=128#118
Iamleos wants to merge 3 commits intomainfrom
fix/chunk128-nan-stability

Conversation

@Iamleos
Copy link
Copy Markdown

@Iamleos Iamleos commented Mar 31, 2026

Summary

  • Fix numerical overflow (NaN) in GLA Pallas kernel when chunk_size > 64
  • Backward kernel: decay matrix approach replacing separate exp(b_g)/exp(-b_g)
  • Forward intra kernel: clamp + early causal mask
  • Backward A computation: use overflow-safe simple GLA intra kernel
  • Reference fwd/bwd: clamp ramp_diff before exp() in masked regions

Root Cause

For 16 heads, max |gamma| = 0.707. At chunk_size=128: |gamma| * 128 = 90.5 > 88.72 (f32 exp overflow threshold), causing exp(-b_g) → Inf → NaN.

Fix Approach

Replace exp(b_g[i]) * exp(-b_g[j]) with exp(b_g[i] - b_g[j]), which is always ≤ 0 for causal (i ≥ j) entries. All fixes are transparent for chunk_size=64.

Test Results (TPU v7x, 8 devices)

  • B∈{2,4,8,16,32,64,128} × T∈{1024,2048,4096,8192}: 27/28 pass (1 OOM at B=128,T=8192)
  • Forward + backward: zero NaN/Inf across all configs
  • Backward compatibility: existing test_chunk_gla_pallas.py 15/15 passed

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Added overflow-safe clamping for exponent arguments and reinforced causal masking to prevent unstable or incorrect attention contributions.
  • Refactor

    • Replaced per-vector exponentiation with precomputed decay matrices to streamline intra-chunk attention and simplified backward attention generation logic.
    • No public API/signatures changed.
  • Tests

    • Added end-to-end chunked forward/backward tests validating slopes, numerical stability, and gradients.

The backward kernel and reference implementations overflow when
chunk_size > 64 because exp(-b_g) and exp(b_g) produce Inf for
large |gamma| * position values (e.g. 0.707 * 128 = 90.5 > 88.72
f32 limit).

Fixes:
- Backward kernel: replace separate exp(b_g)/exp(-b_g) with decay
  matrix exp(b_g[i]-b_g[j]) which is always <= 0 for causal entries
- Forward intra kernel: clamp upper triangle + early causal mask
- Backward A computation: use overflow-safe simple GLA intra kernel
  instead of chunk_gla_fwd_intra_gk_ref with synthetic per-K gates
- Reference fwd/bwd: clamp ramp_diff before exp() in masked regions

All fixes are transparent for chunk_size=64 (max exp arg = 44.5,
well within f32 range). Tested on TPU v7x with B=2..128, T=1k..8k,
H=16, K=V=128 — zero NaN/Inf in forward and backward.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Replaced per-vector exponential scalings with pairwise exp(g_i - g_j) decays and added clamped exponent args before exp in multiple forward/backward GLA chunk kernels; refactored intra-chunk A recomputation and added a new test exercising per-head g_gamma slopes with forward/backward checks.

Changes

Cohort / File(s) Summary
Common chunk output kernels
tops/ops/common/chunk_o.py
Replaced separate exp(g) factors with precomputed pairwise decay matrices (exp(b_g[i]-b_g[j]), masked lower/upper) in chunk_simple_gla_bwd_kernel. Added clamping (jnp.minimum(..., 0.0)) for exponent arguments before jnp.exp in chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg (g_gamma path) and applied causal masking consistently.
SimpleGLA intra-chunk kernels
tops/ops/simple_gla/chunk.py
Made Toeplitz decay computation overflow-safe by clamping exponent args before exp in _chunk_simple_gla_fwd_intra_kernel and enforced causal lower-triangular masking via jnp.where. chunk_simple_gla_bwd now reuses the intra-chunk forward launcher _chunk_simple_gla_fwd_intra (reshaped g_gamma) to produce A instead of a synthetic call.
Tests
tests/ops/simple_gla/test_actual_slopes_chunk128.py
Added a new test module that defines a pallas_chunk_gla custom_vjp wrapper, builds per-head slope tensors, pads inputs for chunking, runs forward/backward (BF16) checks across configurations, and validates absence of NaN/Inf in outputs and gradients.

Sequence Diagram(s)

(Skipped — changes are numerical/stability updates and do not introduce a new multi-component control-flow requiring sequence diagrams.)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • 0xaskr

Poem

🐰 I nibbled at exponents slow,

clipped rising tides so overflows go,
pairwise decays now mind their place,
causal masks keep order and grace,
gradients hop back safe in tow.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: numerical stability for GLA chunk_size=128' directly and clearly addresses the main change: fixing numerical overflow issues in the GLA kernel for larger chunk sizes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/chunk128-nan-stability
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch fix/chunk128-nan-stability

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the numerical stability of GLA kernels by clamping exponent arguments to prevent overflows during decay matrix calculations. It refactors the backward pass and intra-chunk attention mechanisms to incorporate these safety measures across several functions. Feedback suggests further optimizing the decay matrix application in chunk_o.py by reusing the decayed attention matrix and removing redundant masking operations.

Comment on lines +73 to 101
# === Precompute decay matrices (numerically stable for large BT) ===
# decay_lower[i, j] = exp(b_g[i] - b_g[j]) for i >= j; 0 otherwise
# Since gamma < 0 and i >= j: b_g[i] - b_g[j] = gamma*(i-j) <= 0 => exp in (0, 1]
b_g_diff = b_g[:, None] - b_g[None, :] # [BT, BT]
b_g_diff_safe = jnp.minimum(b_g_diff, 0.0) # clamp upper tri to avoid overflow
decay_lower = jnp.where(mask, jnp.exp(b_g_diff_safe), 0.0) # [BT, BT]
decay_upper = decay_lower.T # decay_upper[j, i] = exp(b_g[i] - b_g[j]) for i >= j

# 3. dq = (dA * decay_lower) @ k + do @ h^T * scale * exp(b_g)
# dq_intra[i] = sum_j dA[i,j] * k[j] * exp(b_g[i] - b_g[j])
b_dq_intra = jnp.dot((b_dA * decay_lower).astype(b_k.dtype), b_k,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * jnp.exp(b_g)[:, None]
preferred_element_type=jnp.float32)
# dq_inter: exp(b_g) is safe (b_g <= 0 => exp <= 1, may underflow but no overflow)
b_dq_inter = jnp.dot(b_do, b_h.astype(b_do.dtype).T,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * (scale * jnp.exp(b_g)[:, None])
b_dq = b_dq_intra + b_dq_inter
dq_ref[0, 0] = b_dq.astype(dq_ref.dtype)

# 4. dk = dA^T @ q_pos * exp(-b_g) + v @ dh^T * exp(b_gn - b_g)
q_pos = (b_q * jnp.exp(b_g)[:, None]).astype(b_q.dtype)
b_dk_intra = jnp.dot(b_dA.T.astype(q_pos.dtype), q_pos,
# 4. dk = (dA^T * decay_upper) @ q + v @ dh^T * exp(b_gn - b_g)
# dk_intra[j] = sum_i dA[i,j] * q[i] * exp(b_g[i] - b_g[j])
b_dk_intra = jnp.dot((b_dA.T * decay_upper).astype(b_q.dtype), b_q,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * jnp.exp(-b_g)[:, None]
preferred_element_type=jnp.float32)
# dk_inter: exp(b_gn - b_g) is safe (b_gn <= b_g[j] => b_gn - b_g <= 0)
b_dk_inter = jnp.dot(b_v, b_dh.astype(b_v.dtype).T,
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32) * jnp.exp(b_gn - b_g)[:, None]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The computation of decay matrices and their application to dq and dk can be optimized. By reusing the decayed attention matrix, you can avoid redundant element-wise multiplications and transpositions. Additionally, since b_dA is already masked at line 59, the explicit masking in decay_lower is redundant as long as we clamp the exponent to avoid overflow in the upper triangle.

    # === Precompute decay matrix (numerically stable for large BT) ===
    # decay[i, j] = exp(b_g[i] - b_g[j]). For causal entries (i >= j), b_g[i] - b_g[j] <= 0.
    # We clamp to 0.0 to avoid overflow in the upper triangle (i < j).
    b_g_diff = b_g[:, None] - b_g[None, :]
    decay = jnp.exp(jnp.minimum(b_g_diff, 0.0))
    b_dA_decayed = b_dA * decay

    # 3. dq = (dA * decay) @ k + do @ h^T * scale * exp(b_g)
    b_dq_intra = jnp.dot(b_dA_decayed.astype(b_k.dtype), b_k,
                          precision=jax.lax.Precision.HIGHEST,
                          preferred_element_type=jnp.float32)
    # dq_inter: exp(b_g) is safe (b_g <= 0 => exp <= 1)
    b_dq_inter = jnp.dot(b_do, b_h.astype(b_do.dtype).T,
                          precision=jax.lax.Precision.HIGHEST,
                          preferred_element_type=jnp.float32) * (scale * jnp.exp(b_g)[:, None])
    b_dq = b_dq_intra + b_dq_inter
    dq_ref[0, 0] = b_dq.astype(dq_ref.dtype)

    # 4. dk = (dA * decay)^T @ q + v @ dh^T * exp(b_gn - b_g)
    b_dk_intra = jnp.dot(b_dA_decayed.T.astype(b_q.dtype), b_q,
                          precision=jax.lax.Precision.HIGHEST,
                          preferred_element_type=jnp.float32)
    # dk_inter: exp(b_gn - b_g) is safe (b_gn <= b_g[j] => b_gn - b_g <= 0)
    b_dk_inter = jnp.dot(b_v, b_dh.astype(b_v.dtype).T,
                          precision=jax.lax.Precision.HIGHEST,
                          preferred_element_type=jnp.float32) * jnp.exp(b_gn - b_g)[:, None]

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 251-254: The clamped decay path uses jnp.minimum(..., 0.0) which
assumes g_gamma is non-positive; add an explicit validation at each public
entrypoint that can reach this clamped Toeplitz branch (specifically in
chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg) to assert or raise if g_gamma >
0 so callers fail fast instead of silently producing incorrect amplification;
apply the same guard where the ramp_diff -> ramp_diff_safe computation occurs
(the occurrences around the shown block and the locations referenced near lines
340-343 and 524-529) and ensure the check runs before computing
ramp_diff/ramp_diff_safe.

In `@tops/ops/simple_gla/chunk.py`:
- Around line 628-630: The code currently lets chunk_simple_gla_bwd call
_chunk_simple_gla_fwd_intra and later per-head index g_gamma even when g_gamma
is None; add an explicit guard at the start of chunk_simple_gla_bwd that checks
if g_gamma is None and raises a clear ValueError (or TypeError) describing that
g_gamma must be provided and non-None, so you fail fast before the
g_gamma.reshape / per-head indexing in _chunk_simple_gla_fwd_intra or the fused
backward kernel; reference g_gamma, chunk_simple_gla_bwd, and
_chunk_simple_gla_fwd_intra in the error message to help debugging.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dfd460ca-d882-4ef8-b68e-d4a96c69b3e8

📥 Commits

Reviewing files that changed from the base of the PR and between abe6ccb and b5e2c7b.

📒 Files selected for processing (2)
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py

Comment on lines +251 to +254
# Clamp ramp diff to avoid overflow in upper triangle (masked later)
ramp_diff = ramp[..., :, None] - ramp[..., None, :] # [H, C, C]
ramp_diff_safe = jnp.minimum(ramp_diff, 0.0)
A = A * exp(ramp_diff_safe)[None, None]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate non-positive g_gamma before the clamped decay path.

jnp.minimum(..., 0.0) is only algebraically correct for decay heads. If any caller passes g_gamma > 0, these branches replace the required causal amplification with 1, so chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg silently compute different outputs/gradients instead of failing fast.

Suggested guard
     assert_shape_or_none(g_gamma, (H,))
+    if g_gamma is not None:
+        assert (g_gamma <= 0).all(), "g_gamma must be non-positive for simple GLA decay"

Please add the same guard in each public entrypoint that feeds this clamped Toeplitz path.

As per coding guidelines, public functions must enforce strict constraints and validation on the shape and types of input variables before executing the main logic.

Also applies to: 340-343, 524-529

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_o.py` around lines 251 - 254, The clamped decay path
uses jnp.minimum(..., 0.0) which assumes g_gamma is non-positive; add an
explicit validation at each public entrypoint that can reach this clamped
Toeplitz branch (specifically in chunk_fwd_o, chunk_bwd_dv, and chunk_bwd_dqkwg)
to assert or raise if g_gamma > 0 so callers fail fast instead of silently
producing incorrect amplification; apply the same guard where the ramp_diff ->
ramp_diff_safe computation occurs (the occurrences around the shown block and
the locations referenced near lines 340-343 and 524-529) and ensure the check
runs before computing ramp_diff/ramp_diff_safe.

Comment on lines +628 to +630
# 3. Compute A via simple GLA intra-chunk attention (overflow-safe)
g_gamma_4d = g_gamma.reshape(1, 1, H, 1) if g_gamma is not None else None
A = _chunk_simple_gla_fwd_intra(q, k, g_gamma_4d, scale, chunk_size=C)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when g_gamma is absent.

chunk_simple_gla_bwd still advertises g_gamma: jax.Array | None, but this path now always calls _chunk_simple_gla_fwd_intra, and the fused backward kernel below also indexes g_gamma per head. With g_gamma=None, callers fall through to an internal reshape/indexing failure instead of a clear contract error.

Suggested guard
     assert g is None, "per-element gating not supported in chunk_simple_gla_bwd"
+    assert g_gamma is not None, "chunk_simple_gla_bwd requires g_gamma"
     assert scale is not None # fix pylance check

As per coding guidelines, public functions must enforce strict constraints and validation on the shape and types of input variables before executing the main logic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 628 - 630, The code currently lets
chunk_simple_gla_bwd call _chunk_simple_gla_fwd_intra and later per-head index
g_gamma even when g_gamma is None; add an explicit guard at the start of
chunk_simple_gla_bwd that checks if g_gamma is None and raises a clear
ValueError (or TypeError) describing that g_gamma must be provided and non-None,
so you fail fast before the g_gamma.reshape / per-head indexing in
_chunk_simple_gla_fwd_intra or the fused backward kernel; reference g_gamma,
chunk_simple_gla_bwd, and _chunk_simple_gla_fwd_intra in the error message to
help debugging.

Uses build_slope_tensor(16) with actual al_model.yml config
(H=16, num_decoder_layers=20) to test chunk_size=128 stability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/ops/simple_gla/test_actual_slopes_chunk128.py (1)

64-64: Clean up the lint-only nits in this new test (_ht, _dh0, and zip(..., strict=True)).

Low-risk cleanup, but worth doing now to keep the test file warning-free.

Patch sketch
-        o, ht = chunk_simple_gla_fwd(
+        o, _ht = chunk_simple_gla_fwd(
...
-        dq, dk, dv, dh0 = chunk_simple_gla_bwd(
+        dq, dk, dv, _dh0 = chunk_simple_gla_bwd(
...
-        for name, g in zip(["dq", "dk", "dv"], grads):
+        for name, g in zip(["dq", "dk", "dv"], grads, strict=True):

Also applies to: 89-89, 188-188

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` at line 64, Rename
unused unpacked variables from ht and dh0 to _ht and _dh0 where
chunk_simple_gla_fwd (and similar calls) return values you don't use, and update
any zip(...) calls in this test to zip(..., strict=True) to silence lint
warnings and enforce equal-length iteration; search for the symbols
chunk_simple_gla_fwd, ht, dh0, and zip in this file and make those three small
edits at each occurrence (including the other noted lines around 89 and 188).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Around line 34-115: The test currently only exposes run_test(...) and a
main-guard script without pytest assertions or reference comparisons; convert it
into pytest-style tests by creating one or more test_* functions that call
run_test parameters or inline the logic, implement a simple CPU/naive reference
for the forward and backward computations (matching chunk_simple_gla_fwd and
chunk_simple_gla_bwd semantics) and use compare_tensor(output, reference,
atol=..., rtol=...) to assert numerical parity for both outputs and gradients,
then remove the if __name__ == "__main__" guard so pytest can collect the tests.
Ensure you reference the actual kernel entrypoints chunk_simple_gla_fwd and
chunk_simple_gla_bwd when computing the reference and assert with compare_tensor
for o, dq, dk, dv.
- Around line 81-83: Replace the broad "except Exception as e: print(...);
return False" blocks in the test (the handlers that print "Forward FAILED" and
similar at the shown locations) so exceptions are not swallowed: either remove
the try/except entirely to let the original exception and traceback propagate,
or replace the handler with pytest.fail(str(e), pytrace=True) or re-raise the
exception with "raise" or "raise ... from e" to preserve chaining; update the
error-handling around the Forward block and the other two similar blocks (the
ones printing failure and returning False) to use one of these approaches.

---

Nitpick comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Line 64: Rename unused unpacked variables from ht and dh0 to _ht and _dh0
where chunk_simple_gla_fwd (and similar calls) return values you don't use, and
update any zip(...) calls in this test to zip(..., strict=True) to silence lint
warnings and enforce equal-length iteration; search for the symbols
chunk_simple_gla_fwd, ht, dh0, and zip in this file and make those three small
edits at each occurrence (including the other noted lines around 89 and 188).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 64051237-fffe-4181-88df-54218bd88e1d

📥 Commits

Reviewing files that changed from the base of the PR and between b5e2c7b and 4ae7e74.

📒 Files selected for processing (1)
  • tests/ops/simple_gla/test_actual_slopes_chunk128.py

Comment on lines +34 to +115
def run_test(B, T, H, K, V, chunk_size, num_decoder_layers, layer_idx, seed=42):
"""Run forward + backward with actual slope values."""
np.random.seed(seed)
key = jax.random.PRNGKey(seed)

# Build actual g_gamma exactly as MaxText does
slope_base = build_slope_tensor(H)
denom = max(num_decoder_layers - 1, 1)
slope_scale = 1.0 - layer_idx / denom + 1e-5
g_gamma = -slope_base * slope_scale # ALWAYS NEGATIVE

print(f"\n{'='*70}")
print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, chunk_size={chunk_size}")
print(f"Layer: {layer_idx}/{num_decoder_layers}, slope_scale={slope_scale:.4f}")
print(f"g_gamma range: [{g_gamma.min():.6f}, {g_gamma.max():.6f}]")
print(f"head 0: g_gamma={g_gamma[0]:.6f}, max exp arg in chunk = {abs(g_gamma[0]) * chunk_size:.2f}")
print(f"head 1: g_gamma={g_gamma[1]:.6f}, max exp arg in chunk = {abs(g_gamma[1]) * chunk_size:.2f}")
print(f"{'='*70}")

g_gamma_jax = jnp.array(g_gamma)

# Random inputs in bf16 (like actual training)
k1, k2, k3, k4 = jax.random.split(key, 4)
q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16)
k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16)
v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16)

# Forward
print("\n--- Forward ---")
try:
o, ht = chunk_simple_gla_fwd(
q, k_arr, v,
g_gamma=g_gamma_jax,
chunk_size=chunk_size,
)
has_nan = bool(jnp.any(jnp.isnan(o)))
has_inf = bool(jnp.any(jnp.isinf(o)))
print(f"Output shape: {o.shape}, dtype: {o.dtype}")
print(f"Output range: [{float(jnp.min(o)):.6f}, {float(jnp.max(o)):.6f}]")
print(f"NaN: {has_nan}, Inf: {has_inf}")
if has_nan or has_inf:
# Check per-head
for h in range(min(H, 8)):
oh = o[:, :, h, :]
print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, "
f"Inf={bool(jnp.any(jnp.isinf(oh)))}, "
f"range=[{float(jnp.min(oh)):.4f}, {float(jnp.max(oh)):.4f}]")
except Exception as e:
print(f"Forward FAILED: {e}")
return False

# Backward
print("\n--- Backward ---")
try:
do = jax.random.normal(k4, (B, T, H, V), dtype=jnp.bfloat16)
dq, dk, dv, dh0 = chunk_simple_gla_bwd(
q, k_arr, v, do,
g_gamma=g_gamma_jax,
chunk_size=chunk_size,
)
for name, grad in [("dq", dq), ("dk", dk), ("dv", dv)]:
has_nan = bool(jnp.any(jnp.isnan(grad)))
has_inf = bool(jnp.any(jnp.isinf(grad)))
print(f"{name}: shape={grad.shape}, dtype={grad.dtype}, "
f"range=[{float(jnp.min(grad)):.6f}, {float(jnp.max(grad)):.6f}], "
f"NaN={has_nan}, Inf={has_inf}")
if has_nan or has_inf:
for h in range(min(H, 8)):
gh = grad[:, :, h, :]
print(f" head {h}: NaN={bool(jnp.any(jnp.isnan(gh)))}, "
f"Inf={bool(jnp.any(jnp.isinf(gh)))}, "
f"range=[{float(jnp.min(gh)):.4f}, {float(jnp.max(gh)):.4f}]")
except Exception as e:
print(f"Backward FAILED: {e}")
return False

fwd_ok = not (bool(jnp.any(jnp.isnan(o))) or bool(jnp.any(jnp.isinf(o))))
bwd_ok = all(
not (bool(jnp.any(jnp.isnan(g))) or bool(jnp.any(jnp.isinf(g))))
for g in [dq, dk, dv]
)
return fwd_ok and bwd_ok
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="tests/ops/simple_gla/test_actual_slopes_chunk128.py"

echo "[1] pytest-discoverable tests in file:"
rg -n '^\s*def\s+test_' "$file" || true

echo "[2] main-guard usage:"
rg -n '__name__\s*==\s*["'"'"']__main__["'"'"']' "$file" || true

echo "[3] compare_tensor usage:"
rg -n '\bcompare_tensor\b' "$file" || true

Repository: primatrix/pallas-kernel

Length of output: 184


Restructure as pytest reference-parity test (currently bypasses pytest collection and lacks correctness verification).

This file is structured as a script under if __name__ == "__main__" and has no pytest-discoverable functions (def test_...). More critically, it lacks numerical correctness verification: the required compare_tensor utility is not used to compare optimized kernel output against a naive/CPU reference implementation.

Without reference-parity assertions, the test can report success while outputs diverge numerically. Per coding guidelines for tests/ops/**/*.py and tests/**/*.py, tests must use compare_tensor with tolerance-based assertions and be executable as standard pytest.

Minimal refactoring needed:

  • Move logic from run_test() into proper test_* functions
  • Implement naive/CPU reference kernel implementation
  • Add compare_tensor(output, reference, atol=..., rtol=...) assertions for forward and backward passes
  • Remove the if __name__ == "__main__" guard

Also applies to: 118-201

🧰 Tools
🪛 Ruff (0.15.7)

[warning] 64-64: Unpacked variable ht is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


[warning] 81-81: Do not catch blind exception: Exception

(BLE001)


[warning] 89-89: Unpacked variable dh0 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


[warning] 106-106: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 34 - 115,
The test currently only exposes run_test(...) and a main-guard script without
pytest assertions or reference comparisons; convert it into pytest-style tests
by creating one or more test_* functions that call run_test parameters or inline
the logic, implement a simple CPU/naive reference for the forward and backward
computations (matching chunk_simple_gla_fwd and chunk_simple_gla_bwd semantics)
and use compare_tensor(output, reference, atol=..., rtol=...) to assert
numerical parity for both outputs and gradients, then remove the if __name__ ==
"__main__" guard so pytest can collect the tests. Ensure you reference the
actual kernel entrypoints chunk_simple_gla_fwd and chunk_simple_gla_bwd when
computing the reference and assert with compare_tensor for o, dq, dk, dv.

Comment on lines +81 to +83
except Exception as e:
print(f"Forward FAILED: {e}")
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and get its basic info
wc -l tests/ops/simple_gla/test_actual_slopes_chunk128.py

Repository: primatrix/pallas-kernel

Length of output: 122


🏁 Script executed:

# Read the file to examine the exception handlers at the specified lines
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '70,120p'

Repository: primatrix/pallas-kernel

Length of output: 2561


🏁 Script executed:

# Also check lines around 195-197
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '185,210p'

Repository: primatrix/pallas-kernel

Length of output: 949


🏁 Script executed:

# Check the function definition to understand the test structure
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | head -70

Repository: primatrix/pallas-kernel

Length of output: 3262


🏁 Script executed:

# Check imports to confirm pytest usage
head -20 tests/ops/simple_gla/test_actual_slopes_chunk128.py

Repository: primatrix/pallas-kernel

Length of output: 658


🏁 Script executed:

# Check the main block to see how run_test is being used
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | tail -30

Repository: primatrix/pallas-kernel

Length of output: 1602


🏁 Script executed:

# Search for pytest tests that might call this run_test function
rg "test_" tests/ops/simple_gla/test_actual_slopes_chunk128.py

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check if there are other test files in the same directory
fd "test_.*\.py" tests/ops/simple_gla/

Repository: primatrix/pallas-kernel

Length of output: 314


🏁 Script executed:

# Check the main block structure to see how run_test is being called
cat -n tests/ops/simple_gla/test_actual_slopes_chunk128.py | sed -n '118,145p'

Repository: primatrix/pallas-kernel

Length of output: 1549


🏁 Script executed:

# Check if pytest has any configuration that runs this as a test
find . -name "pytest.ini" -o -name "pyproject.toml" -o -name "conftest.py" | head -5

Repository: primatrix/pallas-kernel

Length of output: 103


🏁 Script executed:

# Check if there are any pytest tests that wrap or call run_test
rg "run_test\|test_actual_slopes" --type py tests/

Repository: primatrix/pallas-kernel

Length of output: 49


Avoid broad except Exception handlers in test code; let failures propagate with full traceback.

These handlers suppress exceptions and return False or set flags, which weakens failure signals and makes debugging harder. Replace with proper error propagation: either re-raise exceptions with context, or if using pytest, apply pytest.fail() with exception chaining.

Applies to lines 81-83, 106-108, and 195-197.

🧰 Tools
🪛 Ruff (0.15.7)

[warning] 81-81: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 81 - 83,
Replace the broad "except Exception as e: print(...); return False" blocks in
the test (the handlers that print "Forward FAILED" and similar at the shown
locations) so exceptions are not swallowed: either remove the try/except
entirely to let the original exception and traceback propagate, or replace the
handler with pytest.fail(str(e), pytrace=True) or re-raise the exception with
"raise" or "raise ... from e" to preserve chaining; update the error-handling
around the Forward block and the other two similar blocks (the ones printing
failure and returning False) to use one of these approaches.

Matches the actual production path:
  bf16 input -> pallas_chunk_gla (custom_vjp) -> f32 conversion
  -> padding -> chunk_simple_gla_fwd/bwd

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/ops/simple_gla/test_actual_slopes_chunk128.py (2)

193-193: Add strict=True to zip() for safety.

Using zip() without strict= can silently truncate if the iterables have different lengths. Since both lists should always have exactly 3 elements, add strict=True to catch mismatches.

♻️ Proposed fix
-        for name, g in zip(["dq", "dk", "dv"], grads):
+        for name, g in zip(["dq", "dk", "dv"], grads, strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` at line 193, The
for-loop using zip(["dq", "dk", "dv"], grads) should use zip(..., strict=True)
to ensure the two iterables have the same length and fail fast on mismatches;
update the loop where the test iterates over name and g (i.e., the for name, g
in zip(["dq", "dk", "dv"], grads) line in test_actual_slopes_chunk128.py) to
call zip with strict=True so any length mismatch raises immediately.

179-182: Avoid broad except Exception handlers; let test failures propagate.

Lines 179 and 204 catch all exceptions and swallow them with a return False, which obscures the root cause of failures. Additionally, lines 181 and 206 have multiple statements on one line (semicolon), which violates style conventions.

♻️ Proposed fix: remove try/except or re-raise with context
-    try:
-        o, _ = pallas_chunk_gla(q, k_arr, v, g_gamma_jax, chunk_size=chunk_size)
-        has_nan = bool(jnp.any(jnp.isnan(o)))
-        has_inf = bool(jnp.any(jnp.isinf(o)))
-        print(f"Output: {o.shape} {o.dtype}, range=[{float(jnp.min(o)):.4f}, {float(jnp.max(o)):.4f}]")
-        print(f"NaN: {has_nan}, Inf: {has_inf}")
-        if has_nan or has_inf:
-            for h in range(min(H, 8)):
-                oh = o[:, :, h, :]
-                print(f"  head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, "
-                      f"Inf={bool(jnp.any(jnp.isinf(oh)))}")
-    except Exception as e:
-        print(f"Forward FAILED: {e}")
-        import traceback; traceback.print_exc()
-        return False
+    o, _ = pallas_chunk_gla(q, k_arr, v, g_gamma_jax, chunk_size=chunk_size)
+    has_nan = bool(jnp.any(jnp.isnan(o)))
+    has_inf = bool(jnp.any(jnp.isinf(o)))
+    print(f"Output: {o.shape} {o.dtype}, range=[{float(jnp.min(o)):.4f}, {float(jnp.max(o)):.4f}]")
+    print(f"NaN: {has_nan}, Inf: {has_inf}")
+    if has_nan or has_inf:
+        for h in range(min(H, 8)):
+            oh = o[:, :, h, :]
+            print(f"  head {h}: NaN={bool(jnp.any(jnp.isnan(oh)))}, "
+                  f"Inf={bool(jnp.any(jnp.isinf(oh)))}")

Apply the same pattern to the backward pass (lines 186-207).

Also applies to: 204-207

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py` around lines 179 - 182,
Remove the broad try/except blocks that swallow errors in the forward and
backward sections (the blocks that print "Forward FAILED" and "Backward
FAILED"); either delete the try/except so exceptions propagate and fail the
test, or catch specific exceptions and re-raise them with additional context
(e.g., raise RuntimeError("Forward failed") from e). Also split any multiple
statements on one line (remove semicolons used for traceback.print_exc()) so
each statement is on its own line; locate the blocks by the "Forward FAILED" and
"Backward FAILED" print strings and update the surrounding try/except
accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ops/simple_gla/test_actual_slopes_chunk128.py`:
- Line 193: The for-loop using zip(["dq", "dk", "dv"], grads) should use
zip(..., strict=True) to ensure the two iterables have the same length and fail
fast on mismatches; update the loop where the test iterates over name and g
(i.e., the for name, g in zip(["dq", "dk", "dv"], grads) line in
test_actual_slopes_chunk128.py) to call zip with strict=True so any length
mismatch raises immediately.
- Around line 179-182: Remove the broad try/except blocks that swallow errors in
the forward and backward sections (the blocks that print "Forward FAILED" and
"Backward FAILED"); either delete the try/except so exceptions propagate and
fail the test, or catch specific exceptions and re-raise them with additional
context (e.g., raise RuntimeError("Forward failed") from e). Also split any
multiple statements on one line (remove semicolons used for
traceback.print_exc()) so each statement is on its own line; locate the blocks
by the "Forward FAILED" and "Backward FAILED" print strings and update the
surrounding try/except accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d60fa07f-727c-4942-972d-975408db22bd

📥 Commits

Reviewing files that changed from the base of the PR and between 4ae7e74 and dd2b7b7.

📒 Files selected for processing (1)
  • tests/ops/simple_gla/test_actual_slopes_chunk128.py

@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Mar 31, 2026

This [https://github.com/https://github.com//pull/131] has been merged. Try pointing pallas-kernel to the latest main branch directly in the pretrain repository. @Iamleos

@Iamleos Iamleos closed this Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants