Skip to content

fix(gla): midpoint stabilization in chunk_gla_fwd_intra_gk_ref for BT=128#133

Open
jimoosciuc wants to merge 8 commits intomainfrom
fix/bwd-intra-stabilize
Open

fix(gla): midpoint stabilization in chunk_gla_fwd_intra_gk_ref for BT=128#133
jimoosciuc wants to merge 8 commits intomainfrom
fix/bwd-intra-stabilize

Conversation

@jimoosciuc
Copy link
Copy Markdown
Contributor

@jimoosciuc jimoosciuc commented Mar 31, 2026

Summary

  • Fix NaN in backward pass at chunk_size=128: chunk_gla_fwd_intra_gk_ref used first-position stabilization (g_n = g_c[0]), causing exp(|g_gamma|*127) to overflow float32 when |g_gamma| > 0.69
  • Switch to midpoint stabilization (matching the Pallas kernel chunk_gla_fwd_intra_gk_pl), halving the max exponent from 101.6 to 50.8
  • Fix chunk_gla_bwd call to chunk_bwd_dh_ref to use keyword args matching updated signature

Root Cause

At chunk_size=128 with |g_gamma|=0.8:

  • exp(g_n - g_c[-1]) = exp(0.8 * 127) = exp(101.6)float32 overflow (Inf)
  • k_gated[-1] = k * Inf, q_gated[-1] = q * 00 * Inf = NaN on diagonal of A matrix
  • NaN propagates through causal mask into dq/dk/dv

At chunk_size=64: exp(0.8 * 63) = exp(50.4) → safe (float32 max exp ≈ 88.7)

