Skip to content

Fix/chunk nan overflow safe exp#131

Merged
0xaskr merged 10 commits intomainfrom
fix/chunk-nan-overflow-safe-exp
Mar 31, 2026
Merged

Fix/chunk nan overflow safe exp#131
0xaskr merged 10 commits intomainfrom
fix/chunk-nan-overflow-safe-exp

Conversation

@0xaskr
Copy link
Copy Markdown
Collaborator

@0xaskr 0xaskr commented Mar 31, 2026

This pull request significantly refactors and improves the GPU test suite for the chunk_simple_gla operator, focusing on more accurate, scalable, and robust cross-framework (Triton vs JAX) testing. It introduces dynamic, shape-aware tolerance scaling, adds new utilities for ALiBi gating parameter generation, and enhances test coverage for numerical stability. The code is also modernized to directly call forward/backward kernels, and test configurations are made more flexible and realistic.

Key changes:

Test infrastructure improvements:

  • Introduced dynamic tolerance scaling with the _tol helper, so numerical tolerances for tests now adapt automatically to sequence length and chunk size, improving reliability of cross-platform comparisons. [1] [2]
  • Added support for configuring chunk_size and dtype per test case, and ensured these are reflected in test IDs and tensor creation. [1] [2] [3] [4]

Refactoring and modernization:

  • Refactored test logic to call chunk_simple_gla_fwd and chunk_simple_gla_bwd directly for both Triton and JAX, removing reliance on autograd for reference gradients and making the tests more explicit and robust.
  • Replaced the internal _torch_to_jax with a new torch_to_jax utility in tests/utils.py that better preserves dtypes and is reused across tests.

ALiBi gating and parameter utilities:

  • Added build_alibi_slopes and make_alibi_g_gamma utilities to tests/utils.py, enabling test cases to use realistic ALiBi gating parameters matching those used in production models.
  • Updated all tests to use these utilities, ensuring gating parameters are consistent and realistic. [1] [2]

Numerical stability and coverage:

  • Added a new test suite to check for NaN/Inf stability with large shapes and realistic ALiBi g_gamma values, especially for chunk_size=128 and multi-layer settings.
  • Increased test coverage for edge cases, including multi-batch, long sequence, and per-head gating.

Minor fixes:

  • Fixed a bug in tops/ops/common/chunk_h.py by removing unnecessary type casts in matrix multiplication, ensuring correct precision is used.

These changes make the test suite much more robust, maintainable, and reflective of real-world usage.

Summary by CodeRabbit

  • Bug Fixes

    • Reduced unnecessary FP32 casts, stabilized gating/decay logic and matmul precision paths to improve numeric stability and consistent dtype behavior.
  • Tests

    • Added JAX interop utilities, per-case tolerance controls, configurable dtype and chunk-size matrices, direct forward/backward comparisons to JAX, GPU runtime env setup, and a NaN/Inf stability test for large chunks.
  • Documentation

    • Updated testing methodology to use high‑precision CPU baselines as the canonical golden reference.

0xaskr and others added 2 commits March 31, 2026 09:11
When chunk_size > 64, exp(g_diff) in the masked-out region of the
attention matrix can overflow f32 (e.g. |gamma|*128 = 90.5 > 88.72).
On TPU, XLA may compile jnp.where(mask, val, 0) as mask*val, causing
0*Inf=NaN to propagate.

Fix: use jnp.where(mask, g_diff, 0.0) before exp() so masked entries
get exp(0)=1 instead of Inf. This is sign-agnostic (works for any
g_gamma, not just negative) and matches Triton's tl.where semantics.

Affected paths:
- _chunk_fwd_o_kernel (g_gamma branch)
- chunk_simple_gla_bwd_kernel (A recomputation + dA gating)
- chunk_fwd_o_ref, chunk_bwd_dv, chunk_bwd_dqkwg (reference impls)

Tested with actual ALiBi slopes (H=16, 20 layers) at chunk_size=128.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Refactors SimpleGLA GPU tests to call Triton forward/back entry points, adds Torch→JAX and ALiBi helpers, introduces per-case dtype/tolerance handling and chunk-size variants, tightens numeric casting and gated-exponential logic across chunk kernels, and adds a NaN/Inf stability test for chunk_size=128.

Changes

Cohort / File(s) Summary
Test utilities & JAX helpers
tests/utils.py
Added torch_to_jax(), build_alibi_slopes(), and make_alibi_g_gamma() to convert Torch tensors to JAX and produce ALiBi g_gamma for tests.
GPU tests — SimpleGLA
tests/ops/simple_gla/test_chunk_simple_gla_gpu.py
Switched Triton coverage to direct _run_triton_fwd/_run_triton_bwd, removed local Torch→JAX helpers in favor of tests.utils.torch_to_jax, added per-case dtype/atol/rtol/max_ulp logic, included chunk_size in case IDs, changed g_gamma generation to ALiBi, and added test_chunk128_no_nan.
Chunk kernels — recurrence (h) numerics
tops/ops/common/chunk_h.py
Removed unconditional .astype(jnp.float32) casts on dot operands (rely on preferred_element_type), retained high-precision settings, and added disable_bounds_checks=True to Pallas compiler params.
Chunk kernels — output/backward/gating numerics
tops/ops/common/chunk_o.py
Removed many unconditional fp32 casts, switched gating/ramp exponentials to jnp.where-masked forms before triangular masking, adjusted accumulation/scaling order, and added targeted dtype alignment casts for inter-/intra-chunk matmuls.
SimpleGLA orchestration
tops/ops/simple_gla/chunk.py
Forward now calls chunk_fwd_h with states_in_fp32=False; removed _build_gk_from_gamma; backward now receives g_gamma directly and omits gk construction.
Documentation
ARCHITECTURE.md, ARCHITECTURE.zh.md
Standardized kernel-correctness testing to use high-precision tops/cpu/ as the canonical golden baseline and tightened error-ordering and test guidelines.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • use scratch #64 — modifies the same chunk_h scan paths and operand casting changes present here.
  • Test/update ci test #109 — overlaps on chunked SimpleGLA tests and core operator implementations (chunk_o.py, chunk_h.py, chunk.py).
  • fix chunk_bwd_dh_coredump #61 — touches backward handling for g/g_gamma and states_in_fp32 changes similar to this PR.

Poem

🐇 I hopped from Torch into JAX with care,
I masked my exps so numerics fare,
Chunks keep their order, dtypes aligned,
No NaNs in the burrow — gradients kind,
A rabbit cheers for kernels repaired.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Fix/chunk nan overflow safe exp' directly relates to the core objective of improving numerical stability and handling NaN/Inf issues in the chunk_simple_gla operator through safer exponential operations and precision fixes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/chunk-nan-overflow-safe-exp

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the simple_gla chunk operations and associated GPU tests to improve numerical stability and alignment with Triton implementations. Key updates include the introduction of scaled tolerances for cross-platform comparisons, refactoring Triton test runners to use direct forward and backward calls, and adding utility functions for ALiBi slope generation. In the JAX kernels, explicit type casting was removed in favor of preferred_element_type in dot products, and exponential calculations were updated for better stability. Feedback identifies a typo in an environment variable and redundant mask calculations in the forward and backward kernels.

import sys
from pathlib import Path
os.environ["TRITON_F32_DEFAULT"] = "ieee"
os.environ["‘TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the environment variable name. The leading character ‘ (a smart quote) should be removed so that it correctly sets TF_GPU_ALLOCATOR. This is important for avoiding OOM errors as mentioned in the comment.

Suggested change
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128

Comment on lines +71 to 74
_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
b_A = b_A * jnp.exp(jnp.where(_mask, b_g_gamma[:, None] - b_g_gamma[None, :], 0.0))

mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The causal mask is being computed twice in this function. You can define mask once before the if g_gamma_ref is not None: block and reuse it inside the block to avoid redundant calculations.

Comment on lines 224 to 225
mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
b_dA = jnp.where(mask, b_dA, 0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mask variable is redefined here, but it was already computed at line 209 and is still in scope. This redundant calculation can be removed.

Suggested change
mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
b_dA = jnp.where(mask, b_dA, 0.0)
b_dA = jnp.where(mask, b_dA, 0.0)

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.py (1)

199-199: Minor: Redundant conversion chain.

The conversion np.array(make_alibi_g_gamma(...)) is redundant since make_alibi_g_gamma returns a jnp.ndarray which can be directly converted via np.asarray(). However, this works correctly as-is.

♻️ Optional simplification
-        g_gamma = torch.from_numpy(np.array(make_alibi_g_gamma(H, 32, 0)))
+        g_gamma = torch.from_numpy(np.asarray(make_alibi_g_gamma(H, 32, 0)))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py` at line 199, The
conversion chain is redundant: instead of wrapping the jnp.ndarray returned by
make_alibi_g_gamma(...) in np.array and then converting to a torch tensor, call
np.asarray on the jnp.ndarray and pass that into torch.from_numpy; update the
assignment to g_gamma to use np.asarray(make_alibi_g_gamma(H, 32, 0)) so you
eliminate the unnecessary np.array call and keep torch.from_numpy(gamma_array)
usage.
tests/utils.py (1)

9-23: Potential precision loss for float64 inputs.

The function converts to .float() (float32) unconditionally before converting to NumPy, which loses precision for float64 inputs. If the caller passes a float64 tensor, the conversion chain is: float64 → float32 → numpy → jax, losing significant precision.

Consider preserving precision for float64:

🔧 Proposed fix
 def torch_to_jax(t: torch.Tensor, dtype=None) -> jax.Array:
     """Convert a torch tensor to a JAX array.

     Args:
         t: Input torch tensor.
         dtype: Target JAX dtype. If None, preserves bfloat16 when the input
             is bfloat16, otherwise uses float32.
     """
-    np_arr = t.detach().cpu().float().numpy()
+    if t.dtype == torch.float64:
+        np_arr = t.detach().cpu().double().numpy()
+    else:
+        np_arr = t.detach().cpu().float().numpy()
     jax_arr = jnp.array(np_arr)
     if dtype is not None:
         return jax_arr.astype(dtype)
     if t.dtype == torch.bfloat16:
         return jax_arr.astype(jnp.bfloat16)
+    if t.dtype == torch.float64:
+        return jax_arr
     return jax_arr
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils.py` around lines 9 - 23, The helper torch_to_jax currently forces
t.detach().cpu().float().numpy(), which downcasts float64 to float32 and loses
precision; change it to preserve the original tensor dtype by using
t.detach().cpu().numpy() (no .float()), then create the JAX array
(jnp.array(np_arr)), and only cast afterwards: if dtype parameter is set, return
.astype(dtype); else if input t.dtype is torch.bfloat16 return
.astype(jnp.bfloat16); otherwise return the JAX array as-is. Update torch_to_jax
to use the preserved np_arr and the same variable names (t, np_arr, jax_arr) so
callers and tests keep working.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Line 9: Replace the incorrect environment variable key containing a stray left
single quotation mark by updating the os.environ assignment (the line that
currently sets os.environ["‘TF_GPU_ALLOCATOR"] = "cuda_malloc_async") to use the
correct key os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"; ensure the
bracketed string uses a normal ASCII double-quoted identifier "TF_GPU_ALLOCATOR"
so the GPU allocator setting is applied correctly.

---

Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Line 199: The conversion chain is redundant: instead of wrapping the
jnp.ndarray returned by make_alibi_g_gamma(...) in np.array and then converting
to a torch tensor, call np.asarray on the jnp.ndarray and pass that into
torch.from_numpy; update the assignment to g_gamma to use
np.asarray(make_alibi_g_gamma(H, 32, 0)) so you eliminate the unnecessary
np.array call and keep torch.from_numpy(gamma_array) usage.

In `@tests/utils.py`:
- Around line 9-23: The helper torch_to_jax currently forces
t.detach().cpu().float().numpy(), which downcasts float64 to float32 and loses
precision; change it to preserve the original tensor dtype by using
t.detach().cpu().numpy() (no .float()), then create the JAX array
(jnp.array(np_arr)), and only cast afterwards: if dtype parameter is set, return
.astype(dtype); else if input t.dtype is torch.bfloat16 return
.astype(jnp.bfloat16); otherwise return the JAX array as-is. Update torch_to_jax
to use the preserved np_arr and the same variable names (t, np_arr, jax_arr) so
callers and tests keep working.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 04e641e6-69d9-4ed0-bdf6-cd3f083a96b7

📥 Commits

Reviewing files that changed from the base of the PR and between 2b47657 and 13a40eb.

📒 Files selected for processing (5)
  • tests/ops/simple_gla/test_chunk_simple_gla_gpu.py
  • tests/utils.py
  • tops/ops/common/chunk_h.py
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py

0xaskr and others added 3 commits March 31, 2026 22:47
* rm gk, rm recompute of h

* add disable_bounds_checks

* revert store h
…straint (#134)

- Define high-precision CPU JAX (float64) as the sole Golden baseline
- GPU implementation compared against high-precision CPU JAX (ε_GPU)
- TPU implementation compared against high-precision CPU JAX (ε_TPU)
- Enforce error bound constraint: ε_TPU ≤ ε_GPU
- Add GPU/TPU test file naming convention (_gpu/_tpu suffix)
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/common/chunk_h.py (1)

336-342: ⚠️ Potential issue | 🟠 Major

Use explicit DotAlgorithm or restore explicit fp32 casts to guarantee precision behavior.

The assumption that precision=lax.Precision.HIGHEST and preferred_element_type=jnp.float32 guarantee float32 accumulation for bfloat16 inputs is incorrect. The JAX documentation clarifies that preferred_element_type is a hint to the compiler, not a guarantee—backend behavior varies. For bfloat16 inputs, the dot product may accumulate at bf16 precision instead of float32, producing different numerical results than the explicit .astype(jnp.float32) casts that were removed.

To restore guaranteed float32 accumulation, either:

  1. Use an explicit DotAlgorithm (e.g., lax.Precision.BF16_BF16_F32 or equivalent) where supported on your target hardware, or
  2. Restore the explicit .astype(jnp.float32) casts on b_k and b_v before the dot_general call.

Do not rely on implicit compiler behavior for precision-critical operations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 336 - 342, The dot_general call
using b_k and b_v relies on preferred_element_type and lax.Precision.HIGHEST
which is not guaranteed to force FP32 accumulation; update the lax.dot_general
invocation in chunk_h.py to ensure FP32 accumulation by either (A) specifying a
backend-supported algorithm enum on lax.dot_general (e.g., the appropriate
DotAlgorithm/precision constant such as a BF16->F32 accumulation variant instead
of lax.Precision.HIGHEST) or (B) restore explicit casts: cast b_k and b_v to
jnp.float32 before calling lax.dot_general so accumulation is performed in
float32; apply the change where kv = lax.dot_general(...) is defined and keep
references to b_k, b_v, preferred_element_type and lax.Precision consistent with
the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@ARCHITECTURE.md`:
- Around line 83-95: The GPU test currently only compares Triton vs JAX; add the
high-precision CPU reference check as in the TPU test by importing the reference
from tops.cpu (the same module used in the TPU test), computing the CPU
reference outputs (e.g., o_cpu) with the same inputs used for the JAX Pallas run
(o_jax), and then call compare_tensor("output", o_jax, o_cpu, atol=atol,
rtol=rtol, max_ulp=max_ulp). Keep the Triton vs JAX assertion unchanged but
ensure the JAX→CPU assertion uses the same tolerance variables and mirrors the
TPU test pattern in test_chunk_simple_gla_tpu.py so the GPU test validates JAX
Pallas against tops.cpu/.

In `@ARCHITECTURE.zh.md`:
- Around line 81-88: 为使“ε_TPU ≤ ε_GPU”可执行,补充统一的 ε 定义和断言规则:在文档(Kernel Correctness
Methodology)中明确 ε 的计算方法(例如支持选项 max_abs、RMSE、ULP,并推荐默认项,如使用 compare_tensor 返回的
max_abs 和 RMSE 两项作为标准),规定对比基准为 tops/cpu/ 高精度 JAX 实现并说明 compare_tensor
的统计输出格式;规定执行流程为“同一测试用例下同时在 GPU/TPU 上跑 Pallas/JAX 内核并与 CPU 参考比较,记录 ε_GPU 和
ε_TPU”,并定义聚合与判定规则(例如单用例必须满足 ε_TPU ≤ ε_GPU 且总体用例集按 p%
通过,或使用均值/百分位汇总),以及失败时的报告字段(用例 id、metric 名称、ε_GPU、ε_TPU、raw 差异)以便统一实现和自动化判断。

---

Outside diff comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 336-342: The dot_general call using b_k and b_v relies on
preferred_element_type and lax.Precision.HIGHEST which is not guaranteed to
force FP32 accumulation; update the lax.dot_general invocation in chunk_h.py to
ensure FP32 accumulation by either (A) specifying a backend-supported algorithm
enum on lax.dot_general (e.g., the appropriate DotAlgorithm/precision constant
such as a BF16->F32 accumulation variant instead of lax.Precision.HIGHEST) or
(B) restore explicit casts: cast b_k and b_v to jnp.float32 before calling
lax.dot_general so accumulation is performed in float32; apply the change where
kv = lax.dot_general(...) is defined and keep references to b_k, b_v,
preferred_element_type and lax.Precision consistent with the chosen approach.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b419bbf5-1e1a-4df2-9772-6ce3908385ec

📥 Commits

Reviewing files that changed from the base of the PR and between 494328c and e8b09fb.

📒 Files selected for processing (4)
  • ARCHITECTURE.md
  • ARCHITECTURE.zh.md
  • tops/ops/common/chunk_h.py
  • tops/ops/simple_gla/chunk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/ops/simple_gla/chunk.py

Comment on lines +83 to +95
### 3.2 Kernel Correctness Methodology
* Test types within the `tests/` directory are strictly restricted to the following two types of reference comparisons:
1. **CPU Reference Tests (vs JAX-CPU)**: Output and gradients from Pallas kernels are checked for tolerance against reference implementations in `tops/cpu/` written in pure JAX.
2. **GPU Reference Tests (vs Torch-GPU/Triton)**: Aligning the computation results of Pallas kernels against known-correct, cross-framework computation libraries (such as those based on PyTorch or existing high-priority components like FlashAttention) under identical hardware conditions.
* **`tests/ops/`**: Modifications to low-level operators (e.g., scheduling optimizations of Pallas kernels) must use the two comparison test categories above to verify results or gradient tolerances. It is strictly prohibited to overstep and rely on high-level tests (such as `test_gla.py` in the layers tier) as a workaround for validation.
All Kernel correctness validation uses the **high-precision JAX CPU reference implementations** in `tops/cpu/` as the sole Golden baseline. Both GPU and TPU implementations are compared against this baseline with tolerance checks, and the TPU error must not exceed the GPU error:
1. **GPU Comparison Tests (GPU vs High-Precision CPU JAX)**: Output and gradients from GPU Pallas/JAX kernels are checked for tolerance against the high-precision pure JAX reference implementations in `tops/cpu/`. The error level is recorded as $\epsilon_{\text{GPU}}$.
2. **TPU Comparison Tests (TPU vs High-Precision CPU JAX)**: Output and gradients from TPU Pallas kernels are checked for tolerance against the high-precision pure JAX reference implementations in `tops/cpu/`. The error level is recorded as $\epsilon_{\text{TPU}}$.
3. **Error Bound Constraint**: **The TPU→CPU error must not exceed the GPU→CPU error**, i.e., $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$ must hold. If a TPU implementation exhibits higher error than the corresponding GPU implementation, it is treated as a precision defect in the TPU Kernel that must be fixed. This constraint ensures TPU Kernels maintain numerical precision at least on par with GPU Kernels.
### 3.3 Operator & Layer Tests
* **`tests/ops/`**: Modifications to low-level operators (e.g., scheduling optimizations of Pallas kernels) must use the GPU/TPU comparison tests described above to verify results or gradient tolerances. It is strictly prohibited to overstep and rely on high-level tests (such as `test_gla.py` in the layers tier) as a workaround for validation.
* **`tests/modules/` & `tests/layers/`**: Modifications at the network component or layer levels must include corresponding integration encapsulation and data flow validation tests.
* **Default comparator and GPU-specific exception**: The default comparator is `tops/cpu/`. If GPU-based comparison is required (e.g., Torch/Triton), create a separate test file and append the `_gpu` suffix to its filename.
* **Default comparator**: All tests default to comparing against the `tops/cpu/` high-precision reference implementation. GPU test files are named with a `_gpu` suffix, and TPU test files are named with a `_tpu` suffix.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify GPU test structure to confirm the discrepancy
rg -n 'tops.cpu|cpu_chunk_simple_gla' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py || echo "No tops.cpu imports found in GPU test"
rg -n '_run_triton_fwd|_run_jax_fwd' -A2 tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | head -30

Repository: primatrix/pallas-kernel

Length of output: 763


🏁 Script executed:

#!/bin/bash
# Get the full structure of the GPU test to see what's being compared
wc -l tests/ops/simple_gla/test_chunk_simple_gla_gpu.py
echo "---"
# Check imports at the top of the file
head -50 tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -n
echo "---"
# Check the full test function and comparisons
sed -n '200,250p' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -n

Repository: primatrix/pallas-kernel

Length of output: 5174


🏁 Script executed:

#!/bin/bash
# Check if there's a corresponding TPU test and how it compares
find tests/ops/simple_gla -name "*tpu*.py" -o -name "*cpu*.py" | head -20
echo "---"
# Look for tops.cpu usage patterns in the test directory
rg -l 'tops\.cpu|from tops' tests/ops/simple_gla/ | head -10

Repository: primatrix/pallas-kernel

Length of output: 467


🏁 Script executed:

#!/bin/bash
# Search for CPU reference implementations referenced in the test directory
rg -n 'tops\.cpu' tests/ops/ -A2 -B1 | head -50

Repository: primatrix/pallas-kernel

Length of output: 2603


🏁 Script executed:

#!/bin/bash
# Get the _run_jax_fwd implementation
sed -n '152,200p' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -n

Repository: primatrix/pallas-kernel

Length of output: 2092


🏁 Script executed:

#!/bin/bash
# Check if JAX is being explicitly placed on GPU in the test
rg -n 'jax.device_put|with jax.default_device|gpu|cuda' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | head -20

Repository: primatrix/pallas-kernel

Length of output: 365


🏁 Script executed:

#!/bin/bash
# Verify the file date to understand if ARCHITECTURE.md is a new requirement
git log --oneline -n 1 -- ARCHITECTURE.md
git log --oneline -n 1 -- tests/ops/simple_gla/test_chunk_simple_gla_gpu.py

Repository: primatrix/pallas-kernel

Length of output: 250


🏁 Script executed:

#!/bin/bash
# Check test_simple_gla_gpu.py for tops.cpu imports
rg -n 'tops\.cpu|from tops\.cpu' tests/ops/simple_gla/test_simple_gla_gpu.py | head -20
echo "---"
# Check first 60 lines of test_simple_gla_gpu.py
head -60 tests/ops/simple_gla/test_simple_gla_gpu.py | cat -n

Repository: primatrix/pallas-kernel

Length of output: 2630


🏁 Script executed:

#!/bin/bash
# Check if there's a CPU test for chunk_simple_gla
find tests/ -name "*chunk*simple*" -type f | sort
echo "---"
# Look for any CPU validation of chunk_simple_gla_fwd
rg -l 'chunk_simple_gla_fwd.*cpu|cpu.*chunk_simple_gla_fwd' tests/

Repository: primatrix/pallas-kernel

Length of output: 373


🏁 Script executed:

#!/bin/bash
# Check what's in the tests/ref directory (reference tests)
head -80 tests/ref/simple_gla/test_chunk_simple_gla.py | cat -n
echo "---"
# Check the full imports and structure
rg -n 'tops\.cpu|def test_' tests/ref/simple_gla/test_chunk_simple_gla.py | head -30

Repository: primatrix/pallas-kernel

Length of output: 3617


🏁 Script executed:

#!/bin/bash
# Check the full scope of the GPU test file to see what tests exist
grep -n "^def test_" tests/ops/simple_gla/test_chunk_simple_gla_gpu.py

Repository: primatrix/pallas-kernel

Length of output: 186


GPU test does not validate JAX Pallas against tops/cpu/ per the new methodology.

The ARCHITECTURE.md section 3.2 specifies that GPU Pallas/JAX kernels must compare against the high-precision tops/cpu/ reference baseline. However, test_chunk_simple_gla_gpu.py only compares Triton vs JAX directly—it does not import or validate against tops.cpu/:

# Line 207 of test_chunk_simple_gla_gpu.py
assert compare_tensor("output", o_tri, o_jax, atol=atol, rtol=rtol, max_ulp=max_ulp)

In contrast, test_chunk_simple_gla_tpu.py correctly imports and compares Pallas/JAX kernels against the CPU reference. Update the GPU test to validate the JAX Pallas kernel against tops/cpu/ separately, following the pattern established in the TPU test.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ARCHITECTURE.md` around lines 83 - 95, The GPU test currently only compares
Triton vs JAX; add the high-precision CPU reference check as in the TPU test by
importing the reference from tops.cpu (the same module used in the TPU test),
computing the CPU reference outputs (e.g., o_cpu) with the same inputs used for
the JAX Pallas run (o_jax), and then call compare_tensor("output", o_jax, o_cpu,
atol=atol, rtol=rtol, max_ulp=max_ulp). Keep the Triton vs JAX assertion
unchanged but ensure the JAX→CPU assertion uses the same tolerance variables and
mirrors the TPU test pattern in test_chunk_simple_gla_tpu.py so the GPU test
validates JAX Pallas against tops.cpu/.

Comment on lines +81 to +88
### 3.2 算子正确性验证方法 (Kernel Correctness Methodology)

* 在 `tests/` 目录下的测试类型被严格限制为以下两类标准的对比(Reference):
1. **CPU 参考对比测试 (vs JAX-CPU)**:将 Pallas 内核的输出和梯度,与 `tops/cpu/` 中纯 JAX 编写的参考实现做容差(Tolerances)检查。
2. **GPU 参考对比测试 (vs Torch-GPU/Triton)**:将 Pallas 内核与已知正确、跨框架计算库(如基于 PyTorch 编写或者 FlashAttention 等现有高优组件)在同等硬件条件下的计算结果进行对齐。
* **`tests/ops/`**:底层算子的修改(如 Pallas 内核的调度优化),必须采用上述两类对照测试验证结果或梯度容差,绝不可越权依赖高层级(如层的 `test_gla.py`)来变相验证。
所有 Kernel 的正确性验证统一以 `tops/cpu/` 中**高精度 JAX CPU 参考实现**作为唯一 Golden 基准。GPU 和 TPU 实现分别与该基准进行容差对比,并强制要求 TPU 的误差不得高于 GPU 的误差:

1. **GPU 对比测试 (GPU vs High-Precision CPU JAX)**:将 GPU 上的 Pallas/JAX 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{GPU}}$。
2. **TPU 对比测试 (TPU vs High-Precision CPU JAX)**:将 TPU 上的 Pallas 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{TPU}}$。
3. **误差上界约束(Error Bound Constraint)**:**TPU→CPU 的误差不得超过 GPU→CPU 的误差**,即必须满足 $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$。若 TPU 实现的误差高于 GPU 实现,则视为 TPU Kernel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

建议补充 ε 的统一定义与“如何断言 ε_TPU ≤ ε_GPU”的执行规则。

当前只给了原则,没有给可执行口径(例如 ε 是 max_abs、RMSE、ULP 还是 compare_tensor 返回的统计)和断言落点(同一用例内同时跑 GPU/TPU 并比较,还是离线基线)。这会导致不同测试作者实现不一致,约束难以落地。

📌 建议补充(文档 diff)
 ### 3.2 算子正确性验证方法 (Kernel Correctness Methodology)

 所有 Kernel 的正确性验证统一以 `tops/cpu/` 中**高精度 JAX CPU 参考实现**作为唯一 Golden 基准。GPU 和 TPU 实现分别与该基准进行容差对比,并强制要求 TPU 的误差不得高于 GPU 的误差:

+> 统一口径:`ε` 采用同一统计定义(建议为 `max_abs_error` + `max_ulp` 组合),
+> 且 GPU/TPU 必须在**同一组输入与同一 compare 配置**下计算误差。
+> 对于同时提供 `_gpu` 与 `_tpu` 测试的算子,必须在一处测试逻辑中显式断言
+> `ε_TPU <= ε_GPU`(或等价的逐指标不劣于约束)。
+
 1.  **GPU 对比测试 (GPU vs High-Precision CPU JAX)**:将 GPU 上的 Pallas/JAX 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{GPU}}$。
 2.  **TPU 对比测试 (TPU vs High-Precision CPU JAX)**:将 TPU 上的 Pallas 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{TPU}}$。
 3.  **误差上界约束(Error Bound Constraint)**:**TPU→CPU 的误差不得超过 GPU→CPU 的误差**,即必须满足 $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$。若 TPU 实现的误差高于 GPU 实现,则视为 TPU Kernel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。
🧰 Tools
🪛 LanguageTool

[uncategorized] ~87-~87: 您的意思是“少"于"”吗?
Context: ...rnel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。 ### 3.3 算子与层级测试 (Operat...

(YU7_YU8)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ARCHITECTURE.zh.md` around lines 81 - 88, 为使“ε_TPU ≤ ε_GPU”可执行,补充统一的 ε
定义和断言规则:在文档(Kernel Correctness Methodology)中明确 ε 的计算方法(例如支持选项
max_abs、RMSE、ULP,并推荐默认项,如使用 compare_tensor 返回的 max_abs 和 RMSE 两项作为标准),规定对比基准为
tops/cpu/ 高精度 JAX 实现并说明 compare_tensor 的统计输出格式;规定执行流程为“同一测试用例下同时在 GPU/TPU 上跑
Pallas/JAX 内核并与 CPU 参考比较,记录 ε_GPU 和 ε_TPU”,并定义聚合与判定规则(例如单用例必须满足 ε_TPU ≤ ε_GPU
且总体用例集按 p% 通过,或使用均值/百分位汇总),以及失败时的报告字段(用例 id、metric 名称、ε_GPU、ε_TPU、raw
差异)以便统一实现和自动化判断。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.py (1)

198-200: Consider simplifying the conversion chain.

The conversion torch.from_numpy(np.array(make_alibi_g_gamma(...))) is verbose. Since make_alibi_g_gamma returns a JAX array backed by NumPy, you can use .to_py() or rely on NumPy's array protocol directly.

♻️ Proposed simplification
     if gate == "g_gamma":
-        g_gamma = torch.from_numpy(np.array(make_alibi_g_gamma(H, 32, 0)))
+        g_gamma = torch.from_numpy(np.asarray(make_alibi_g_gamma(H, 32, 0)))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py` around lines 198 - 200,
Replace the verbose conversion chain around g_gamma: instead of
torch.from_numpy(np.array(make_alibi_g_gamma(...))), extract a NumPy array
directly from the JAX array returned by make_alibi_g_gamma (e.g. call its
.to_py() or use np.asarray on the JAX result) and then create the tensor with
torch.from_numpy; update the assignment to g_gamma (and the conditional using
gate == "g_gamma") to use that direct conversion so you avoid the redundant
np.array wrapper and preserve zero-copy semantics.
tops/ops/simple_gla/chunk.py (1)

464-517: Missing docstring for public function chunk_simple_gla_fwd.

As per coding guidelines for tops/**/*.py, all public functions must have a clear docstring explaining business semantics and tensor shape/dimension meanings for all inputs and outputs. The function chunk_simple_gla_fwd lacks this documentation.

📝 Proposed docstring
 def chunk_simple_gla_fwd(
     q: jax.Array,
     k: jax.Array,
     v: jax.Array,
     *,
     g: jax.Array | None = None,
     g_gamma: jax.Array | None = None,
     scale: float | None = None,
     h0: jax.Array | None = None,
     use_ht: bool = False,
     cu_seqlens_cpu: jax.Array | None = None,
     cu_seqlens_dev: jax.Array | None = None,
     chunk_size: int = 64,
 ) -> tuple[jax.Array, jax.Array | None]:
+    """Simple GLA chunked forward pass.
+
+    Computes the GLA recurrence: h_t = h_{t-1} * exp(g) + k_t^T @ v_t
+    and output o_t = q_t^T @ h_t using chunked computation.
+
+    Args:
+        q: [B, T, H, K] queries
+        k: [B, T, H, K] keys
+        v: [B, T, H, V] values
+        g: [B, T, H] per-position gate (optional, mutually exclusive with g_gamma)
+        g_gamma: [H] constant per-head log-decay (optional)
+        scale: attention scale factor (default: K^-0.5)
+        h0: [N, H, K, V] initial recurrent state (N=B or num_seqs)
+        use_ht: whether to return final state
+        cu_seqlens_cpu: [B+1] cumulative sequence lengths on CPU
+        cu_seqlens_dev: [B+1] cumulative sequence lengths on device
+        chunk_size: block size for chunked computation
+
+    Returns:
+        (o, ht): output [B, T, H, V] and optional final state [N, H, K, V]
+    """
     B, T, H, K, V = *q.shape, v.shape[-1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 464 - 517, Add a comprehensive
docstring to the public function chunk_simple_gla_fwd that explains its purpose
(chunked forward pass for the simple GLA), lists and describes every parameter
(q, k, v, g, g_gamma, scale, h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev,
chunk_size) including expected tensor shapes and meanings (B, T, H, K, V, N
where N = len(cu_seqlens_cpu)-1), the return values (o: output tensor shape (B,
T, H, V) and ht: final state or None with shape (N, H, K, V) when use_ht=True),
and all preconditions/assertions (T % chunk_size == 0, cu_seqlens_cpu divisible
by chunk_size, K % 128 == 0, V % 128 == 0). Also document side effects/behavior
(calls chunk_fwd_h and chunk_fwd_o, when h0 is used, when states are returned),
and what exceptions/assertions will be raised on invalid inputs to satisfy the
project’s public-function docstring guidelines.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Around line 198-200: Replace the verbose conversion chain around g_gamma:
instead of torch.from_numpy(np.array(make_alibi_g_gamma(...))), extract a NumPy
array directly from the JAX array returned by make_alibi_g_gamma (e.g. call its
.to_py() or use np.asarray on the JAX result) and then create the tensor with
torch.from_numpy; update the assignment to g_gamma (and the conditional using
gate == "g_gamma") to use that direct conversion so you avoid the redundant
np.array wrapper and preserve zero-copy semantics.

In `@tops/ops/simple_gla/chunk.py`:
- Around line 464-517: Add a comprehensive docstring to the public function
chunk_simple_gla_fwd that explains its purpose (chunked forward pass for the
simple GLA), lists and describes every parameter (q, k, v, g, g_gamma, scale,
h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev, chunk_size) including expected
tensor shapes and meanings (B, T, H, K, V, N where N = len(cu_seqlens_cpu)-1),
the return values (o: output tensor shape (B, T, H, V) and ht: final state or
None with shape (N, H, K, V) when use_ht=True), and all preconditions/assertions
(T % chunk_size == 0, cu_seqlens_cpu divisible by chunk_size, K % 128 == 0, V %
128 == 0). Also document side effects/behavior (calls chunk_fwd_h and
chunk_fwd_o, when h0 is used, when states are returned), and what
exceptions/assertions will be raised on invalid inputs to satisfy the project’s
public-function docstring guidelines.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8ee35684-c556-4767-8271-d6671ded5a11

📥 Commits

Reviewing files that changed from the base of the PR and between e8b09fb and c8852e1.

📒 Files selected for processing (3)
  • tests/ops/simple_gla/test_chunk_simple_gla_gpu.py
  • tops/ops/common/chunk_o.py
  • tops/ops/simple_gla/chunk.py
✅ Files skipped from review due to trivial changes (1)
  • tops/ops/common/chunk_o.py

@0xaskr 0xaskr added this pull request to the merge queue Mar 31, 2026
Merged via the queue into main with commit 7196900 Mar 31, 2026
2 checks passed
@0xaskr 0xaskr deleted the fix/chunk-nan-overflow-safe-exp branch March 31, 2026 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants