Skip to content

fix: clamp upper-triangle decay exponents to prevent overflow at chunk_size>64#117

Closed
pengchengneo wants to merge 6 commits intomainfrom
fix/chunk-decay-overflow
Closed

fix: clamp upper-triangle decay exponents to prevent overflow at chunk_size>64#117
pengchengneo wants to merge 6 commits intomainfrom
fix/chunk-decay-overflow

Conversation

@pengchengneo
Copy link
Copy Markdown

@pengchengneo pengchengneo commented Mar 31, 2026

Summary

  • Root cause: GLA Toeplitz decay matrix exp(gamma * (i - j)) overflows float32 when chunk_size=128 and ALiBi slopes are large (head 0: slope=0.707, max exponent=0.707×127=89.8 > 88.7). In backward pass, 0 * inf = NaN propagates through grad_norm.
  • Fix: Apply jnp.where(mask, decay_arg, 0.0) before exp() — upper-triangle entries produce exp(0)=1 (harmless, masked away later) instead of exp(89.8)=inf.
  • Patches both forward and backward kernels in chunk_o.py and forward intra-attention kernel in simple_gla/chunk.py.

Changes

File Change
tops/ops/common/chunk_o.py Clamp decay exponents in fwd kernel (_chunk_fwd_o_kernel) and bwd kernel (chunk_simple_gla_bwd_kernel)
tops/ops/simple_gla/chunk.py Clamp decay exponents in intra-attention kernel (_chunk_simple_gla_fwd_intra_kernel)

Context

  • Enables chunk_size=128 for ~13% training speed improvement (see ant-pretrain#199, ant-pretrain#280)
  • Previously chunk_size was limited to 64 to avoid this overflow
  • ant-pretrain PR: primatrix/ant-pretrain#323

Test plan

  • ant-pretrain CI: unit tests pass with chunk_size=128
  • TPU smoke test: no NaN at step 0
  • TPU 1k-step run: stable training, no grad_norm=nan

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Reduced floating-point overflow in attention by clamping exponent arguments before exponentiation.
    • Applied causal masking earlier to avoid large temporary intermediates and ensure stable forward/backward results.
    • Improved gating and precision handling for masked attention regions to prevent spurious values and preserve numerical stability.
  • Chores

    • Added runtime diagnostic logging and simplified internal backward computations for more robust debugging and maintenance.

…k_size>64

For simple GLA with ALiBi slopes, the Toeplitz decay matrix exp(gamma*(i-j))
overflows float32 in the upper triangle when chunk_size=128 and |gamma|>0.69
(head 0 slope=0.707 gives exp(89.8) > fp32 max).

The upper triangle is always masked to 0 by the causal mask, so these values
are never used. The fix clamps the exponent to 0 before exp() in the upper
triangle: exp(where(mask, arg, 0)) instead of exp(arg). This produces
exp(0)=1 (harmless) instead of inf, avoiding 0*inf=NaN in the backward.

Files changed:
- tops/ops/common/chunk_o.py: fwd kernel _chunk_fwd_o_kernel + bwd kernel
  chunk_simple_gla_bwd_kernel
- tops/ops/simple_gla/chunk.py: fwd intra-attention kernel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The PR clamps exponent arguments before calling exp() and moves causal/intra-chunk mask construction earlier inside chunked kernels to prevent overflow. Forward/backward kernels in tops/ops/common/chunk_o.py and tops/ops/simple_gla/chunk.py were updated; no public APIs changed. (≤50 words)

Changes

Cohort / File(s) Summary
Chunked forward/backward kernels
tops/ops/common/chunk_o.py
Create the causal mask earlier; compute exponent differences and use exp(jnp.minimum(diff, 0.0)) (clamped) for intra-chunk attention multipliers and gated/backward terms; move mask declarations and avoid unconditional exp(...) to prevent overflow.
Simple GLA Toeplitz decay & backward path
tops/ops/simple_gla/chunk.py
Clamp Toeplitz decay exponents with exp(min(decay_arg, 0.0)) and apply upper-triangle (i>=j) mask in _chunk_simple_gla_fwd_intra_kernel; add runtime NaN/Inf diagnostics in backward; remove synthetic gk, pass g_gamma directly into chunk_bwd_dh; replace fused backward with separate chunk_bwd_dv and chunk_bwd_dqkwg calls; update imports.

Sequence Diagram(s)

(Skipped — changes are numerical-stability and kernel-local control-flow tweaks and do not introduce multi-component sequential flows requiring diagramming.)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • 0xaskr

Poem

🐇
I nudged the exponents tiny and light,
Masks hopped in early to keep things right.
No runaway exp, no wild overflow,
Kernels hum softly where numbers now flow. 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the primary change: clamping decay exponents to prevent overflow in chunk computations when chunk_size exceeds 64.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/chunk-decay-overflow

Comment @coderabbitai help to get the list of available commands and usage tips.

@pengchengneo
Copy link
Copy Markdown
Author

Kernel 修改详解

问题背景

GLA 采用 chunkwise parallel 算法,在 chunk 内部构建 Toeplitz decay 矩阵:

decay[i,j] = exp(gamma * (pos[i] - pos[j]))

其中 gamma 是 ALiBi 的 per-head slope(负数),pos 是 chunk 内的位置。对于 下三角(i≥j),pos[i]-pos[j] ≥ 0,乘以负的 gamma 后指数 ≤ 0,exp 值在 (0, 1],安全。

但对于 上三角(i<j),pos[i]-pos[j] < 0,乘以负的 gamma 后指数 > 0。当 chunk_size=128 时,最大指数为:

|gamma| × (BT-1) = 0.707 × 127 = 89.8

float32 能表示的最大指数约 88.7(exp(88.7) ≈ 3.4e38),89.8 溢出为 inf

前向传播中,上三角会被 causal mask 清零:0.0 * inf → 虽然不规范但后续 jnp.where(mask, b_A, 0.0) 会覆盖为 0,不影响结果。

反向传播中,b_dA 已经被 causal mask 清零(上三角为 0),但随后 b_dA * exp(g_diff) 中的上三角计算 0 * inf = NaN。NaN 传播到 dqdk 的矩阵乘法中,最终导致 grad_norm = NaN

修改内容

在 3 个 kernel 函数中,共 5 处 exp(decay_arg) 调用前,添加上三角 clamp:

# 修改前
decay = jnp.exp(gamma * (pos[:, None] - pos[None, :]))

# 修改后
mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
decay_arg = gamma * (pos[:, None] - pos[None, :])
decay = jnp.exp(jnp.where(mask, decay_arg, 0.0))

涉及的 kernel 函数:

文件 函数 修改处数
chunk_o.py _chunk_fwd_o_kernel(前向) 2 处(g_ref 分支 + g_gamma_ref 分支)
chunk_o.py chunk_simple_gla_bwd_kernel(反向) 2 处(decay 重计算 + b_dA_gated 计算)
simple_gla/chunk.py _chunk_simple_gla_fwd_intra_kernel(前向 intra-attention) 1 处

为什么不影响正确性

  1. 上三角在 causal attention 中无意义:GLA 是 causal 模型,A[i,j](i<j)表示未来 token 对当前 token 的注意力,在 causal mask 下恒为 0。我们只是把 exp 的输入从一个会溢出的正数改为 0,产生 exp(0)=1,但该位置后续会被 jnp.where(mask, b_A, 0.0) 清零,所以最终值不变。

  2. 下三角完全不受影响jnp.where(mask, decay_arg, 0.0)mask=True(下三角)时取原始 decay_arg,计算路径与修改前完全相同。

  3. 对角线不受影响i==jdecay_arg=0exp(0)=1,也是原始行为。

  4. 数学等价性:设 causal mask 为 $M$,修改前计算 $M \odot (A \cdot \exp(D))$,修改后计算 $M \odot (A \cdot \exp(M \odot D))$。由于 $M$ 的支撑集(下三角)上 $D$ 的值不变,而 $M$ 的零集(上三角)上无论 $\exp$ 值为多少都会被清零,所以两者数学上严格等价。区别仅在于避免了中间结果 infNaN

预期效果

  • 修复 chunk_size=128 时的 grad_norm=NaN:消除上三角 0 * inf = NaN 的数值问题
  • 解锁 ~13% 训练加速:chunk_size 从 64 提升到 128(参考 ant-pretrain#280 的性能数据)
  • 对 chunk_size=64 无影响:chunk_size=64 时最大指数为 0.707 × 63 = 44.5,远未溢出,clamp 不会被触发(where 的 true 分支走原始值)
  • 向前兼容:即使未来进一步增大 chunk_size(如 256),只要 causal mask 保持,此 fix 仍然有效

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)

314-316: ⚠️ Potential issue | 🟡 Minor

Add assertion or clamping to prevent overflow in inter-chunk gating computation.

The inter-chunk computation at line 315 g_exp = jnp.exp(gamma * pos) can overflow if gamma > 0 and BT = 128 (e.g., 0.707 * 128 ≈ 90.5 > fp32_max ≈ 88.7). While all tests use negative g_gamma (which decays toward 0 as designed), there is no runtime assertion enforcing this constraint.

Notably, the intra-chunk kernel (line 208) explicitly clamps exponents to prevent overflow, but the inter-chunk path lacks this protection. Consider adding an assertion assert jnp.all(gamma <= 0) at line 313, or match the intra-chunk pattern by clamping negative exponents in gamma * pos before the exp.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 314 - 316, The inter-chunk gating
computation in chunk.py (variables gamma, pos, g_exp) can overflow when gamma>0
for large BT; either add a runtime assertion like assert jnp.all(gamma <= 0)
before computing g_exp or clamp the exponent input to safe range (match the
intra-chunk kernel pattern) by bounding gamma * pos to a max value (e.g.,
jnp.minimum(gamma * pos, MAX_EXP_ARG)) before calling jnp.exp so g_exp =
jnp.exp(clamped_gamma_pos); update the code path that computes g_exp to use the
assertion or clamped value to prevent fp32 overflow.
🧹 Nitpick comments (3)
tops/ops/common/chunk_o.py (2)

387-398: Reference implementation chunk_fwd_o_ref may overflow for large chunk_size.

Similar to the _chunk_simple_gla_fwd_intra_ref issue, this reference computes unmasked exponentials:

A = A * exp(g_c[..., :, None] - g_c[..., None, :])  # line 391
A = A * exp(ramp[..., :, None] - ramp[..., None, :])[None, None]  # line 398

The causal mask is applied afterward at line 402, but by then overflow may have occurred. If this reference is used for numerical testing against the fixed Pallas kernels, consider applying the same masked exponentiation pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_o.py` around lines 387 - 398, The reference forward in
chunk_fwd_o_ref can overflow because A is multiplied by unmasked exp(...) terms
(variables g_c and ramp) before the causal mask is applied; fix by applying the
causal mask to the exponent inputs so only allowed positions are exponentiated
(e.g., compute delta = g_c[..., :, None] - g_c[..., None, :] and mask out
non-causal entries by replacing them with -inf or a large negative constant
before calling exp, and do the same for ramp[..., :, None] - ramp[..., None, :]
), ensuring A and o_inter use exp(masked_delta) so overflow cannot occur before
the mask is enforced.

527-541: Reference chunk_bwd_dv has similar overflow vulnerability.

The backward gradient reference computes:

  • A = A * exp(g_c[..., None, :] - g_c[..., :, None]) (line 534)
  • A = A * exp(ramp[..., None, :] - ramp[..., :, None]) (line 541)

Note: This uses an upper-triangular mask (line 544), so the overflow-prone region differs. For g_c[..., None, :] - g_c[..., :, None], the upper triangle (j > i) has negative exponents (safe), while the lower triangle has positive exponents that are masked out.

The current pattern (mask after exp) is correct for the lower triangle being masked, but verify this matches the kernel behavior if testing against it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_o.py` around lines 527 - 541, The backward routine
chunk_bwd_dv has the same overflow risk when computing A = A * exp(g_c[...,
None, :] - g_c[..., :, None]) and A = A * exp(ramp[..., None, :] - ramp[..., :,
None]) (and dv_inter updates) — fix by applying the same masking strategy used
in the forward path: mask out the positions that will be zeroed by the
upper-triangular mask (or otherwise not used) before calling exp so you never
exponentiate large positive differences, or alternatively keep the current "mask
after exp" approach only if you run the kernel tests and confirm behavior
matches the reference kernel; check and adjust chunk_bwd_dv symbols A, dv_inter,
g_c, g_gamma, and the triangular mask to be consistent with the forward
implementation and kernel behavior.
tops/ops/simple_gla/chunk.py (1)

60-65: Reference implementation may still overflow for large chunk_size.

The reference implementation _chunk_simple_gla_fwd_intra_ref computes the decay without clamping:

decay = jnp.exp(g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :]))

If this reference is used for testing against the fixed Pallas kernel with chunk_size=128, the test may produce incorrect results or NaNs from the reference while the kernel is correct.

♻️ Suggested fix for consistency
     # Toeplitz decay mask: exp(g_gamma[h] * (i - j))
     g_h = g_gamma.reshape(H)
     pos = jnp.arange(1, C + 1, dtype=jnp.float32)
-    # decay[i, h, j] = exp(g_h[h] * (pos[i] - pos[j]))
-    decay = jnp.exp(g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :]))
-    A = A * decay[None, None]  # broadcast over B, NT
+    # decay[i, h, j] = exp(g_h[h] * (pos[i] - pos[j])), clamped for upper triangle
+    mask = jnp.arange(C)[:, None] >= jnp.arange(C)[None, :]  # (C, C)
+    decay_arg = g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :])
+    decay = jnp.exp(jnp.where(mask[:, None, :], decay_arg, 0.0))
+    A = A * decay[None, None]
+    A = jnp.where(mask[None, None, :, None, :], A, 0.0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 60 - 65, The reference
implementation computes decay = jnp.exp(g_h[None,...] * (pos[:,None,None] -
pos[None,None,:])) which can overflow for large chunk_size; modify
_chunk_simple_gla_fwd_intra_ref to compute the exponent first (e =
g_h[None,:,None] * (pos[:,None,None] - pos[None,None,:])) and clamp it (e =
jnp.clip(e, a_min=..., a_max=...); e.g. [-50,50]) before calling jnp.exp, then
use decay = jnp.exp(e) and keep the subsequent A = A * decay[None,None]
broadcast. Ensure you reference g_gamma/g_h, pos, decay and the function
_chunk_simple_gla_fwd_intra_ref when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 314-316: The inter-chunk gating computation in chunk.py (variables
gamma, pos, g_exp) can overflow when gamma>0 for large BT; either add a runtime
assertion like assert jnp.all(gamma <= 0) before computing g_exp or clamp the
exponent input to safe range (match the intra-chunk kernel pattern) by bounding
gamma * pos to a max value (e.g., jnp.minimum(gamma * pos, MAX_EXP_ARG)) before
calling jnp.exp so g_exp = jnp.exp(clamped_gamma_pos); update the code path that
computes g_exp to use the assertion or clamped value to prevent fp32 overflow.

---

Nitpick comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 387-398: The reference forward in chunk_fwd_o_ref can overflow
because A is multiplied by unmasked exp(...) terms (variables g_c and ramp)
before the causal mask is applied; fix by applying the causal mask to the
exponent inputs so only allowed positions are exponentiated (e.g., compute delta
= g_c[..., :, None] - g_c[..., None, :] and mask out non-causal entries by
replacing them with -inf or a large negative constant before calling exp, and do
the same for ramp[..., :, None] - ramp[..., None, :] ), ensuring A and o_inter
use exp(masked_delta) so overflow cannot occur before the mask is enforced.
- Around line 527-541: The backward routine chunk_bwd_dv has the same overflow
risk when computing A = A * exp(g_c[..., None, :] - g_c[..., :, None]) and A = A
* exp(ramp[..., None, :] - ramp[..., :, None]) (and dv_inter updates) — fix by
applying the same masking strategy used in the forward path: mask out the
positions that will be zeroed by the upper-triangular mask (or otherwise not
used) before calling exp so you never exponentiate large positive differences,
or alternatively keep the current "mask after exp" approach only if you run the
kernel tests and confirm behavior matches the reference kernel; check and adjust
chunk_bwd_dv symbols A, dv_inter, g_c, g_gamma, and the triangular mask to be
consistent with the forward implementation and kernel behavior.

In `@tops/ops/simple_gla/chunk.py`:
- Around line 60-65: The reference implementation computes decay =
jnp.exp(g_h[None,...] * (pos[:,None,None] - pos[None,None,:])) which can
overflow for large chunk_size; modify _chunk_simple_gla_fwd_intra_ref to compute
the exponent first (e = g_h[None,:,None] * (pos[:,None,None] -
pos[None,None,:])) and clamp it (e = jnp.clip(e, a_min=..., a_max=...); e.g.
[-50,50]) before calling jnp.exp, then use decay = jnp.exp(e) and keep the
subsequent A = A * decay[None,None] broadcast. Ensure you reference g_gamma/g_h,
pos, decay and the function _chunk_simple_gla_fwd_intra_ref when making the
change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ec3e179-0de5-4251-b5ea-199e6091d887

📥 Commits

Reviewing files that changed from the base of the PR and between abe6ccb and fae0e06.

📒 Files selected for processing (2)
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py

Replace exp(where(mask, x, 0)) with exp(minimum(x, 0)) in all decay
computations.  The where+exp pattern can produce intermediate inf on TPU
if the compiler evaluates exp before the select -- minimum guarantees all
exp inputs are ≤ 0, preventing overflow entirely.

Also switch chunk_simple_gla_bwd to pass g_gamma directly to chunk_bwd_dh
instead of building a synthetic gk tensor, using the kernel's native
scalar-decay path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@pengchengneo
Copy link
Copy Markdown
Author

Update: switched from where to minimum clamping

The initial exp(where(mask, x, 0)) fix was correct on CPU (interpret mode) but still produced NaN on TPU. Root cause: the TPU compiler may evaluate exp(x) for all elements before applying where/select, producing intermediate inf that leaks into the result.

New approach: exp(minimum(x, 0))

Replaced all exp(jnp.where(mask, x, 0.0)) with exp(jnp.minimum(x, 0.0)):

  • minimum(x, 0) guarantees all exp inputs are ≤ 0, so exp output is always in (0, 1]
  • No intermediate inf anywhere — the compiler cannot reorder minimum past exp in a way that produces overflow
  • For the causal lower triangle, x = γ*(i−j) ≤ 0 (since γ < 0 and i ≥ j), so minimum is a no-op
  • For the upper triangle, x > 0, so minimum(x, 0) = 0 and exp(0) = 1 — same as before, harmless

Also: use g_gamma directly in backward dh

Changed chunk_simple_gla_bwd to pass g_gamma directly to chunk_bwd_dh instead of building a synthetic gk tensor via _build_gk_from_gamma. This:

  • Uses the kernel's native scalar-decay path (simpler code path)
  • Avoids creating a large [B, T, H, K] temporary tensor
  • Eliminates a potential source of numerical divergence between forward and backward

Verification

  • CPU test (interpret mode) passes with both approaches
  • CI triggered — awaiting TPU results

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tops/ops/common/chunk_o.py (1)

190-197: Please add a regression that actually hits the float32 overflow boundary.

The closest backward coverage in tests/ref/simple_gla/test_chunk_simple_gla.py, Lines 275-325, runs in fp64 with chunk_size=16, so it never exercises the 0 * Inf -> NaN case this patch is fixing. A float32 chunk_size=128 case with a large-magnitude g_gamma should assert finite outputs and grads here.

Also applies to: 224-227

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_o.py` around lines 190 - 197, Add a regression test
that runs the chunked algorithm with dtype=float32 and chunk_size=128 using a
large-magnitude decay/gamma (the b_gamma / g_gamma used to build decay_arg) so
the code path computing pos, decay_arg, and decay actually hits the float32
overflow boundary; run both forward and backward (compute gradients) and assert
all outputs and grads are finite (no Inf/NaN). Ensure the test forces float32
(cast inputs and g_gamma), uses the same chunking code that computes pos, mask,
decay_arg, decay, and fails if any value is non-finite so the 0 * Inf -> NaN
regression is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 70-75: The clamp to jnp.minimum(b_g_diff, 0.0) on b_g_diff (b_g,
b_g_diff, b_A) changes semantics relative to chunk_fwd_o_ref() because that
reference computes exp(g_row - g_col) then masks causally; to fix, either (A)
preserve reference semantics by computing b_A = b_A * exp(b_g_diff) and then
applying the causal mask (same mask used in chunk_fwd_o_ref) before or after
multiplication, or (B) enforce a monotone-decay invariant for g at public entry
points (validate/assert that g is non-increasing within chunks) and document it;
update the code paths that set/validate b_g (and any public APIs that accept g)
accordingly so behavior is deterministic and matches chunk_fwd_o_ref.

---

Nitpick comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 190-197: Add a regression test that runs the chunked algorithm
with dtype=float32 and chunk_size=128 using a large-magnitude decay/gamma (the
b_gamma / g_gamma used to build decay_arg) so the code path computing pos,
decay_arg, and decay actually hits the float32 overflow boundary; run both
forward and backward (compute gradients) and assert all outputs and grads are
finite (no Inf/NaN). Ensure the test forces float32 (cast inputs and g_gamma),
uses the same chunking code that computes pos, mask, decay_arg, decay, and fails
if any value is non-finite so the 0 * Inf -> NaN regression is exercised.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9d4bb26f-85a5-4168-b9d8-c6a50dcd0a50

📥 Commits

Reviewing files that changed from the base of the PR and between fae0e06 and 55dbc84.

📒 Files selected for processing (2)
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py
✅ Files skipped from review due to trivial changes (1)
  • tops/ops/simple_gla/chunk.py

Comment on lines +70 to +75
# Clamp exponents to ≤0 before exp to prevent overflow on TPU.
# Upper triangle (i<j) has positive exponents; minimum(x, 0) guarantees
# all exp inputs are non-positive, avoiding intermediate inf in the
# compiler pipeline (TPU may evaluate exp before where/select).
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * exp(jnp.minimum(b_g_diff, 0.0))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This minimum(..., 0) optimization needs an explicit monotone-decay contract for g.

chunk_fwd_o_ref() still defines the scalar-gate path as exp(g_row - g_col) and only applies the causal mask afterwards at Lines 389-403. The new kernel path here only matches that when every causal g_i - g_j is already non-positive. That's obvious for fixed negative g_gamma, but not for an arbitrary chunk-local g, so a non-monotone chunk would now silently clamp valid lower-triangle weights instead of matching the reference. Please either keep g_ref on a masked formulation or make that monotone-decay invariant explicit and enforced at the public entry points.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_o.py` around lines 70 - 75, The clamp to
jnp.minimum(b_g_diff, 0.0) on b_g_diff (b_g, b_g_diff, b_A) changes semantics
relative to chunk_fwd_o_ref() because that reference computes exp(g_row - g_col)
then masks causally; to fix, either (A) preserve reference semantics by
computing b_A = b_A * exp(b_g_diff) and then applying the causal mask (same mask
used in chunk_fwd_o_ref) before or after multiplication, or (B) enforce a
monotone-decay invariant for g at public entry points (validate/assert that g is
non-increasing within chunks) and document it; update the code paths that
set/validate b_g (and any public APIs that accept g) accordingly so behavior is
deterministic and matches chunk_fwd_o_ref.

pengchengneo and others added 2 commits March 31, 2026 15:10
Adds jax.debug.print calls after each kernel in the backward pass to
identify which kernel produces NaN:
- After chunk_fwd_h (h recomputation)
- After chunk_bwd_dh (hidden state gradient)
- After chunk_simple_gla_bwd_o_pl (fused dq/dk/dv)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…kernel

The fused Pallas backward kernel (chunk_simple_gla_bwd_o_pl) produces NaN
on TPU v7x with BT=128 due to suspected compiler issues with the large
number of 128x128 matrix operations in a single kernel.  CPU interpret mode
produces correct results, confirming the algorithm is correct.

Replace the fused Pallas kernel call with the JAX reference implementations
(chunk_bwd_dv + chunk_bwd_dqkwg) which use standard jnp.matmul compiled by
XLA.  This is more robust as XLA compiles each matmul independently rather
than fusing all 8+ dot products into one kernel.

The Pallas kernels for chunk_fwd_h and chunk_bwd_dh are kept as they are
proven stable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tops/ops/simple_gla/chunk.py (2)

10-10: Remove unused import chunk_simple_gla_bwd_o_pl.

The fused Pallas backward kernel has been replaced with JAX reference implementations (chunk_bwd_dv and chunk_bwd_dqkwg), but chunk_simple_gla_bwd_o_pl is still imported and no longer used in this file.

Proposed fix
-from tops.ops.common.chunk_o import chunk_fwd_o, chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg
+from tops.ops.common.chunk_o import chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` at line 10, Remove the unused import
chunk_simple_gla_bwd_o_pl from the import statement in
tops.ops.simple_gla.chunk; currently the line imports chunk_fwd_o,
chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg but
chunk_simple_gla_bwd_o_pl is no longer referenced (we now use chunk_bwd_dv and
chunk_bwd_dqkwg), so update the import to only include the used symbols
(chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg).

530-543: Remove unused helper function _build_gk_from_gamma.

The function is not called anywhere in the codebase and is no longer needed since the backward path uses g_gamma directly instead of synthesizing a gk tensor.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 530 - 543, Remove the unused
helper function _build_gk_from_gamma: delete its entire definition (the function
_build_gk_from_gamma(g_gamma: jax.Array, B: int, T: int, H: int, K: int,
chunk_size: int) -> jax.Array and its body) and any tests or imports that only
exist to support it; also search for and remove any stray references to
_build_gk_from_gamma elsewhere so the codebase uses g_gamma directly as
intended.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tops/ops/simple_gla/chunk.py`:
- Line 10: Remove the unused import chunk_simple_gla_bwd_o_pl from the import
statement in tops.ops.simple_gla.chunk; currently the line imports chunk_fwd_o,
chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg but
chunk_simple_gla_bwd_o_pl is no longer referenced (we now use chunk_bwd_dv and
chunk_bwd_dqkwg), so update the import to only include the used symbols
(chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg).
- Around line 530-543: Remove the unused helper function _build_gk_from_gamma:
delete its entire definition (the function _build_gk_from_gamma(g_gamma:
jax.Array, B: int, T: int, H: int, K: int, chunk_size: int) -> jax.Array and its
body) and any tests or imports that only exist to support it; also search for
and remove any stray references to _build_gk_from_gamma elsewhere so the
codebase uses g_gamma directly as intended.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4005cbf0-2f24-4fe0-a5b9-b493d55436c8

📥 Commits

Reviewing files that changed from the base of the PR and between e85894d and e90cf26.

📒 Files selected for processing (1)
  • tops/ops/simple_gla/chunk.py

Adds jax.debug.print NaN/Inf/absmax checks after:
- Input tensors (q, k, v, do)
- h recomputed from chunk_fwd_h
- dh from chunk_bwd_dh
- dv from chunk_bwd_dv
- dq, dk from chunk_bwd_dqkwg

This will pinpoint exactly which operation produces NaN on TPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)

603-673: ⚠️ Potential issue | 🟠 Major

Guard or remove unconditional NaN/Inf debug prints in backward path.

These checks execute on every backward pass and can add substantial training overhead and log noise.

💡 Suggested change (debug checks off by default)
 def chunk_simple_gla_bwd(
@@
 ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]:
@@
-    def _nan_check(name, x):
+    enable_nan_checks = False
+
+    def _nan_check(name, x):
       """Print NaN/Inf stats for a tensor."""
@@
-    # Check inputs
-    _nan_check("q_in", q)
-    _nan_check("k_in", k)
-    _nan_check("v_in", v)
-    _nan_check("do_in", do)
+    if enable_nan_checks:
+        _nan_check("q_in", q)
+        _nan_check("k_in", k)
+        _nan_check("v_in", v)
+        _nan_check("do_in", do)
@@
-    _nan_check("h_recomputed", h)
+    if enable_nan_checks:
+        _nan_check("h_recomputed", h)
@@
-    _nan_check("dh_result", dh)
+    if enable_nan_checks:
+        _nan_check("dh_result", dh)
@@
-    _nan_check("dv_result", dv)
+    if enable_nan_checks:
+        _nan_check("dv_result", dv)
@@
-    _nan_check("dq_result", dq)
-    _nan_check("dk_result", dk)
+    if enable_nan_checks:
+        _nan_check("dq_result", dq)
+        _nan_check("dk_result", dk)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 603 - 673, The unconditional
NaN/Inf debug prints (the helper _nan_check and its calls around chunk_fwd_h,
chunk_bwd_dh, chunk_bwd_dv, and chunk_bwd_dqkwg) should be guarded or removed to
avoid per-backward-pass overhead; add a boolean runtime flag (e.g.,
enable_gla_nan_check defaulting to False) or respect a logging/debug level
inside _nan_check and wrap the existing calls so they only execute when that
flag is true (or remove the calls entirely), ensuring the check remains
available for debugging but is off by default in normal training.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 658-671: The backward calls to chunk_bwd_dv and chunk_bwd_dqkwg
are missing the sequence-boundary info, which can cause tokens from different
packed sequences to be mixed; update the calls inside chunk_simple_gla_bwd to
pass the validated varlen boundaries (e.g., cu_seqlens_cpu or the existing
cu_seqlens / max_seqlen variables used earlier) into both chunk_bwd_dv(...) and
chunk_bwd_dqkwg(...), preserving the same parameter name/order those helper
functions expect so they perform chunked, varlen-safe processing.