Test plan

  • All 11 overflow stabilization tests pass (test_exp_overflow_stabilization.py)
  • New test_constant_g_gamma_bwd_no_nan covers exact AL model scenario (BT=128, g_gamma=-0.8)
  • Validate in ant-pretrain E2E CI (PR #316)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Improved intra-chunk gate/exponent handling and clamped exponent inputs to prevent overflow and NaN/Inf in attention-like computations.
    • Revised backward-gradient routing to avoid NaN/Inf propagation for large chunks or extreme gate values.
    • Added runtime validation to reject positive gate-decay inputs that would cause unsafe exponentiation.
  • Tests

    • Added parameterized tests validating gradients remain finite (no NaN/Inf) on large-chunk scenarios with varied gate values.

The reference intra-chunk attention function used first-position
stabilization (g_n = g_c[0]), causing exp(g_n - g_c[-1]) to overflow
float32 when |g_gamma| * (chunk_size - 1) > 88.7. At chunk_size=128
with |g_gamma|=0.8, this produced exp(101.6) → Inf, then 0 * Inf = NaN
on diagonal entries of the attention matrix.

Switch to midpoint stabilization (already used by the Pallas kernel
variant), halving the max exponent from 101.6 to 50.8 at BT=128.

Also fix chunk_gla_bwd call to chunk_bwd_dh_ref to use keyword args
matching the updated signature.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replace single-row intra-chunk gate references with midpoint-based gating in forward kernels, clamp exponent differences to 0 on the causal upper triangle before exp() in chunk kernels, and route g_cumsum as gk with adjusted chunk_gla_bwd -> chunk_bwd_dh_ref args; add tests asserting no NaN/Inf in backward gradients.

Changes

Cohort / File(s) Summary
Gate stabilization & backward routing
tops/ops/gla/chunk.py
Use midpoint (average of first and last rows) as intra-chunk gate reference in forward kernels; change chunk_gla_bwd to call chunk_bwd_dh_ref with g=None, g_gamma=None, pass gk=g_cumsum, and set output_dh0=(initial_state is not None or dht is not None).
Safe exponent handling in chunk kernels
tops/ops/common/chunk_o.py
Clamp exponent differences to 0.0 on the causal upper triangle before exp() for gating and decay computations (forward/backward kernels) to avoid overflow and 0 * Inf → NaN.
Backward gradient tests & matrix sizes
tests/ops/gla/test_exp_overflow_stabilization.py, tests/ops/simple_gla/test_chunk_simple_gla_tpu.py
Add/extend tests with chunk_size=128 and parametrized g_gamma cases ([-0.5,-0.8,-1.0]) that construct fixed inputs and assert dq, dk, dv, dg contain no NaN/Inf when running chunked GLA backward paths.
Runtime validation
tops/ops/simple_gla/chunk.py
Add assertion in orchestration functions requiring g_gamma to be None or all elements <= 0 to prevent positive-gamma overflow in decay/gating exponentials.

Sequence Diagram(s)

sequenceDiagram
  participant Test as Test Harness
  participant Input as Input Tensors
  participant Fwd as chunk_gla_fwd_intra_gk_*
  participant Bwd as chunk_gla_bwd
  participant DhRef as chunk_bwd_dh_ref

  Test->>Input: build fixed-size GLA inputs + constant g_gamma
  Input->>Fwd: forward (use g_mid = (first+last)/2)
  Fwd-->>Input: produce gated q/k with midpoint-based exponents
  Test->>Bwd: call chunk_gla_bwd(g=None, g_gamma=None, gk=g_cumsum, ...)
  Bwd->>DhRef: route backward dh with output_dh0 flag
  DhRef-->>Bwd: return dh gradients
  Bwd-->>Test: return dq, dk, dv, dg (checked for NaN/Inf)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • 0xaskr

Poem

🐰 I nibble at gates, front and tail I bind,
Averaged mid-step so exps stay kind.
No NaN in my hop, no Inf in my track,
Gradients bounce clean — I give them a snack.
Hooray, stable chunks — a rabbit's small clap!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main change: switching to midpoint stabilization in chunk_gla_fwd_intra_gk_ref at chunk_size=128 to prevent overflow.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/bwd-intra-stabilize

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces midpoint stabilization in the Gated Linear Attention (GLA) forward pass to prevent exponent overflow when using large chunk sizes. It also adds a regression test to ensure numerical stability for constant g_gamma values in the backward pass. A bug was identified in the backward pass implementation where the 'output_dh0' argument was omitted, leading to the loss of initial state gradients; a code suggestion was provided to resolve this.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/gla/test_exp_overflow_stabilization.py`:
- Line 206: The test unpacks five values from chunk_gla_bwd into dq, dk, dv, dg,
dh0 but dh0 is unused which triggers Ruff RUF059; update the unpacking in the
test to use a deliberately ignored variable name (e.g., _dh0) instead of dh0 so
the intent is clear and the linter warning is silenced—locate the call to
chunk_gla_bwd in the test and change the last target from dh0 to _dh0.
- Around line 174-183: The file calls pytest.main([__file__, "-v"]) before new
tests are defined, causing tests like test_constant_g_gamma_bwd_no_nan to be
skipped when running the module directly; move (or remove) the pytest.main
invocation so it executes only after all tests are declared — e.g., place the if
__name__ == "__main__": pytest.main([...]) block at the end of the file (below
test_constant_g_gamma_bwd_no_nan and any other tests), ensuring pytest.main runs
after test definitions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7b2ed14a-5332-4bf3-b8c7-d9ce701212b0

📥 Commits

Reviewing files that changed from the base of the PR and between 2b47657 and 08d124b.

📒 Files selected for processing (2)
  • tests/ops/gla/test_exp_overflow_stabilization.py
  • tops/ops/gla/chunk.py

jimoosciuc and others added 2 commits March 31, 2026 23:06
- Add output_dh0 flag to chunk_bwd_dh_ref call so dh0 is only computed
  when initial_state or dht is provided
- Move pytest.main guard to end of test file (after all test functions)
- Mark unused dh0 as _dh0 to satisfy linter

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The Toeplitz decay exp(gamma*(i-j)) overflows float32 in the upper
triangle when |gamma|*(BT-1) > 88.7 (e.g. BT=128, |gamma|=0.8 gives
exp(101.6) = Inf).  On TPU, 0 * Inf = NaN leaks through jnp.where
because XLA may compile select as mask*x + (1-mask)*y instead of
a true IEEE select.

Fix: clamp the exponent to 0.0 in the upper triangle *before* exp(),
so exp(0)=1 replaces the overflow.  The upper triangle is masked to
zero anyway, so the result is mathematically identical.

Applied to both forward (_chunk_fwd_o_kernel) and backward
(chunk_simple_gla_bwd_kernel) kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Mar 31, 2026

This PR[https://github.com//pull/131] has been merged. Try pointing pallas-kernel to the latest main branch directly in the pretrain repository. @jimoosciuc

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tops/ops/common/chunk_o.py (1)

66-69: Add a direct regression for the chunk_fwd_o masked-exp path.

The new stabilization tests called out in this PR exercise chunk_gla_bwd and chunk_gla_fwd_intra_gk_ref, but not this _chunk_fwd_o_kernel/chunk_fwd_o path. A chunk_size=128, g_gamma=-0.8 case here would keep this exact clamp from regressing independently of the backward fix.

Also applies to: 76-79

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/common/chunk_o.py`:
- Around line 218-222: The kernel assumes decay rates are non-positive but only
masks positive values locally; enforce the invariant at the public entry
point(s) that supply the gamma/ b_gamma arrays (or immediately before any exp()
use) by asserting all entries are <= 0 or by clamping b_gamma =
jnp.minimum(b_gamma, 0.0); apply this fix for every location that exponentiates
gamma-like values (symbols: b_gamma, g_gamma, g_gn, safe_diff, decay) so no
unmasked positive values reach jnp.exp and cause overflow.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 53a4b97b-6e5f-416d-a927-bcfabe1556b2

📥 Commits

Reviewing files that changed from the base of the PR and between 0379606 and ef98322.

📒 Files selected for processing (1)
  • tops/ops/common/chunk_o.py

Add test_chunk_bwd_large_gamma_no_nan with g_gamma=-0.5/-0.8/-1.0 at
chunk_size=128.  This test runs in pallas-kernel TPU CI (not ignored
like tests/ops/gla/) and exercises the Toeplitz decay overflow path
that causes NaN in production.

Also add chunk_size=128 cases to existing fwd/bwd test suites.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/ops/simple_gla/test_chunk_simple_gla_tpu.py (1)

306-337: Consider extending the NaN regression to also validate dh0 finiteness.

This test currently checks dq/dk/dv only. Adding an h0=True branch would directly cover the output_dh0 backward path as well.

Suggested extension
-@pytest.mark.parametrize("g_gamma_val", [-0.5, -0.8, -1.0])
-def test_chunk_bwd_large_gamma_no_nan(g_gamma_val):
+@pytest.mark.parametrize("g_gamma_val", [-0.5, -0.8, -1.0])
+@pytest.mark.parametrize("use_h0", [False, True])
+def test_chunk_bwd_large_gamma_no_nan(g_gamma_val, use_h0):
@@
-    q = jax.random.normal(keys[0], (B, T, H, K), dtype=jnp.bfloat16)
-    k = jax.random.normal(keys[1], (B, T, H, K), dtype=jnp.bfloat16)
-    v = jax.random.normal(keys[2], (B, T, H, V), dtype=jnp.bfloat16)
-    do = jax.random.normal(keys[3], (B, T, H, V), dtype=jnp.bfloat16)
+    q = jax.random.normal(keys[0], (B, T, H, K), dtype=jnp.bfloat16)
+    k = jax.random.normal(keys[1], (B, T, H, K), dtype=jnp.bfloat16)
+    v = jax.random.normal(keys[2], (B, T, H, V), dtype=jnp.bfloat16)
+    do = jax.random.normal(keys[3], (B, T, H, V), dtype=jnp.bfloat16)
+    h0 = jax.random.normal(keys[4], (B, H, K, V), dtype=jnp.bfloat16) if use_h0 else None
@@
-    dq, dk, dv, _dh0 = chunk_simple_gla_bwd(
+    dq, dk, dv, dh0 = chunk_simple_gla_bwd(
         q, k, v, do,
         g_gamma=g_gamma,
         scale=scale,
+        h0=h0,
         chunk_size=C,
     )
@@
-    for name, arr in [("dq", dq), ("dk", dk), ("dv", dv)]:
+    grads = [("dq", dq), ("dk", dk), ("dv", dv)]
+    if dh0 is not None:
+        grads.append(("dh0", dh0))
+    for name, arr in grads:
         assert not jnp.any(jnp.isnan(arr)), (
             f"{name} contains NaN (g_gamma={g_gamma_val}, chunk_size={C})"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/simple_gla/test_chunk_simple_gla_tpu.py` around lines 306 - 337,
The test test_chunk_bwd_large_gamma_no_nan currently asserts finiteness for
dq/dk/dv but ignores dh0 returned by chunk_simple_gla_bwd; update the test to
also check the dh0 output (variable _dh0 or rename to dh0) for NaN/Inf so the
output_dh0 backward path is covered. Specifically, capture the fourth return
value from chunk_simple_gla_bwd as dh0 and add the same jnp.any(jnp.isnan(...))
and jnp.any(jnp.isinf(...)) assertions used for dq/dk/dv, with a clear failure
message referencing g_gamma_val and chunk_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_tpu.py`:
- Around line 306-337: The test test_chunk_bwd_large_gamma_no_nan currently
asserts finiteness for dq/dk/dv but ignores dh0 returned by
chunk_simple_gla_bwd; update the test to also check the dh0 output (variable
_dh0 or rename to dh0) for NaN/Inf so the output_dh0 backward path is covered.
Specifically, capture the fourth return value from chunk_simple_gla_bwd as dh0
and add the same jnp.any(jnp.isnan(...)) and jnp.any(jnp.isinf(...)) assertions
used for dq/dk/dv, with a clear failure message referencing g_gamma_val and
chunk_size.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: eeb2de0b-8068-49a5-9d4e-f0fe18fa497b

📥 Commits

Reviewing files that changed from the base of the PR and between ef98322 and 631cebe.

📒 Files selected for processing (1)
  • tests/ops/simple_gla/test_chunk_simple_gla_tpu.py

Enforces the decay-rate invariant (g_gamma must be non-positive)
at chunk_simple_gla_fwd and chunk_simple_gla_bwd entry points,
rather than relying on kernel-local masking assumptions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@jimoosciuc jimoosciuc closed this Mar 31, 2026
@jimoosciuc jimoosciuc reopened this Mar 31, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/ops/simple_gla/chunk.py (1)

741-750: 🛠️ Refactor suggestion | 🟠 Major

Consider adding validation at the public entry point chunk_simple_gla before tracing.

The g_gamma validation assertions are placed in chunk_simple_gla_fwd (line 489) and chunk_simple_gla_bwd (line 585), but these are called during tracing by _pallas_chunk_gla_fwd and _pallas_chunk_gla_bwd respectively.

The safest location for the assertion is in the chunk_simple_gla function body itself (before line 769), which executes eagerly before custom_vjp begins tracing the forward/backward rules:

Proposed fix: add validation at entry point
 def chunk_simple_gla(
     q: jax.Array,
     k: jax.Array,
     v: jax.Array,
     g_gamma: jax.Array,
     initial_state: jax.Array | None = None,
     scale: float | None = None,
     output_final_state: bool = False,
     chunk_size: int = 64,
 ) -> tuple[jax.Array, jax.Array | None]:
     """Chunked Simple GLA using pallas-kernel's Pallas TPU kernels.
     ...
     """
+    assert g_gamma is None or (g_gamma <= 0).all(), \
+        "g_gamma must be <= 0 (decay rate); positive values overflow exp()"
     dtype = q.dtype

Then consider whether the assertions in chunk_simple_gla_fwd and chunk_simple_gla_bwd are still needed (they may cause tracing issues if those functions are ever called in a traced context).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 741 - 750, Add eager validation of
g_gamma at the public entry point chunk_simple_gla before any tracing begins:
check that g_gamma has the expected shape/dtype and non-null constraints at the
top of chunk_simple_gla (before calls to
_pallas_chunk_gla_fwd/_pallas_chunk_gla_bwd), mirroring the existing assertions
in chunk_simple_gla_fwd and chunk_simple_gla_bwd; after adding this pre-check
you can optionally remove or relax the redundant assertions inside
chunk_simple_gla_fwd and chunk_simple_gla_bwd to avoid tracing-time issues.
🧹 Nitpick comments (1)
tops/ops/simple_gla/chunk.py (1)

467-480: Missing docstring for public function chunk_simple_gla_fwd.

Per coding guidelines, public functions must have a docstring explaining business semantics and detailing tensor shapes for all inputs/outputs. This function lacks any documentation.

While not part of the current PR's changes, consider adding a docstring for maintainability:

Example docstring
def chunk_simple_gla_fwd(
    q: jax.Array,
    k: jax.Array,
    v: jax.Array,
    *,
    g: jax.Array | None = None,
    g_gamma: jax.Array | None = None,
    scale: float | None = None,
    h0: jax.Array | None = None,
    use_ht: bool = False,
    cu_seqlens_cpu: jax.Array | None = None,
    cu_seqlens_dev: jax.Array | None = None,
    chunk_size: int = 64,
) -> tuple[jax.Array, jax.Array | None]:
    """Simple GLA forward pass using chunked computation.

    Args:
        q: Query tensor [B, T, H, K]
        k: Key tensor [B, T, H, K]
        v: Value tensor [B, T, H, V]
        g: Per-element gate [B, T, H] or None
        g_gamma: Per-head decay rate [H], must be <= 0
        scale: Attention scale factor (default: K^-0.5)
        h0: Initial state [N, H, K, V] or None
        use_ht: Whether to return final state
        cu_seqlens_cpu: Cumulative sequence lengths on CPU [B+1]
        cu_seqlens_dev: Cumulative sequence lengths on device [B+1]
        chunk_size: Chunk size for block computation

    Returns:
        (o, ht): Output [B, T, H, V] and optional final state [N, H, K, V]
    """

As per coding guidelines: "All public functions must have a clear docstring... that explicitly explains the business semantics of the function and clearly details the tensor shape and dimension meaning for every input and output argument."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 467 - 480, Add a descriptive
docstring to the public function chunk_simple_gla_fwd that explains its business
semantics (chunked/simple GLA forward pass) and documents the tensor shapes and
meanings for every parameter and return value; explicitly describe q, k, v
shapes (e.g. [B, T, H, K] or whatever this module expects), optional tensors g,
g_gamma, scale, h0, booleans like use_ht, cu_seqlens_cpu/cu_seqlens_dev shapes,
chunk_size semantics, and the returned tuple (output o shape and optional final
state ht shape), so consumers and maintainers can unambiguously understand
inputs/outputs and behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 741-750: Add eager validation of g_gamma at the public entry point
chunk_simple_gla before any tracing begins: check that g_gamma has the expected
shape/dtype and non-null constraints at the top of chunk_simple_gla (before
calls to _pallas_chunk_gla_fwd/_pallas_chunk_gla_bwd), mirroring the existing
assertions in chunk_simple_gla_fwd and chunk_simple_gla_bwd; after adding this
pre-check you can optionally remove or relax the redundant assertions inside
chunk_simple_gla_fwd and chunk_simple_gla_bwd to avoid tracing-time issues.

---

Nitpick comments:
In `@tops/ops/simple_gla/chunk.py`:
- Around line 467-480: Add a descriptive docstring to the public function
chunk_simple_gla_fwd that explains its business semantics (chunked/simple GLA
forward pass) and documents the tensor shapes and meanings for every parameter
and return value; explicitly describe q, k, v shapes (e.g. [B, T, H, K] or
whatever this module expects), optional tensors g, g_gamma, scale, h0, booleans
like use_ht, cu_seqlens_cpu/cu_seqlens_dev shapes, chunk_size semantics, and the
returned tuple (output o shape and optional final state ht shape), so consumers
and maintainers can unambiguously understand inputs/outputs and behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 55869d70-f2da-4b09-9eb4-6f608ac5cb59

📥 Commits

Reviewing files that changed from the base of the PR and between 631cebe and 68407db.

📒 Files selected for processing (1)
  • tops/ops/simple_gla/chunk.py

jimoosciuc and others added 2 commits April 1, 2026 02:14
- test_chunk_simple_gla_grad_no_nan: uses chunk_simple_gla (custom_vjp)
  with jax.grad to test the full forward+backward pipeline
- test_chunk_bwd_components_no_nan: tests each backward stage
  individually (chunk_fwd_h, chunk_bwd_dh, bwd_o_pl)

Both pass on CPU interpret mode. Need TPU CI to verify.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants