Conversation
When chunk_size > 64, exp(g_diff) in the masked-out region of the attention matrix can overflow f32 (e.g. |gamma|*128 = 90.5 > 88.72). On TPU, XLA may compile jnp.where(mask, val, 0) as mask*val, causing 0*Inf=NaN to propagate. Fix: use jnp.where(mask, g_diff, 0.0) before exp() so masked entries get exp(0)=1 instead of Inf. This is sign-agnostic (works for any g_gamma, not just negative) and matches Triton's tl.where semantics. Affected paths: - _chunk_fwd_o_kernel (g_gamma branch) - chunk_simple_gla_bwd_kernel (A recomputation + dA gating) - chunk_fwd_o_ref, chunk_bwd_dv, chunk_bwd_dqkwg (reference impls) Tested with actual ALiBi slopes (H=16, 20 layers) at chunk_size=128. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRefactors SimpleGLA GPU tests to call Triton forward/back entry points, adds Torch→JAX and ALiBi helpers, introduces per-case dtype/tolerance handling and chunk-size variants, tightens numeric casting and gated-exponential logic across chunk kernels, and adds a NaN/Inf stability test for chunk_size=128. Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the simple_gla chunk operations and associated GPU tests to improve numerical stability and alignment with Triton implementations. Key updates include the introduction of scaled tolerances for cross-platform comparisons, refactoring Triton test runners to use direct forward and backward calls, and adding utility functions for ALiBi slope generation. In the JAX kernels, explicit type casting was removed in favor of preferred_element_type in dot products, and exponential calculations were updated for better stability. Feedback identifies a typo in an environment variable and redundant mask calculations in the forward and backward kernels.
| import sys | ||
| from pathlib import Path | ||
| os.environ["TRITON_F32_DEFAULT"] = "ieee" | ||
| os.environ["‘TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128 |
There was a problem hiding this comment.
There is a typo in the environment variable name. The leading character ‘ (a smart quote) should be removed so that it correctly sets TF_GPU_ALLOCATOR. This is important for avoiding OOM errors as mentioned in the comment.
| os.environ["‘TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128 | |
| os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # Avoid OOM errors with large T and chunk_size=128 |
| _mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] | ||
| b_A = b_A * jnp.exp(jnp.where(_mask, b_g_gamma[:, None] - b_g_gamma[None, :], 0.0)) | ||
|
|
||
| mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] |
| mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] | ||
| b_dA = jnp.where(mask, b_dA, 0.0) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.py (1)
199-199: Minor: Redundant conversion chain.The conversion
np.array(make_alibi_g_gamma(...))is redundant sincemake_alibi_g_gammareturns ajnp.ndarraywhich can be directly converted vianp.asarray(). However, this works correctly as-is.♻️ Optional simplification
- g_gamma = torch.from_numpy(np.array(make_alibi_g_gamma(H, 32, 0))) + g_gamma = torch.from_numpy(np.asarray(make_alibi_g_gamma(H, 32, 0)))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py` at line 199, The conversion chain is redundant: instead of wrapping the jnp.ndarray returned by make_alibi_g_gamma(...) in np.array and then converting to a torch tensor, call np.asarray on the jnp.ndarray and pass that into torch.from_numpy; update the assignment to g_gamma to use np.asarray(make_alibi_g_gamma(H, 32, 0)) so you eliminate the unnecessary np.array call and keep torch.from_numpy(gamma_array) usage.tests/utils.py (1)
9-23: Potential precision loss for float64 inputs.The function converts to
.float()(float32) unconditionally before converting to NumPy, which loses precision for float64 inputs. If the caller passes a float64 tensor, the conversion chain is:float64 → float32 → numpy → jax, losing significant precision.Consider preserving precision for float64:
🔧 Proposed fix
def torch_to_jax(t: torch.Tensor, dtype=None) -> jax.Array: """Convert a torch tensor to a JAX array. Args: t: Input torch tensor. dtype: Target JAX dtype. If None, preserves bfloat16 when the input is bfloat16, otherwise uses float32. """ - np_arr = t.detach().cpu().float().numpy() + if t.dtype == torch.float64: + np_arr = t.detach().cpu().double().numpy() + else: + np_arr = t.detach().cpu().float().numpy() jax_arr = jnp.array(np_arr) if dtype is not None: return jax_arr.astype(dtype) if t.dtype == torch.bfloat16: return jax_arr.astype(jnp.bfloat16) + if t.dtype == torch.float64: + return jax_arr return jax_arr🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils.py` around lines 9 - 23, The helper torch_to_jax currently forces t.detach().cpu().float().numpy(), which downcasts float64 to float32 and loses precision; change it to preserve the original tensor dtype by using t.detach().cpu().numpy() (no .float()), then create the JAX array (jnp.array(np_arr)), and only cast afterwards: if dtype parameter is set, return .astype(dtype); else if input t.dtype is torch.bfloat16 return .astype(jnp.bfloat16); otherwise return the JAX array as-is. Update torch_to_jax to use the preserved np_arr and the same variable names (t, np_arr, jax_arr) so callers and tests keep working.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Line 9: Replace the incorrect environment variable key containing a stray left
single quotation mark by updating the os.environ assignment (the line that
currently sets os.environ["‘TF_GPU_ALLOCATOR"] = "cuda_malloc_async") to use the
correct key os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"; ensure the
bracketed string uses a normal ASCII double-quoted identifier "TF_GPU_ALLOCATOR"
so the GPU allocator setting is applied correctly.
---
Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Line 199: The conversion chain is redundant: instead of wrapping the
jnp.ndarray returned by make_alibi_g_gamma(...) in np.array and then converting
to a torch tensor, call np.asarray on the jnp.ndarray and pass that into
torch.from_numpy; update the assignment to g_gamma to use
np.asarray(make_alibi_g_gamma(H, 32, 0)) so you eliminate the unnecessary
np.array call and keep torch.from_numpy(gamma_array) usage.
In `@tests/utils.py`:
- Around line 9-23: The helper torch_to_jax currently forces
t.detach().cpu().float().numpy(), which downcasts float64 to float32 and loses
precision; change it to preserve the original tensor dtype by using
t.detach().cpu().numpy() (no .float()), then create the JAX array
(jnp.array(np_arr)), and only cast afterwards: if dtype parameter is set, return
.astype(dtype); else if input t.dtype is torch.bfloat16 return
.astype(jnp.bfloat16); otherwise return the JAX array as-is. Update torch_to_jax
to use the preserved np_arr and the same variable names (t, np_arr, jax_arr) so
callers and tests keep working.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 04e641e6-69d9-4ed0-bdf6-cd3f083a96b7
📒 Files selected for processing (5)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.pytests/utils.pytops/ops/common/chunk_h.pytops/ops/common/chunk_o.pytops/ops/simple_gla/chunk.py
* rm gk, rm recompute of h * add disable_bounds_checks * revert store h
…straint (#134) - Define high-precision CPU JAX (float64) as the sole Golden baseline - GPU implementation compared against high-precision CPU JAX (ε_GPU) - TPU implementation compared against high-precision CPU JAX (ε_TPU) - Enforce error bound constraint: ε_TPU ≤ ε_GPU - Add GPU/TPU test file naming convention (_gpu/_tpu suffix)
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/ops/common/chunk_h.py (1)
336-342:⚠️ Potential issue | 🟠 MajorUse explicit DotAlgorithm or restore explicit fp32 casts to guarantee precision behavior.
The assumption that
precision=lax.Precision.HIGHESTandpreferred_element_type=jnp.float32guarantee float32 accumulation for bfloat16 inputs is incorrect. The JAX documentation clarifies thatpreferred_element_typeis a hint to the compiler, not a guarantee—backend behavior varies. For bfloat16 inputs, the dot product may accumulate at bf16 precision instead of float32, producing different numerical results than the explicit.astype(jnp.float32)casts that were removed.To restore guaranteed float32 accumulation, either:
- Use an explicit DotAlgorithm (e.g.,
lax.Precision.BF16_BF16_F32or equivalent) where supported on your target hardware, or- Restore the explicit
.astype(jnp.float32)casts onb_kandb_vbefore thedot_generalcall.Do not rely on implicit compiler behavior for precision-critical operations.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_h.py` around lines 336 - 342, The dot_general call using b_k and b_v relies on preferred_element_type and lax.Precision.HIGHEST which is not guaranteed to force FP32 accumulation; update the lax.dot_general invocation in chunk_h.py to ensure FP32 accumulation by either (A) specifying a backend-supported algorithm enum on lax.dot_general (e.g., the appropriate DotAlgorithm/precision constant such as a BF16->F32 accumulation variant instead of lax.Precision.HIGHEST) or (B) restore explicit casts: cast b_k and b_v to jnp.float32 before calling lax.dot_general so accumulation is performed in float32; apply the change where kv = lax.dot_general(...) is defined and keep references to b_k, b_v, preferred_element_type and lax.Precision consistent with the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ARCHITECTURE.md`:
- Around line 83-95: The GPU test currently only compares Triton vs JAX; add the
high-precision CPU reference check as in the TPU test by importing the reference
from tops.cpu (the same module used in the TPU test), computing the CPU
reference outputs (e.g., o_cpu) with the same inputs used for the JAX Pallas run
(o_jax), and then call compare_tensor("output", o_jax, o_cpu, atol=atol,
rtol=rtol, max_ulp=max_ulp). Keep the Triton vs JAX assertion unchanged but
ensure the JAX→CPU assertion uses the same tolerance variables and mirrors the
TPU test pattern in test_chunk_simple_gla_tpu.py so the GPU test validates JAX
Pallas against tops.cpu/.
In `@ARCHITECTURE.zh.md`:
- Around line 81-88: 为使“ε_TPU ≤ ε_GPU”可执行,补充统一的 ε 定义和断言规则:在文档(Kernel Correctness
Methodology)中明确 ε 的计算方法(例如支持选项 max_abs、RMSE、ULP,并推荐默认项,如使用 compare_tensor 返回的
max_abs 和 RMSE 两项作为标准),规定对比基准为 tops/cpu/ 高精度 JAX 实现并说明 compare_tensor
的统计输出格式;规定执行流程为“同一测试用例下同时在 GPU/TPU 上跑 Pallas/JAX 内核并与 CPU 参考比较,记录 ε_GPU 和
ε_TPU”,并定义聚合与判定规则(例如单用例必须满足 ε_TPU ≤ ε_GPU 且总体用例集按 p%
通过,或使用均值/百分位汇总),以及失败时的报告字段(用例 id、metric 名称、ε_GPU、ε_TPU、raw 差异)以便统一实现和自动化判断。
---
Outside diff comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 336-342: The dot_general call using b_k and b_v relies on
preferred_element_type and lax.Precision.HIGHEST which is not guaranteed to
force FP32 accumulation; update the lax.dot_general invocation in chunk_h.py to
ensure FP32 accumulation by either (A) specifying a backend-supported algorithm
enum on lax.dot_general (e.g., the appropriate DotAlgorithm/precision constant
such as a BF16->F32 accumulation variant instead of lax.Precision.HIGHEST) or
(B) restore explicit casts: cast b_k and b_v to jnp.float32 before calling
lax.dot_general so accumulation is performed in float32; apply the change where
kv = lax.dot_general(...) is defined and keep references to b_k, b_v,
preferred_element_type and lax.Precision consistent with the chosen approach.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b419bbf5-1e1a-4df2-9772-6ce3908385ec
📒 Files selected for processing (4)
ARCHITECTURE.mdARCHITECTURE.zh.mdtops/ops/common/chunk_h.pytops/ops/simple_gla/chunk.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tops/ops/simple_gla/chunk.py
| ### 3.2 Kernel Correctness Methodology | ||
| * Test types within the `tests/` directory are strictly restricted to the following two types of reference comparisons: | ||
| 1. **CPU Reference Tests (vs JAX-CPU)**: Output and gradients from Pallas kernels are checked for tolerance against reference implementations in `tops/cpu/` written in pure JAX. | ||
| 2. **GPU Reference Tests (vs Torch-GPU/Triton)**: Aligning the computation results of Pallas kernels against known-correct, cross-framework computation libraries (such as those based on PyTorch or existing high-priority components like FlashAttention) under identical hardware conditions. | ||
| * **`tests/ops/`**: Modifications to low-level operators (e.g., scheduling optimizations of Pallas kernels) must use the two comparison test categories above to verify results or gradient tolerances. It is strictly prohibited to overstep and rely on high-level tests (such as `test_gla.py` in the layers tier) as a workaround for validation. | ||
| All Kernel correctness validation uses the **high-precision JAX CPU reference implementations** in `tops/cpu/` as the sole Golden baseline. Both GPU and TPU implementations are compared against this baseline with tolerance checks, and the TPU error must not exceed the GPU error: | ||
| 1. **GPU Comparison Tests (GPU vs High-Precision CPU JAX)**: Output and gradients from GPU Pallas/JAX kernels are checked for tolerance against the high-precision pure JAX reference implementations in `tops/cpu/`. The error level is recorded as $\epsilon_{\text{GPU}}$. | ||
| 2. **TPU Comparison Tests (TPU vs High-Precision CPU JAX)**: Output and gradients from TPU Pallas kernels are checked for tolerance against the high-precision pure JAX reference implementations in `tops/cpu/`. The error level is recorded as $\epsilon_{\text{TPU}}$. | ||
| 3. **Error Bound Constraint**: **The TPU→CPU error must not exceed the GPU→CPU error**, i.e., $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$ must hold. If a TPU implementation exhibits higher error than the corresponding GPU implementation, it is treated as a precision defect in the TPU Kernel that must be fixed. This constraint ensures TPU Kernels maintain numerical precision at least on par with GPU Kernels. | ||
| ### 3.3 Operator & Layer Tests | ||
| * **`tests/ops/`**: Modifications to low-level operators (e.g., scheduling optimizations of Pallas kernels) must use the GPU/TPU comparison tests described above to verify results or gradient tolerances. It is strictly prohibited to overstep and rely on high-level tests (such as `test_gla.py` in the layers tier) as a workaround for validation. | ||
| * **`tests/modules/` & `tests/layers/`**: Modifications at the network component or layer levels must include corresponding integration encapsulation and data flow validation tests. | ||
| * **Default comparator and GPU-specific exception**: The default comparator is `tops/cpu/`. If GPU-based comparison is required (e.g., Torch/Triton), create a separate test file and append the `_gpu` suffix to its filename. | ||
| * **Default comparator**: All tests default to comparing against the `tops/cpu/` high-precision reference implementation. GPU test files are named with a `_gpu` suffix, and TPU test files are named with a `_tpu` suffix. |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify GPU test structure to confirm the discrepancy
rg -n 'tops.cpu|cpu_chunk_simple_gla' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py || echo "No tops.cpu imports found in GPU test"
rg -n '_run_triton_fwd|_run_jax_fwd' -A2 tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | head -30Repository: primatrix/pallas-kernel
Length of output: 763
🏁 Script executed:
#!/bin/bash
# Get the full structure of the GPU test to see what's being compared
wc -l tests/ops/simple_gla/test_chunk_simple_gla_gpu.py
echo "---"
# Check imports at the top of the file
head -50 tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -n
echo "---"
# Check the full test function and comparisons
sed -n '200,250p' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -nRepository: primatrix/pallas-kernel
Length of output: 5174
🏁 Script executed:
#!/bin/bash
# Check if there's a corresponding TPU test and how it compares
find tests/ops/simple_gla -name "*tpu*.py" -o -name "*cpu*.py" | head -20
echo "---"
# Look for tops.cpu usage patterns in the test directory
rg -l 'tops\.cpu|from tops' tests/ops/simple_gla/ | head -10Repository: primatrix/pallas-kernel
Length of output: 467
🏁 Script executed:
#!/bin/bash
# Search for CPU reference implementations referenced in the test directory
rg -n 'tops\.cpu' tests/ops/ -A2 -B1 | head -50Repository: primatrix/pallas-kernel
Length of output: 2603
🏁 Script executed:
#!/bin/bash
# Get the _run_jax_fwd implementation
sed -n '152,200p' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | cat -nRepository: primatrix/pallas-kernel
Length of output: 2092
🏁 Script executed:
#!/bin/bash
# Check if JAX is being explicitly placed on GPU in the test
rg -n 'jax.device_put|with jax.default_device|gpu|cuda' tests/ops/simple_gla/test_chunk_simple_gla_gpu.py | head -20Repository: primatrix/pallas-kernel
Length of output: 365
🏁 Script executed:
#!/bin/bash
# Verify the file date to understand if ARCHITECTURE.md is a new requirement
git log --oneline -n 1 -- ARCHITECTURE.md
git log --oneline -n 1 -- tests/ops/simple_gla/test_chunk_simple_gla_gpu.pyRepository: primatrix/pallas-kernel
Length of output: 250
🏁 Script executed:
#!/bin/bash
# Check test_simple_gla_gpu.py for tops.cpu imports
rg -n 'tops\.cpu|from tops\.cpu' tests/ops/simple_gla/test_simple_gla_gpu.py | head -20
echo "---"
# Check first 60 lines of test_simple_gla_gpu.py
head -60 tests/ops/simple_gla/test_simple_gla_gpu.py | cat -nRepository: primatrix/pallas-kernel
Length of output: 2630
🏁 Script executed:
#!/bin/bash
# Check if there's a CPU test for chunk_simple_gla
find tests/ -name "*chunk*simple*" -type f | sort
echo "---"
# Look for any CPU validation of chunk_simple_gla_fwd
rg -l 'chunk_simple_gla_fwd.*cpu|cpu.*chunk_simple_gla_fwd' tests/Repository: primatrix/pallas-kernel
Length of output: 373
🏁 Script executed:
#!/bin/bash
# Check what's in the tests/ref directory (reference tests)
head -80 tests/ref/simple_gla/test_chunk_simple_gla.py | cat -n
echo "---"
# Check the full imports and structure
rg -n 'tops\.cpu|def test_' tests/ref/simple_gla/test_chunk_simple_gla.py | head -30Repository: primatrix/pallas-kernel
Length of output: 3617
🏁 Script executed:
#!/bin/bash
# Check the full scope of the GPU test file to see what tests exist
grep -n "^def test_" tests/ops/simple_gla/test_chunk_simple_gla_gpu.pyRepository: primatrix/pallas-kernel
Length of output: 186
GPU test does not validate JAX Pallas against tops/cpu/ per the new methodology.
The ARCHITECTURE.md section 3.2 specifies that GPU Pallas/JAX kernels must compare against the high-precision tops/cpu/ reference baseline. However, test_chunk_simple_gla_gpu.py only compares Triton vs JAX directly—it does not import or validate against tops.cpu/:
# Line 207 of test_chunk_simple_gla_gpu.py
assert compare_tensor("output", o_tri, o_jax, atol=atol, rtol=rtol, max_ulp=max_ulp)In contrast, test_chunk_simple_gla_tpu.py correctly imports and compares Pallas/JAX kernels against the CPU reference. Update the GPU test to validate the JAX Pallas kernel against tops/cpu/ separately, following the pattern established in the TPU test.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ARCHITECTURE.md` around lines 83 - 95, The GPU test currently only compares
Triton vs JAX; add the high-precision CPU reference check as in the TPU test by
importing the reference from tops.cpu (the same module used in the TPU test),
computing the CPU reference outputs (e.g., o_cpu) with the same inputs used for
the JAX Pallas run (o_jax), and then call compare_tensor("output", o_jax, o_cpu,
atol=atol, rtol=rtol, max_ulp=max_ulp). Keep the Triton vs JAX assertion
unchanged but ensure the JAX→CPU assertion uses the same tolerance variables and
mirrors the TPU test pattern in test_chunk_simple_gla_tpu.py so the GPU test
validates JAX Pallas against tops.cpu/.
| ### 3.2 算子正确性验证方法 (Kernel Correctness Methodology) | ||
|
|
||
| * 在 `tests/` 目录下的测试类型被严格限制为以下两类标准的对比(Reference): | ||
| 1. **CPU 参考对比测试 (vs JAX-CPU)**:将 Pallas 内核的输出和梯度,与 `tops/cpu/` 中纯 JAX 编写的参考实现做容差(Tolerances)检查。 | ||
| 2. **GPU 参考对比测试 (vs Torch-GPU/Triton)**:将 Pallas 内核与已知正确、跨框架计算库(如基于 PyTorch 编写或者 FlashAttention 等现有高优组件)在同等硬件条件下的计算结果进行对齐。 | ||
| * **`tests/ops/`**:底层算子的修改(如 Pallas 内核的调度优化),必须采用上述两类对照测试验证结果或梯度容差,绝不可越权依赖高层级(如层的 `test_gla.py`)来变相验证。 | ||
| 所有 Kernel 的正确性验证统一以 `tops/cpu/` 中**高精度 JAX CPU 参考实现**作为唯一 Golden 基准。GPU 和 TPU 实现分别与该基准进行容差对比,并强制要求 TPU 的误差不得高于 GPU 的误差: | ||
|
|
||
| 1. **GPU 对比测试 (GPU vs High-Precision CPU JAX)**:将 GPU 上的 Pallas/JAX 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{GPU}}$。 | ||
| 2. **TPU 对比测试 (TPU vs High-Precision CPU JAX)**:将 TPU 上的 Pallas 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{TPU}}$。 | ||
| 3. **误差上界约束(Error Bound Constraint)**:**TPU→CPU 的误差不得超过 GPU→CPU 的误差**,即必须满足 $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$。若 TPU 实现的误差高于 GPU 实现,则视为 TPU Kernel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。 | ||
|
|
There was a problem hiding this comment.
建议补充 ε 的统一定义与“如何断言 ε_TPU ≤ ε_GPU”的执行规则。
当前只给了原则,没有给可执行口径(例如 ε 是 max_abs、RMSE、ULP 还是 compare_tensor 返回的统计)和断言落点(同一用例内同时跑 GPU/TPU 并比较,还是离线基线)。这会导致不同测试作者实现不一致,约束难以落地。
📌 建议补充(文档 diff)
### 3.2 算子正确性验证方法 (Kernel Correctness Methodology)
所有 Kernel 的正确性验证统一以 `tops/cpu/` 中**高精度 JAX CPU 参考实现**作为唯一 Golden 基准。GPU 和 TPU 实现分别与该基准进行容差对比,并强制要求 TPU 的误差不得高于 GPU 的误差:
+> 统一口径:`ε` 采用同一统计定义(建议为 `max_abs_error` + `max_ulp` 组合),
+> 且 GPU/TPU 必须在**同一组输入与同一 compare 配置**下计算误差。
+> 对于同时提供 `_gpu` 与 `_tpu` 测试的算子,必须在一处测试逻辑中显式断言
+> `ε_TPU <= ε_GPU`(或等价的逐指标不劣于约束)。
+
1. **GPU 对比测试 (GPU vs High-Precision CPU JAX)**:将 GPU 上的 Pallas/JAX 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{GPU}}$。
2. **TPU 对比测试 (TPU vs High-Precision CPU JAX)**:将 TPU 上的 Pallas 内核输出和梯度,与 `tops/cpu/` 中高精度纯 JAX 参考实现做容差检查,记录误差水平 $\epsilon_{\text{TPU}}$。
3. **误差上界约束(Error Bound Constraint)**:**TPU→CPU 的误差不得超过 GPU→CPU 的误差**,即必须满足 $\epsilon_{\text{TPU}} \leq \epsilon_{\text{GPU}}$。若 TPU 实现的误差高于 GPU 实现,则视为 TPU Kernel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。🧰 Tools
🪛 LanguageTool
[uncategorized] ~87-~87: 您的意思是“少"于"”吗?
Context: ...rnel 存在精度缺陷,需修复。此约束确保 TPU Kernel 的数值精度至少与 GPU Kernel 持平。 ### 3.3 算子与层级测试 (Operat...
(YU7_YU8)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ARCHITECTURE.zh.md` around lines 81 - 88, 为使“ε_TPU ≤ ε_GPU”可执行,补充统一的 ε
定义和断言规则:在文档(Kernel Correctness Methodology)中明确 ε 的计算方法(例如支持选项
max_abs、RMSE、ULP,并推荐默认项,如使用 compare_tensor 返回的 max_abs 和 RMSE 两项作为标准),规定对比基准为
tops/cpu/ 高精度 JAX 实现并说明 compare_tensor 的统计输出格式;规定执行流程为“同一测试用例下同时在 GPU/TPU 上跑
Pallas/JAX 内核并与 CPU 参考比较,记录 ε_GPU 和 ε_TPU”,并定义聚合与判定规则(例如单用例必须满足 ε_TPU ≤ ε_GPU
且总体用例集按 p% 通过,或使用均值/百分位汇总),以及失败时的报告字段(用例 id、metric 名称、ε_GPU、ε_TPU、raw
差异)以便统一实现和自动化判断。
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.py (1)
198-200: Consider simplifying the conversion chain.The conversion
torch.from_numpy(np.array(make_alibi_g_gamma(...)))is verbose. Sincemake_alibi_g_gammareturns a JAX array backed by NumPy, you can use.to_py()or rely on NumPy's array protocol directly.♻️ Proposed simplification
if gate == "g_gamma": - g_gamma = torch.from_numpy(np.array(make_alibi_g_gamma(H, 32, 0))) + g_gamma = torch.from_numpy(np.asarray(make_alibi_g_gamma(H, 32, 0)))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py` around lines 198 - 200, Replace the verbose conversion chain around g_gamma: instead of torch.from_numpy(np.array(make_alibi_g_gamma(...))), extract a NumPy array directly from the JAX array returned by make_alibi_g_gamma (e.g. call its .to_py() or use np.asarray on the JAX result) and then create the tensor with torch.from_numpy; update the assignment to g_gamma (and the conditional using gate == "g_gamma") to use that direct conversion so you avoid the redundant np.array wrapper and preserve zero-copy semantics.tops/ops/simple_gla/chunk.py (1)
464-517: Missing docstring for public functionchunk_simple_gla_fwd.As per coding guidelines for
tops/**/*.py, all public functions must have a clear docstring explaining business semantics and tensor shape/dimension meanings for all inputs and outputs. The functionchunk_simple_gla_fwdlacks this documentation.📝 Proposed docstring
def chunk_simple_gla_fwd( q: jax.Array, k: jax.Array, v: jax.Array, *, g: jax.Array | None = None, g_gamma: jax.Array | None = None, scale: float | None = None, h0: jax.Array | None = None, use_ht: bool = False, cu_seqlens_cpu: jax.Array | None = None, cu_seqlens_dev: jax.Array | None = None, chunk_size: int = 64, ) -> tuple[jax.Array, jax.Array | None]: + """Simple GLA chunked forward pass. + + Computes the GLA recurrence: h_t = h_{t-1} * exp(g) + k_t^T @ v_t + and output o_t = q_t^T @ h_t using chunked computation. + + Args: + q: [B, T, H, K] queries + k: [B, T, H, K] keys + v: [B, T, H, V] values + g: [B, T, H] per-position gate (optional, mutually exclusive with g_gamma) + g_gamma: [H] constant per-head log-decay (optional) + scale: attention scale factor (default: K^-0.5) + h0: [N, H, K, V] initial recurrent state (N=B or num_seqs) + use_ht: whether to return final state + cu_seqlens_cpu: [B+1] cumulative sequence lengths on CPU + cu_seqlens_dev: [B+1] cumulative sequence lengths on device + chunk_size: block size for chunked computation + + Returns: + (o, ht): output [B, T, H, V] and optional final state [N, H, K, V] + """ B, T, H, K, V = *q.shape, v.shape[-1]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/chunk.py` around lines 464 - 517, Add a comprehensive docstring to the public function chunk_simple_gla_fwd that explains its purpose (chunked forward pass for the simple GLA), lists and describes every parameter (q, k, v, g, g_gamma, scale, h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev, chunk_size) including expected tensor shapes and meanings (B, T, H, K, V, N where N = len(cu_seqlens_cpu)-1), the return values (o: output tensor shape (B, T, H, V) and ht: final state or None with shape (N, H, K, V) when use_ht=True), and all preconditions/assertions (T % chunk_size == 0, cu_seqlens_cpu divisible by chunk_size, K % 128 == 0, V % 128 == 0). Also document side effects/behavior (calls chunk_fwd_h and chunk_fwd_o, when h0 is used, when states are returned), and what exceptions/assertions will be raised on invalid inputs to satisfy the project’s public-function docstring guidelines.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_gpu.py`:
- Around line 198-200: Replace the verbose conversion chain around g_gamma:
instead of torch.from_numpy(np.array(make_alibi_g_gamma(...))), extract a NumPy
array directly from the JAX array returned by make_alibi_g_gamma (e.g. call its
.to_py() or use np.asarray on the JAX result) and then create the tensor with
torch.from_numpy; update the assignment to g_gamma (and the conditional using
gate == "g_gamma") to use that direct conversion so you avoid the redundant
np.array wrapper and preserve zero-copy semantics.
In `@tops/ops/simple_gla/chunk.py`:
- Around line 464-517: Add a comprehensive docstring to the public function
chunk_simple_gla_fwd that explains its purpose (chunked forward pass for the
simple GLA), lists and describes every parameter (q, k, v, g, g_gamma, scale,
h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev, chunk_size) including expected
tensor shapes and meanings (B, T, H, K, V, N where N = len(cu_seqlens_cpu)-1),
the return values (o: output tensor shape (B, T, H, V) and ht: final state or
None with shape (N, H, K, V) when use_ht=True), and all preconditions/assertions
(T % chunk_size == 0, cu_seqlens_cpu divisible by chunk_size, K % 128 == 0, V %
128 == 0). Also document side effects/behavior (calls chunk_fwd_h and
chunk_fwd_o, when h0 is used, when states are returned), and what
exceptions/assertions will be raised on invalid inputs to satisfy the project’s
public-function docstring guidelines.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8ee35684-c556-4767-8271-d6671ded5a11
📒 Files selected for processing (3)
tests/ops/simple_gla/test_chunk_simple_gla_gpu.pytops/ops/common/chunk_o.pytops/ops/simple_gla/chunk.py
✅ Files skipped from review due to trivial changes (1)
- tops/ops/common/chunk_o.py
This pull request significantly refactors and improves the GPU test suite for the
chunk_simple_glaoperator, focusing on more accurate, scalable, and robust cross-framework (Triton vs JAX) testing. It introduces dynamic, shape-aware tolerance scaling, adds new utilities for ALiBi gating parameter generation, and enhances test coverage for numerical stability. The code is also modernized to directly call forward/backward kernels, and test configurations are made more flexible and realistic.Key changes:
Test infrastructure improvements:
_tolhelper, so numerical tolerances for tests now adapt automatically to sequence length and chunk size, improving reliability of cross-platform comparisons. [1] [2]chunk_sizeanddtypeper test case, and ensured these are reflected in test IDs and tensor creation. [1] [2] [3] [4]Refactoring and modernization:
chunk_simple_gla_fwdandchunk_simple_gla_bwddirectly for both Triton and JAX, removing reliance on autograd for reference gradients and making the tests more explicit and robust._torch_to_jaxwith a newtorch_to_jaxutility intests/utils.pythat better preserves dtypes and is reused across tests.ALiBi gating and parameter utilities:
build_alibi_slopesandmake_alibi_g_gammautilities totests/utils.py, enabling test cases to use realistic ALiBi gating parameters matching those used in production models.Numerical stability and coverage:
g_gammavalues, especially forchunk_size=128and multi-layer settings.Minor fixes:
tops/ops/common/chunk_h.pyby removing unnecessary type casts in matrix multiplication, ensuring correct precision is used.These changes make the test suite much more robust, maintainable, and reflective of real-world usage.
Summary by CodeRabbit
Bug Fixes
Tests
Documentation