---

Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 603-673: The unconditional NaN/Inf debug prints (the helper
_nan_check and its calls around chunk_fwd_h, chunk_bwd_dh, chunk_bwd_dv, and
chunk_bwd_dqkwg) should be guarded or removed to avoid per-backward-pass
overhead; add a boolean runtime flag (e.g., enable_gla_nan_check defaulting to
False) or respect a logging/debug level inside _nan_check and wrap the existing
calls so they only execute when that flag is true (or remove the calls
entirely), ensuring the check remains available for debugging but is off by
default in normal training.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dc1fe381-8060-4e0e-8427-b1f1f7695aeb

📥 Commits

Reviewing files that changed from the base of the PR and between e90cf26 and af1a891.

📒 Files selected for processing (1)
  • tops/ops/simple_gla/chunk.py

Comment on lines +658 to 671
dv = chunk_bwd_dv(
q, k, do, dh,
g_gamma=g_gamma,
scale=scale,
chunk_size=C,
)
_nan_check("dv_result", dv)

# 4. Fused dq/dk/dv via simple GLA pallas kernel
dq, dk, dv = chunk_simple_gla_bwd_o_pl(
q, k, v, g_gamma, h, do, dh,
scale=scale, chunk_size=C,
cu_seqlens_dev=cu_seqlens_dev,
interpret=interpret,
dq, dk, _, _ = chunk_bwd_dqkwg(
q, k, v, h, do, dh,
g_gamma=g_gamma,
scale=scale,
chunk_size=C,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass sequence boundaries into chunk_bwd_dv/chunk_bwd_dqkwg for varlen correctness.

chunk_simple_gla_bwd validates chunked cu_seqlens_cpu, but the new reference backward calls omit those boundaries, which can mix tokens across packed sequences.

✅ Suggested fix
     dv = chunk_bwd_dv(
         q, k, do, dh,
         g_gamma=g_gamma,
         scale=scale,
+        cu_seqlens_cpu=cu_seqlens_cpu,
         chunk_size=C,
     )
@@
     dq, dk, _, _ = chunk_bwd_dqkwg(
         q, k, v, h, do, dh,
         g_gamma=g_gamma,
         scale=scale,
+        cu_seqlens=cu_seqlens_cpu,
         chunk_size=C,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 658 - 671, The backward calls to
chunk_bwd_dv and chunk_bwd_dqkwg are missing the sequence-boundary info, which
can cause tokens from different packed sequences to be mixed; update the calls
inside chunk_simple_gla_bwd to pass the validated varlen boundaries (e.g.,
cu_seqlens_cpu or the existing cu_seqlens / max_seqlen variables used earlier)
into both chunk_bwd_dv(...) and chunk_bwd_dqkwg(...), preserving the same
parameter name/order those helper functions expect so they perform chunked,
varlen-safe processing.

The JAX reference backward functions chunk_bwd_dv and chunk_bwd_dqkwg
compute exp(ramp_diff) for the decay matrix where masked entries can
have positive exponents up to +89.8, producing inf in float32.

While jnp.where masks these entries to 0 afterward, the intermediate
inf (or 0*inf=NaN if the matrix entry is exactly zero) is computed on
TPU hardware before the select. Adding jnp.minimum(diff, 0) before
exp() prevents overflow entirely — for unmasked entries the exponents
are already non-positive, so the clamp is mathematically neutral.

This mirrors the fix already applied to the forward Pallas kernel
(_chunk_fwd_o_kernel) at line 75/83.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Mar 31, 2026

This [https://github.com/https://github.com//pull/131] has been merged. Try pointing pallas-kernel to the latest main branch directly in the pretrain repository. @pengchengneo

@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Apr 1, 2026

this ci(https://github.com/primatrix/ant-pretrain/pull/316) is passed, It has been verified that the main branch is working properly. close this pr.

@0xaskr 0xaskr closed this Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants