fix: clamp upper-triangle decay exponents to prevent overflow at chunk_size>64#117
fix: clamp upper-triangle decay exponents to prevent overflow at chunk_size>64#117pengchengneo wants to merge 6 commits intomainfrom
Conversation
…k_size>64 For simple GLA with ALiBi slopes, the Toeplitz decay matrix exp(gamma*(i-j)) overflows float32 in the upper triangle when chunk_size=128 and |gamma|>0.69 (head 0 slope=0.707 gives exp(89.8) > fp32 max). The upper triangle is always masked to 0 by the causal mask, so these values are never used. The fix clamps the exponent to 0 before exp() in the upper triangle: exp(where(mask, arg, 0)) instead of exp(arg). This produces exp(0)=1 (harmless) instead of inf, avoiding 0*inf=NaN in the backward. Files changed: - tops/ops/common/chunk_o.py: fwd kernel _chunk_fwd_o_kernel + bwd kernel chunk_simple_gla_bwd_kernel - tops/ops/simple_gla/chunk.py: fwd intra-attention kernel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThe PR clamps exponent arguments before calling exp() and moves causal/intra-chunk mask construction earlier inside chunked kernels to prevent overflow. Forward/backward kernels in Changes
Sequence Diagram(s)(Skipped — changes are numerical-stability and kernel-local control-flow tweaks and do not introduce multi-component sequential flows requiring diagramming.) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Kernel 修改详解问题背景GLA 采用 chunkwise parallel 算法,在 chunk 内部构建 Toeplitz decay 矩阵: 其中 但对于 上三角(i<j), 而 前向传播中,上三角会被 causal mask 清零: 反向传播中, 修改内容在 3 个 kernel 函数中,共 5 处 # 修改前
decay = jnp.exp(gamma * (pos[:, None] - pos[None, :]))
# 修改后
mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
decay_arg = gamma * (pos[:, None] - pos[None, :])
decay = jnp.exp(jnp.where(mask, decay_arg, 0.0))涉及的 kernel 函数:
为什么不影响正确性
预期效果
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)
314-316:⚠️ Potential issue | 🟡 MinorAdd assertion or clamping to prevent overflow in inter-chunk gating computation.
The inter-chunk computation at line 315
g_exp = jnp.exp(gamma * pos)can overflow ifgamma > 0andBT = 128(e.g.,0.707 * 128 ≈ 90.5 > fp32_max ≈ 88.7). While all tests use negativeg_gamma(which decays toward 0 as designed), there is no runtime assertion enforcing this constraint.Notably, the intra-chunk kernel (line 208) explicitly clamps exponents to prevent overflow, but the inter-chunk path lacks this protection. Consider adding an assertion
assert jnp.all(gamma <= 0)at line 313, or match the intra-chunk pattern by clamping negative exponents ingamma * posbefore the exp.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 314 - 316, The inter-chunk gating computation in chunk.py (variables gamma, pos, g_exp) can overflow when gamma>0 for large BT; either add a runtime assertion like assert jnp.all(gamma <= 0) before computing g_exp or clamp the exponent input to safe range (match the intra-chunk kernel pattern) by bounding gamma * pos to a max value (e.g., jnp.minimum(gamma * pos, MAX_EXP_ARG)) before calling jnp.exp so g_exp = jnp.exp(clamped_gamma_pos); update the code path that computes g_exp to use the assertion or clamped value to prevent fp32 overflow.
🧹 Nitpick comments (3)
tops/ops/common/chunk_o.py (2)
387-398: Reference implementationchunk_fwd_o_refmay overflow for large chunk_size.Similar to the
_chunk_simple_gla_fwd_intra_refissue, this reference computes unmasked exponentials:A = A * exp(g_c[..., :, None] - g_c[..., None, :]) # line 391 A = A * exp(ramp[..., :, None] - ramp[..., None, :])[None, None] # line 398The causal mask is applied afterward at line 402, but by then overflow may have occurred. If this reference is used for numerical testing against the fixed Pallas kernels, consider applying the same masked exponentiation pattern.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_o.py` around lines 387 - 398, The reference forward in chunk_fwd_o_ref can overflow because A is multiplied by unmasked exp(...) terms (variables g_c and ramp) before the causal mask is applied; fix by applying the causal mask to the exponent inputs so only allowed positions are exponentiated (e.g., compute delta = g_c[..., :, None] - g_c[..., None, :] and mask out non-causal entries by replacing them with -inf or a large negative constant before calling exp, and do the same for ramp[..., :, None] - ramp[..., None, :] ), ensuring A and o_inter use exp(masked_delta) so overflow cannot occur before the mask is enforced.
527-541: Referencechunk_bwd_dvhas similar overflow vulnerability.The backward gradient reference computes:
A = A * exp(g_c[..., None, :] - g_c[..., :, None])(line 534)A = A * exp(ramp[..., None, :] - ramp[..., :, None])(line 541)Note: This uses an upper-triangular mask (line 544), so the overflow-prone region differs. For
g_c[..., None, :] - g_c[..., :, None], the upper triangle (j > i) has negative exponents (safe), while the lower triangle has positive exponents that are masked out.The current pattern (mask after exp) is correct for the lower triangle being masked, but verify this matches the kernel behavior if testing against it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_o.py` around lines 527 - 541, The backward routine chunk_bwd_dv has the same overflow risk when computing A = A * exp(g_c[..., None, :] - g_c[..., :, None]) and A = A * exp(ramp[..., None, :] - ramp[..., :, None]) (and dv_inter updates) — fix by applying the same masking strategy used in the forward path: mask out the positions that will be zeroed by the upper-triangular mask (or otherwise not used) before calling exp so you never exponentiate large positive differences, or alternatively keep the current "mask after exp" approach only if you run the kernel tests and confirm behavior matches the reference kernel; check and adjust chunk_bwd_dv symbols A, dv_inter, g_c, g_gamma, and the triangular mask to be consistent with the forward implementation and kernel behavior.tops/ops/simple_gla/chunk.py (1)
60-65: Reference implementation may still overflow for large chunk_size.The reference implementation
_chunk_simple_gla_fwd_intra_refcomputes the decay without clamping:decay = jnp.exp(g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :]))If this reference is used for testing against the fixed Pallas kernel with
chunk_size=128, the test may produce incorrect results or NaNs from the reference while the kernel is correct.♻️ Suggested fix for consistency
# Toeplitz decay mask: exp(g_gamma[h] * (i - j)) g_h = g_gamma.reshape(H) pos = jnp.arange(1, C + 1, dtype=jnp.float32) - # decay[i, h, j] = exp(g_h[h] * (pos[i] - pos[j])) - decay = jnp.exp(g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :])) - A = A * decay[None, None] # broadcast over B, NT + # decay[i, h, j] = exp(g_h[h] * (pos[i] - pos[j])), clamped for upper triangle + mask = jnp.arange(C)[:, None] >= jnp.arange(C)[None, :] # (C, C) + decay_arg = g_h[None, :, None] * (pos[:, None, None] - pos[None, None, :]) + decay = jnp.exp(jnp.where(mask[:, None, :], decay_arg, 0.0)) + A = A * decay[None, None] + A = jnp.where(mask[None, None, :, None, :], A, 0.0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 60 - 65, The reference implementation computes decay = jnp.exp(g_h[None,...] * (pos[:,None,None] - pos[None,None,:])) which can overflow for large chunk_size; modify _chunk_simple_gla_fwd_intra_ref to compute the exponent first (e = g_h[None,:,None] * (pos[:,None,None] - pos[None,None,:])) and clamp it (e = jnp.clip(e, a_min=..., a_max=...); e.g. [-50,50]) before calling jnp.exp, then use decay = jnp.exp(e) and keep the subsequent A = A * decay[None,None] broadcast. Ensure you reference g_gamma/g_h, pos, decay and the function _chunk_simple_gla_fwd_intra_ref when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 314-316: The inter-chunk gating computation in chunk.py (variables
gamma, pos, g_exp) can overflow when gamma>0 for large BT; either add a runtime
assertion like assert jnp.all(gamma <= 0) before computing g_exp or clamp the
exponent input to safe range (match the intra-chunk kernel pattern) by bounding
gamma * pos to a max value (e.g., jnp.minimum(gamma * pos, MAX_EXP_ARG)) before
calling jnp.exp so g_exp = jnp.exp(clamped_gamma_pos); update the code path that
computes g_exp to use the assertion or clamped value to prevent fp32 overflow.
---
Nitpick comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 387-398: The reference forward in chunk_fwd_o_ref can overflow
because A is multiplied by unmasked exp(...) terms (variables g_c and ramp)
before the causal mask is applied; fix by applying the causal mask to the
exponent inputs so only allowed positions are exponentiated (e.g., compute delta
= g_c[..., :, None] - g_c[..., None, :] and mask out non-causal entries by
replacing them with -inf or a large negative constant before calling exp, and do
the same for ramp[..., :, None] - ramp[..., None, :] ), ensuring A and o_inter
use exp(masked_delta) so overflow cannot occur before the mask is enforced.
- Around line 527-541: The backward routine chunk_bwd_dv has the same overflow
risk when computing A = A * exp(g_c[..., None, :] - g_c[..., :, None]) and A = A
* exp(ramp[..., None, :] - ramp[..., :, None]) (and dv_inter updates) — fix by
applying the same masking strategy used in the forward path: mask out the
positions that will be zeroed by the upper-triangular mask (or otherwise not
used) before calling exp so you never exponentiate large positive differences,
or alternatively keep the current "mask after exp" approach only if you run the
kernel tests and confirm behavior matches the reference kernel; check and adjust
chunk_bwd_dv symbols A, dv_inter, g_c, g_gamma, and the triangular mask to be
consistent with the forward implementation and kernel behavior.
In `@tops/ops/simple_gla/chunk.py`:
- Around line 60-65: The reference implementation computes decay =
jnp.exp(g_h[None,...] * (pos[:,None,None] - pos[None,None,:])) which can
overflow for large chunk_size; modify _chunk_simple_gla_fwd_intra_ref to compute
the exponent first (e = g_h[None,:,None] * (pos[:,None,None] -
pos[None,None,:])) and clamp it (e = jnp.clip(e, a_min=..., a_max=...); e.g.
[-50,50]) before calling jnp.exp, then use decay = jnp.exp(e) and keep the
subsequent A = A * decay[None,None] broadcast. Ensure you reference g_gamma/g_h,
pos, decay and the function _chunk_simple_gla_fwd_intra_ref when making the
change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2ec3e179-0de5-4251-b5ea-199e6091d887
📒 Files selected for processing (2)
tops/ops/common/chunk_o.pytops/ops/simple_gla/chunk.py
Replace exp(where(mask, x, 0)) with exp(minimum(x, 0)) in all decay computations. The where+exp pattern can produce intermediate inf on TPU if the compiler evaluates exp before the select -- minimum guarantees all exp inputs are ≤ 0, preventing overflow entirely. Also switch chunk_simple_gla_bwd to pass g_gamma directly to chunk_bwd_dh instead of building a synthetic gk tensor, using the kernel's native scalar-decay path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Update: switched from
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tops/ops/common/chunk_o.py (1)
190-197: Please add a regression that actually hits the float32 overflow boundary.The closest backward coverage in
tests/ref/simple_gla/test_chunk_simple_gla.py, Lines 275-325, runs infp64withchunk_size=16, so it never exercises the0 * Inf -> NaNcase this patch is fixing. A float32chunk_size=128case with a large-magnitudeg_gammashould assert finite outputs and grads here.Also applies to: 224-227
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_o.py` around lines 190 - 197, Add a regression test that runs the chunked algorithm with dtype=float32 and chunk_size=128 using a large-magnitude decay/gamma (the b_gamma / g_gamma used to build decay_arg) so the code path computing pos, decay_arg, and decay actually hits the float32 overflow boundary; run both forward and backward (compute gradients) and assert all outputs and grads are finite (no Inf/NaN). Ensure the test forces float32 (cast inputs and g_gamma), uses the same chunking code that computes pos, mask, decay_arg, decay, and fails if any value is non-finite so the 0 * Inf -> NaN regression is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 70-75: The clamp to jnp.minimum(b_g_diff, 0.0) on b_g_diff (b_g,
b_g_diff, b_A) changes semantics relative to chunk_fwd_o_ref() because that
reference computes exp(g_row - g_col) then masks causally; to fix, either (A)
preserve reference semantics by computing b_A = b_A * exp(b_g_diff) and then
applying the causal mask (same mask used in chunk_fwd_o_ref) before or after
multiplication, or (B) enforce a monotone-decay invariant for g at public entry
points (validate/assert that g is non-increasing within chunks) and document it;
update the code paths that set/validate b_g (and any public APIs that accept g)
accordingly so behavior is deterministic and matches chunk_fwd_o_ref.
---
Nitpick comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 190-197: Add a regression test that runs the chunked algorithm
with dtype=float32 and chunk_size=128 using a large-magnitude decay/gamma (the
b_gamma / g_gamma used to build decay_arg) so the code path computing pos,
decay_arg, and decay actually hits the float32 overflow boundary; run both
forward and backward (compute gradients) and assert all outputs and grads are
finite (no Inf/NaN). Ensure the test forces float32 (cast inputs and g_gamma),
uses the same chunking code that computes pos, mask, decay_arg, decay, and fails
if any value is non-finite so the 0 * Inf -> NaN regression is exercised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9d4bb26f-85a5-4168-b9d8-c6a50dcd0a50
📒 Files selected for processing (2)
tops/ops/common/chunk_o.pytops/ops/simple_gla/chunk.py
✅ Files skipped from review due to trivial changes (1)
- tops/ops/simple_gla/chunk.py
| # Clamp exponents to ≤0 before exp to prevent overflow on TPU. | ||
| # Upper triangle (i<j) has positive exponents; minimum(x, 0) guarantees | ||
| # all exp inputs are non-positive, avoiding intermediate inf in the | ||
| # compiler pipeline (TPU may evaluate exp before where/select). | ||
| b_g_diff = b_g[:, None] - b_g[None, :] | ||
| b_A = b_A * exp(jnp.minimum(b_g_diff, 0.0)) |
There was a problem hiding this comment.
This minimum(..., 0) optimization needs an explicit monotone-decay contract for g.
chunk_fwd_o_ref() still defines the scalar-gate path as exp(g_row - g_col) and only applies the causal mask afterwards at Lines 389-403. The new kernel path here only matches that when every causal g_i - g_j is already non-positive. That's obvious for fixed negative g_gamma, but not for an arbitrary chunk-local g, so a non-monotone chunk would now silently clamp valid lower-triangle weights instead of matching the reference. Please either keep g_ref on a masked formulation or make that monotone-decay invariant explicit and enforced at the public entry points.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/common/chunk_o.py` around lines 70 - 75, The clamp to
jnp.minimum(b_g_diff, 0.0) on b_g_diff (b_g, b_g_diff, b_A) changes semantics
relative to chunk_fwd_o_ref() because that reference computes exp(g_row - g_col)
then masks causally; to fix, either (A) preserve reference semantics by
computing b_A = b_A * exp(b_g_diff) and then applying the causal mask (same mask
used in chunk_fwd_o_ref) before or after multiplication, or (B) enforce a
monotone-decay invariant for g at public entry points (validate/assert that g is
non-increasing within chunks) and document it; update the code paths that
set/validate b_g (and any public APIs that accept g) accordingly so behavior is
deterministic and matches chunk_fwd_o_ref.
Adds jax.debug.print calls after each kernel in the backward pass to identify which kernel produces NaN: - After chunk_fwd_h (h recomputation) - After chunk_bwd_dh (hidden state gradient) - After chunk_simple_gla_bwd_o_pl (fused dq/dk/dv) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…kernel The fused Pallas backward kernel (chunk_simple_gla_bwd_o_pl) produces NaN on TPU v7x with BT=128 due to suspected compiler issues with the large number of 128x128 matrix operations in a single kernel. CPU interpret mode produces correct results, confirming the algorithm is correct. Replace the fused Pallas kernel call with the JAX reference implementations (chunk_bwd_dv + chunk_bwd_dqkwg) which use standard jnp.matmul compiled by XLA. This is more robust as XLA compiles each matmul independently rather than fusing all 8+ dot products into one kernel. The Pallas kernels for chunk_fwd_h and chunk_bwd_dh are kept as they are proven stable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tops/ops/simple_gla/chunk.py (2)
10-10: Remove unused importchunk_simple_gla_bwd_o_pl.The fused Pallas backward kernel has been replaced with JAX reference implementations (
chunk_bwd_dvandchunk_bwd_dqkwg), butchunk_simple_gla_bwd_o_plis still imported and no longer used in this file.Proposed fix
-from tops.ops.common.chunk_o import chunk_fwd_o, chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg +from tops.ops.common.chunk_o import chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` at line 10, Remove the unused import chunk_simple_gla_bwd_o_pl from the import statement in tops.ops.simple_gla.chunk; currently the line imports chunk_fwd_o, chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg but chunk_simple_gla_bwd_o_pl is no longer referenced (we now use chunk_bwd_dv and chunk_bwd_dqkwg), so update the import to only include the used symbols (chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg).
530-543: Remove unused helper function_build_gk_from_gamma.The function is not called anywhere in the codebase and is no longer needed since the backward path uses
g_gammadirectly instead of synthesizing agktensor.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 530 - 543, Remove the unused helper function _build_gk_from_gamma: delete its entire definition (the function _build_gk_from_gamma(g_gamma: jax.Array, B: int, T: int, H: int, K: int, chunk_size: int) -> jax.Array and its body) and any tests or imports that only exist to support it; also search for and remove any stray references to _build_gk_from_gamma elsewhere so the codebase uses g_gamma directly as intended.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tops/ops/simple_gla/chunk.py`:
- Line 10: Remove the unused import chunk_simple_gla_bwd_o_pl from the import
statement in tops.ops.simple_gla.chunk; currently the line imports chunk_fwd_o,
chunk_simple_gla_bwd_o_pl, chunk_bwd_dv, chunk_bwd_dqkwg but
chunk_simple_gla_bwd_o_pl is no longer referenced (we now use chunk_bwd_dv and
chunk_bwd_dqkwg), so update the import to only include the used symbols
(chunk_fwd_o, chunk_bwd_dv, chunk_bwd_dqkwg).
- Around line 530-543: Remove the unused helper function _build_gk_from_gamma:
delete its entire definition (the function _build_gk_from_gamma(g_gamma:
jax.Array, B: int, T: int, H: int, K: int, chunk_size: int) -> jax.Array and its
body) and any tests or imports that only exist to support it; also search for
and remove any stray references to _build_gk_from_gamma elsewhere so the
codebase uses g_gamma directly as intended.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4005cbf0-2f24-4fe0-a5b9-b493d55436c8
📒 Files selected for processing (1)
tops/ops/simple_gla/chunk.py
Adds jax.debug.print NaN/Inf/absmax checks after: - Input tensors (q, k, v, do) - h recomputed from chunk_fwd_h - dh from chunk_bwd_dh - dv from chunk_bwd_dv - dq, dk from chunk_bwd_dqkwg This will pinpoint exactly which operation produces NaN on TPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)
603-673:⚠️ Potential issue | 🟠 MajorGuard or remove unconditional NaN/Inf debug prints in backward path.
These checks execute on every backward pass and can add substantial training overhead and log noise.
💡 Suggested change (debug checks off by default)
def chunk_simple_gla_bwd( @@ ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: @@ - def _nan_check(name, x): + enable_nan_checks = False + + def _nan_check(name, x): """Print NaN/Inf stats for a tensor.""" @@ - # Check inputs - _nan_check("q_in", q) - _nan_check("k_in", k) - _nan_check("v_in", v) - _nan_check("do_in", do) + if enable_nan_checks: + _nan_check("q_in", q) + _nan_check("k_in", k) + _nan_check("v_in", v) + _nan_check("do_in", do) @@ - _nan_check("h_recomputed", h) + if enable_nan_checks: + _nan_check("h_recomputed", h) @@ - _nan_check("dh_result", dh) + if enable_nan_checks: + _nan_check("dh_result", dh) @@ - _nan_check("dv_result", dv) + if enable_nan_checks: + _nan_check("dv_result", dv) @@ - _nan_check("dq_result", dq) - _nan_check("dk_result", dk) + if enable_nan_checks: + _nan_check("dq_result", dq) + _nan_check("dk_result", dk)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 603 - 673, The unconditional NaN/Inf debug prints (the helper _nan_check and its calls around chunk_fwd_h, chunk_bwd_dh, chunk_bwd_dv, and chunk_bwd_dqkwg) should be guarded or removed to avoid per-backward-pass overhead; add a boolean runtime flag (e.g., enable_gla_nan_check defaulting to False) or respect a logging/debug level inside _nan_check and wrap the existing calls so they only execute when that flag is true (or remove the calls entirely), ensuring the check remains available for debugging but is off by default in normal training.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 658-671: The backward calls to chunk_bwd_dv and chunk_bwd_dqkwg
are missing the sequence-boundary info, which can cause tokens from different
packed sequences to be mixed; update the calls inside chunk_simple_gla_bwd to
pass the validated varlen boundaries (e.g., cu_seqlens_cpu or the existing
cu_seqlens / max_seqlen variables used earlier) into both chunk_bwd_dv(...) and
chunk_bwd_dqkwg(...), preserving the same parameter name/order those helper
functions expect so they perform chunked, varlen-safe processing.
---
Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 603-673: The unconditional NaN/Inf debug prints (the helper
_nan_check and its calls around chunk_fwd_h, chunk_bwd_dh, chunk_bwd_dv, and
chunk_bwd_dqkwg) should be guarded or removed to avoid per-backward-pass
overhead; add a boolean runtime flag (e.g., enable_gla_nan_check defaulting to
False) or respect a logging/debug level inside _nan_check and wrap the existing
calls so they only execute when that flag is true (or remove the calls
entirely), ensuring the check remains available for debugging but is off by
default in normal training.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dc1fe381-8060-4e0e-8427-b1f1f7695aeb
📒 Files selected for processing (1)
tops/ops/simple_gla/chunk.py
| dv = chunk_bwd_dv( | ||
| q, k, do, dh, | ||
| g_gamma=g_gamma, | ||
| scale=scale, | ||
| chunk_size=C, | ||
| ) | ||
| _nan_check("dv_result", dv) | ||
|
|
||
| # 4. Fused dq/dk/dv via simple GLA pallas kernel | ||
| dq, dk, dv = chunk_simple_gla_bwd_o_pl( | ||
| q, k, v, g_gamma, h, do, dh, | ||
| scale=scale, chunk_size=C, | ||
| cu_seqlens_dev=cu_seqlens_dev, | ||
| interpret=interpret, | ||
| dq, dk, _, _ = chunk_bwd_dqkwg( | ||
| q, k, v, h, do, dh, | ||
| g_gamma=g_gamma, | ||
| scale=scale, | ||
| chunk_size=C, | ||
| ) |
There was a problem hiding this comment.
Pass sequence boundaries into chunk_bwd_dv/chunk_bwd_dqkwg for varlen correctness.
chunk_simple_gla_bwd validates chunked cu_seqlens_cpu, but the new reference backward calls omit those boundaries, which can mix tokens across packed sequences.
✅ Suggested fix
dv = chunk_bwd_dv(
q, k, do, dh,
g_gamma=g_gamma,
scale=scale,
+ cu_seqlens_cpu=cu_seqlens_cpu,
chunk_size=C,
)
@@
dq, dk, _, _ = chunk_bwd_dqkwg(
q, k, v, h, do, dh,
g_gamma=g_gamma,
scale=scale,
+ cu_seqlens=cu_seqlens_cpu,
chunk_size=C,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/simple_gla/chunk.py` around lines 658 - 671, The backward calls to
chunk_bwd_dv and chunk_bwd_dqkwg are missing the sequence-boundary info, which
can cause tokens from different packed sequences to be mixed; update the calls
inside chunk_simple_gla_bwd to pass the validated varlen boundaries (e.g.,
cu_seqlens_cpu or the existing cu_seqlens / max_seqlen variables used earlier)
into both chunk_bwd_dv(...) and chunk_bwd_dqkwg(...), preserving the same
parameter name/order those helper functions expect so they perform chunked,
varlen-safe processing.
The JAX reference backward functions chunk_bwd_dv and chunk_bwd_dqkwg compute exp(ramp_diff) for the decay matrix where masked entries can have positive exponents up to +89.8, producing inf in float32. While jnp.where masks these entries to 0 afterward, the intermediate inf (or 0*inf=NaN if the matrix entry is exactly zero) is computed on TPU hardware before the select. Adding jnp.minimum(diff, 0) before exp() prevents overflow entirely — for unmasked entries the exponents are already non-positive, so the clamp is mathematically neutral. This mirrors the fix already applied to the forward Pallas kernel (_chunk_fwd_o_kernel) at line 75/83. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
This [https://github.com/https://github.com//pull/131] has been merged. Try pointing pallas-kernel to the latest main branch directly in the pretrain repository. @pengchengneo |
|
this ci(https://github.com/primatrix/ant-pretrain/pull/316) is passed, It has been verified that the main branch is working properly. close this pr. |
Summary
exp(gamma * (i - j))overflows float32 when chunk_size=128 and ALiBi slopes are large (head 0: slope=0.707, max exponent=0.707×127=89.8 > 88.7). In backward pass,0 * inf = NaNpropagates through grad_norm.jnp.where(mask, decay_arg, 0.0)beforeexp()— upper-triangle entries produceexp(0)=1(harmless, masked away later) instead ofexp(89.8)=inf.chunk_o.pyand forward intra-attention kernel insimple_gla/chunk.py.Changes
tops/ops/common/chunk_o.py_chunk_fwd_o_kernel) and bwd kernel (chunk_simple_gla_bwd_kernel)tops/ops/simple_gla/chunk.py_chunk_simple_gla_fwd_intra_kernel)Context
Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
Chores