Skip to content

feat: add GMM + block-wise FP8 kernels for MoE layers#141

Open
sii-xinglong wants to merge 28 commits intomainfrom
feat/gmm-fp8
Open

feat: add GMM + block-wise FP8 kernels for MoE layers#141
sii-xinglong wants to merge 28 commits intomainfrom
feat/gmm-fp8

Conversation

@sii-xinglong
Copy link
Copy Markdown
Contributor

@sii-xinglong sii-xinglong commented Apr 1, 2026

Summary

  • Add grouped matrix multiplication (GMM/TGMM) Pallas TPU kernels with block-wise FP8 e4m3fn quantization to tops.ops.gmm
  • Implement TE-style dual quantization (rowwise + columnwise) for memory-efficient backward pass (~48% reduction vs bf16 residuals)
  • Provide high-level gmm() API with jax.custom_vjp that handles quantization and differentiation automatically
  • Zero dependency on tokamax, qwix, or any external quantization framework — all code written from scratch

Production features (gap-fill)

  • group_offset — expert-parallel sharding support, wired through the full custom VJP chain (forward, residuals, backward)
  • bwd_quantize — TE-style FP8 backward: dual-quantize gradients before dgrad/wgrad matmuls. Configurable independently from forward quantization
  • LUT-based tilingtiling parameter accepts a Callable[[int,int,int], tuple[int,int,int]] for dynamic tile-size selection based on problem dimensions
  • Input validation_validate_args checks shapes/dtypes, FP8 constraint validation (tk == block_size)
  • Cost estimationpl.CostEstimate(flops, bytes_accessed) on all Pallas kernels for XLA scheduling
  • Zero padding_zero_uninitialized_memory zeroes rows outside active groups when using group_offset

Package layout

tops/ops/gmm/
  __init__.py   — public exports (gmm, gmm_kernel, fp8_quantize, LutFn, etc.)
  common.py     — TPU detection, dtype selection
  quantize.py   — fp8_quantize, fp8_dequantize, fp8_dual_quantize
  backend.py    — Pallas kernels (gmm, tgmm, gmm_fp8, tgmm_fp8) + _ref companions + make_group_metadata
  ops.py        — High-level gmm() with custom_vjp + bwd_quantize + group_offset

Key design decisions

  • Explicit scale arrays — no custom QArray/pytree types; quantized tensors are (qvalue, scale) pairs
  • Block-wise (1,128) FP8 e4m3fn for both forward and backward passes
  • Dual quantization: each operand is quantized in both rowwise and columnwise orientations during forward, so backward can reuse the correct orientation without re-quantization
  • Reference implementations (_ref) for every kernel using NumPy float64 for golden-standard testing

Test plan

  • 116 tests passing (1 skipped — TPU-only device detection test)
  • Non-quantized GMM/TGMM: Pallas kernel vs pure-JAX ref
  • FP8 GMM/TGMM: Pallas FP8 kernel vs pure-JAX FP8 ref
  • Block-wise FP8 quantize roundtrip + dual quantize consistency
  • Custom VJP: gradient shapes, finite differences, FP8 finiteness
  • group_offset: forward zeroing, backward gradients
  • bwd_quantize: auto/true/false modes, Frobenius error < 15% vs BF16
  • LUT tiling: callable dispatch, None → ValueError
  • Input validation: wrong shapes, dtypes, FP8 constraints
  • Zero padding: rows outside active groups zeroed
  • TPU integration test (requires TPU runner)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features
    • Added grouped and transposed grouped matrix-multiplication with TPU-aware dtype selection, FP8 block-wise quantization/dequantization, Pallas TPU kernels, and a high-level differentiable API exposing stable forward and backward behavior.
  • Tests
    • Large deterministic test suite covering correctness, shapes/dtypes, FP8 fidelity and roundtrips, numeric tolerances, existing-output/accumulation behavior, validation/error cases, zero-padding/group-offset semantics, and gradient (VJP) correctness.

sii-xinglong and others added 11 commits April 1, 2026 11:32
Add foundational package structure and utility functions for GMM kernels:
- is_tpu / tpu_kind / tpu_generation: device detection
- supports_bfloat16_matmul / select_input_dtype: dtype selection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implements fp8_quantize, fp8_dequantize, and fp8_dual_quantize for
symmetric per-block absmax calibration to float8_e4m3fn. These primitives
are used by the FP8 kernel variants and the high-level gmm() API.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ackend.py)

Port make_group_metadata tile-to-group mapping from megablox, add pure-JAX
gmm_ref and tgmm_ref reference implementations for correctness testing,
and _get_store_mask for per-tile row validity. 22 tests covering all
specified (M,K,N,G) combinations including edge cases.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Port the grouped matrix multiplication kernel from megablox to a
standalone Pallas implementation without QArray/qwix dependencies.
Uses PrefetchScalarGridSpec with (tiles_n, num_active_tiles, tiles_k)
grid, f32 VMEM scratch accumulator, k-remainder masking, and
group-boundary store masks. Supports transpose_rhs, existing_out
accumulation, and group_offset for sharded dispatch.

Includes 18 tests covering 7 problem sizes (both normal and
transpose_rhs), output shape/dtype, bf16 inputs, and existing_out,
all verified via interpret mode on CPU against gmm_ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add tgmm() function to backend.py for transposed grouped matrix
multiplication (lhs[group_g].T @ rhs[group_g]). Used for weight
gradients in backward pass. Also add _get_group_size helper.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add gmm_fp8_ref (pure-JAX reference with block-wise dequantization) and
gmm_fp8 (Pallas TPU kernel with per-block scale outer-product correction)
to backend.py, plus comprehensive tests in test_gmm_fp8.py.

Key design: tk must equal block_size so each k-tile maps to exactly one
scale value, enabling clean [tm,1]*[1,tn] scale broadcast after dot_general.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add tgmm_fp8_ref (pure-JAX reference) and tgmm_fp8 (Pallas TPU kernel)
for FP8-quantized transposed grouped matrix multiplication. Unlike
gmm_fp8 where scales can be applied as an outer product after the dot,
tgmm_fp8 applies scales before the dot since both scale axes are along
the contraction dimension (M).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add differentiable gmm() function that wraps backend kernels with
jax.custom_vjp to support automatic differentiation. When quantize=True,
applies TE-style dual quantization in forward and saves fp8 residuals
for memory-efficient backward pass (~48% activation memory savings).

Key design decisions:
- Non-array params (quantize, block_size, tiling, etc.) captured via
  closures rather than passed through custom_vjp
- Gradient dtype cast uses preferred_element_type (not residual dtype)
  to avoid incorrectly casting to float8_e4m3fn in quantized path
- Auto-detects interpret mode for CPU testing via is_tpu()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fix E731 lambda assignment violations and apply ruff formatting
across all GMM source and test files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… and dead code

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@chatgpt-codex-connector
Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a grouped-matrix-multiplication subsystem: Pallas TPU kernels (gmm/tgmm and FP8 variants), pure‑JAX reference implementations, FP8 block-wise quantize/dequantize utilities, a differentiable high-level gmm() with custom VJP, TPU/dtype helpers, package exports, and comprehensive tests.

Changes

Cohort / File(s) Summary
Backend & core algorithms
tops/ops/gmm/backend.py
Implements group metadata, CSR-style tile mapping, JAX reference funcs (gmm_ref, tgmm_ref, FP8 refs) and TPU Pallas kernels (gmm, tgmm, gmm_fp8, tgmm_fp8) with masking, sharding, K/M remainder handling, existing_out accumulation, and FP8 scale handling.
High-level ops & autodiff
tops/ops/gmm/ops.py
Adds differentiable gmm() with custom VJP: input validation, tiling normalization, interpret auto-selection, FP8 forward path via fp8_dual_quantize, backward paths using backend tgmm/tgmm_fp8, dtype round-tripping, and bwd_quantize resolution/validation.
Quantization utilities
tops/ops/gmm/quantize.py
Adds FP8 block-wise quantize/dequantize and dual-quantize for float8_e4m3fn: block splitting, per-block absolute-max scale computation, clamping, reshape/axis handling, and dequantize target dtype support.
TPU & dtype helpers
tops/ops/gmm/common.py
Adds TPU detection and parsing (is_tpu, tpu_kind, tpu_generation), hardware bfloat16 matmul support check, and select_input_dtype logic to pick bfloat16 vs float32.
Package exports
tops/ops/gmm/__init__.py
Introduces package-level re-exports and explicit __all__ for kernels, refs, ops, quantize utilities, and metadata builder (make_group_metadata).
Tests — utilities, kernels, FP8, VJP
tests/ops/gmm/test_common.py, tests/ops/gmm/test_gmm.py, tests/ops/gmm/test_gmm_fp8.py, tests/ops/gmm/test_gmm_ref.py, tests/ops/gmm/test_gmm_vjp.py, tests/ops/gmm/test_quantize.py, tests/ops/gmm/test_tgmm.py, tests/ops/gmm/test_tgmm_fp8.py
Adds comprehensive unit tests covering TPU-aware tolerances/skips, deterministic input factories, numerical correctness vs references (including FP8), dtype/shape checks, VJP/grad tests, validation/error cases, existing_out behavior, and FP8 block-size alignment checks. Review attention: tolerances, TPU skips, deterministic seeds, and FP8 alignment assertions.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Ops as "gmm.ops (custom_vjp)"
    participant Quant as "gmm.quantize"
    participant Backend
    participant Device as "TPU / Host"

    Caller->>Ops: gmm(lhs, rhs, group_sizes, quantize?)
    alt quantize
        Ops->>Quant: fp8_dual_quantize(lhs, rhs)
        Quant-->>Ops: (lhs_q, lhs_scale, rhs_q, rhs_scale)
        Ops->>Backend: gmm_fp8(lhs_q, lhs_scale, rhs_q, rhs_scale, ...)
    else no quantize
        Ops->>Backend: gmm(lhs, rhs, ...)
    end
    Backend->>Device: launch kernel / JIT
    Device-->>Backend: compute results
    Backend-->>Ops: forward result
    Ops-->>Caller: return output

    Note over Caller,Ops: Backward flow
    Caller->>Ops: request gradients
    Ops->>Backend: tgmm / tgmm_fp8 calls for gradients
    Backend->>Device: kernel launches for grads
    Device-->>Backend: gradient results
    Backend-->>Ops: gradients
    Ops-->>Caller: dlhs, drhs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • labyrinth-ssr

Poem

🐇 I hopped through tiles both wide and neat,

Scaled little blocks and quantized each beat.
Kernels hum on TPU night's light,
Gradients bounce — tensors take flight.
Hooray for grouped matmul, hops and byte!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.23% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add GMM + block-wise FP8 kernels for MoE layers' clearly and concisely describes the primary change: adding grouped matrix multiplication kernels with FP8 quantization support for mixture-of-experts layers.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/gmm-fp8

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (4)
tops/ops/gmm/common.py (1)

48-67: Consider adding input validation for select_input_dtype.

Per coding guidelines, public functions should enforce strict constraints on input shape and types using asserts. This function does not validate that lhs and rhs are valid arrays with .dtype attributes before accessing them.

🛡️ Optional: Add defensive input validation
 def select_input_dtype(
   lhs: jax.Array,
   rhs: jax.Array,
 ) -> jnp.dtype:
   """Select the input dtype for a matmul operation.

   Args:
     lhs: [M, K] -- left operand.
     rhs: [K, N] -- right operand.

   Returns:
     bfloat16 if both inputs are bf16 and hardware supports it, else float32.
   """
+  assert hasattr(lhs, 'dtype'), "lhs must be a JAX array with dtype"
+  assert hasattr(rhs, 'dtype'), "rhs must be a JAX array with dtype"
   if (
     lhs.dtype == jnp.bfloat16
     and rhs.dtype == jnp.bfloat16
     and supports_bfloat16_matmul()
   ):
     return jnp.bfloat16
   return jnp.float32

As per coding guidelines: "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/common.py` around lines 48 - 67, The function select_input_dtype
lacks input validation: add asserts at the start of select_input_dtype to ensure
lhs and rhs are jax.Array (or at least have .dtype), have ndim == 2, and that
their inner dimensions match (lhs.shape[1] == rhs.shape[0]) before using .dtype;
also assert that lhs.dtype and rhs.dtype are valid jnp.dtypes (or raise/convert)
to avoid attribute errors. Keep the existing logic and call to
supports_bfloat16_matmul() unchanged—these assertions should be the first
statements in select_input_dtype so invalid inputs fail fast.
tops/ops/gmm/quantize.py (1)

67-104: Consider adding input validation for fp8_dequantize.

Unlike fp8_quantize, this function doesn't validate that qvalue and scale shapes are compatible or that the axis dimension is divisible by block_size. If mismatched inputs are passed, the reshape will fail with a confusing error.

🛡️ Optional: Add defensive validation
 def fp8_dequantize(
   qvalue: jnp.ndarray,
   scale: jnp.ndarray,
   target_dtype: jnp.dtype = jnp.bfloat16,
   axis: int = -1,
   block_size: int = 128,
 ) -> jnp.ndarray:
   ...
   ndim = qvalue.ndim
   ax = axis % ndim
   dim = qvalue.shape[ax]
-  num_blocks = dim // block_size
+  assert dim % block_size == 0, (
+    f"Axis {axis} size {dim} must be divisible by block_size {block_size}"
+  )
+  num_blocks = dim // block_size

As per coding guidelines: "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/quantize.py` around lines 67 - 104, Add strict input validation
at the start of fp8_dequantize: assert block_size > 0, compute ndim/ax as
already done and assert 0 <= axis < ndim, get dim = qvalue.shape[ax] and assert
dim % block_size == 0 (use num_blocks = dim // block_size), assert scale.ndim ==
ndim and that scale.shape equals qvalue.shape with the axis dimension replaced
by num_blocks (i.e., scale.shape == (*qvalue.shape[:ax], num_blocks,
*qvalue.shape[ax+1:])), and assert scale.dtype == jnp.float32; include clear
assertion messages referencing qvalue, scale, axis, block_size, num_blocks and
fail fast so reshape operations later (new_shape, scale_expanded, out) are safe.
tests/ops/gmm/test_gmm.py (1)

163-185: existing_out test currently doesn’t validate existing_out semantics.

Line 165 uses group_sizes that cover all M rows, so comparing result to ref on Line 180 effectively retests plain GMM. Consider a partial-coverage case (e.g., non-default num_actual_groups/offset) and assert untouched regions remain from existing_out.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm.py` around lines 163 - 185, The test
test_gmm_existing_out currently covers all M rows so it just rechecks plain GMM;
modify it to use a partial-coverage group layout (e.g., call _make_inputs or
construct group_sizes so num_actual_groups < G or include a row offset so only a
subset of rows are active) and pass an existing_out buffer, then call gmm and
assert two things: (1) the active regions of result equal gmm_ref(lhs, rhs,
group_sizes, ...) and (2) the inactive/uncovered rows in result remain unchanged
and equal the original existing_out values; update the assertions accordingly to
check both overwritten and untouched regions using gmm, gmm_ref, existing_out,
and _make_inputs to locate the changes.
tests/ops/gmm/test_gmm_vjp.py (1)

86-104: Add drhs numerical-gradient spot checks.

Line 86 currently validates finite differences only for dlhs. Add a small spot-check loop for analytical_grads[1] too, so regressions in the wgrad path are caught.

Proposed test extension
   analytical_grads = jax.grad(loss_fn, argnums=(0, 1))(lhs_f32, rhs_f32)

@@
   for idx in [(0, 0), (0, 1), (cfg["M"] // 2, cfg["K"] // 2)]:
@@
     np.testing.assert_allclose(
       float(analytical_grads[0][idx]),
       dlhs_numerical_val,
       atol=1e-1,
       rtol=1e-1,
     )
+
+  # Spot-check drhs too.
+  for idx in [(0, 0, 0), (0, 1, 1), (cfg["G"] - 1, cfg["K"] // 2, cfg["N"] // 2)]:
+    rhs_plus = rhs_f32.at[idx].add(eps)
+    rhs_minus = rhs_f32.at[idx].add(-eps)
+    drhs_numerical_val = (
+      float(loss_fn(lhs_f32, rhs_plus)) - float(loss_fn(lhs_f32, rhs_minus))
+    ) / (2 * eps)
+    np.testing.assert_allclose(
+      float(analytical_grads[1][idx]),
+      drhs_numerical_val,
+      atol=1e-1,
+      rtol=1e-1,
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vjp.py` around lines 86 - 104, Add symmetric
finite-difference spot checks for drhs analogous to the existing dlhs checks:
for a few element indices in rhs (use cfg to pick middle and small indices),
create rhs_plus and rhs_minus by adding/subtracting eps to rhs_f32 at that
index, compute numerical derivative as (loss_fn(lhs_f32,
rhs_plus)-loss_fn(lhs_f32, rhs_minus))/(2*eps), and assert that
analytical_grads[1][idx] matches the numerical value with
np.testing.assert_allclose using the same atol/rtol as the dlhs checks; this
verifies the wgrad path given analytical_grads and loss_fn already compute both
argnums (0,1).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/gmm/test_gmm_fp8.py`:
- Around line 83-88: Replace the direct numpy assertion with the repository
helper compare_tensor: instead of np.testing.assert_allclose(np.array(result),
np.array(expected), rtol=1e-4, atol=1e-4) call compare_tensor(result, expected,
rtol=1e-4, atol=1e-4) (ensuring compare_tensor is imported in the test file);
keep using the same result and expected variables but pass them directly to
compare_tensor so the test follows the standard kernel vs reference comparison
pattern.

In `@tests/ops/gmm/test_gmm.py`:
- Around line 69-74: Replace direct np.testing.assert_allclose checks with the
project's compare_tensor test helper: import compare_tensor if missing, then
call compare_tensor(result, expected, rtol=1e-4, atol=1e-4) (or matching
tolerances) wherever np.testing.assert_allclose is used (the shown call and the
other occurrences mentioned around the other ranges). Update the assertions in
the kernel reference tests (the compare between optimized kernel outputs and
naive CPU references) to use compare_tensor to conform to repo standards.

In `@tests/ops/gmm/test_tgmm_fp8.py`:
- Around line 87-92: Replace the direct NumPy assertion
np.testing.assert_allclose in the test_tgmm_fp8 test with the shared compare
helper compare_tensor: locate the np.testing.assert_allclose call that compares
result and expected and call compare_tensor(result, expected, rtol=1e-4,
atol=1e-4) instead so the TGMM FP8 kernel-vs-reference validation uses the
standard test-suite comparator; ensure you import or reference compare_tensor as
used by other tests in the suite if not already available.

In `@tests/ops/gmm/test_tgmm.py`:
- Around line 71-76: Replace the np.testing.assert_allclose usages in this test
with the repository's compare_tensor utility: locate the assertions that call
np.testing.assert_allclose comparing result and expected (lines around the shown
diff and also the similar blocks at 127-132 and 152-157) and change them to use
compare_tensor(result, expected, rtol=1e-4, atol=1e-4). Ensure the
compare_tensor symbol is imported in tests/ops/gmm/test_tgmm.py if not already
present and remove the old np.testing.assert_allclose calls to standardize
tolerance-based kernel/reference checks.

---

Nitpick comments:
In `@tests/ops/gmm/test_gmm_vjp.py`:
- Around line 86-104: Add symmetric finite-difference spot checks for drhs
analogous to the existing dlhs checks: for a few element indices in rhs (use cfg
to pick middle and small indices), create rhs_plus and rhs_minus by
adding/subtracting eps to rhs_f32 at that index, compute numerical derivative as
(loss_fn(lhs_f32, rhs_plus)-loss_fn(lhs_f32, rhs_minus))/(2*eps), and assert
that analytical_grads[1][idx] matches the numerical value with
np.testing.assert_allclose using the same atol/rtol as the dlhs checks; this
verifies the wgrad path given analytical_grads and loss_fn already compute both
argnums (0,1).

In `@tests/ops/gmm/test_gmm.py`:
- Around line 163-185: The test test_gmm_existing_out currently covers all M
rows so it just rechecks plain GMM; modify it to use a partial-coverage group
layout (e.g., call _make_inputs or construct group_sizes so num_actual_groups <
G or include a row offset so only a subset of rows are active) and pass an
existing_out buffer, then call gmm and assert two things: (1) the active regions
of result equal gmm_ref(lhs, rhs, group_sizes, ...) and (2) the
inactive/uncovered rows in result remain unchanged and equal the original
existing_out values; update the assertions accordingly to check both overwritten
and untouched regions using gmm, gmm_ref, existing_out, and _make_inputs to
locate the changes.

In `@tops/ops/gmm/common.py`:
- Around line 48-67: The function select_input_dtype lacks input validation: add
asserts at the start of select_input_dtype to ensure lhs and rhs are jax.Array
(or at least have .dtype), have ndim == 2, and that their inner dimensions match
(lhs.shape[1] == rhs.shape[0]) before using .dtype; also assert that lhs.dtype
and rhs.dtype are valid jnp.dtypes (or raise/convert) to avoid attribute errors.
Keep the existing logic and call to supports_bfloat16_matmul() unchanged—these
assertions should be the first statements in select_input_dtype so invalid
inputs fail fast.

In `@tops/ops/gmm/quantize.py`:
- Around line 67-104: Add strict input validation at the start of
fp8_dequantize: assert block_size > 0, compute ndim/ax as already done and
assert 0 <= axis < ndim, get dim = qvalue.shape[ax] and assert dim % block_size
== 0 (use num_blocks = dim // block_size), assert scale.ndim == ndim and that
scale.shape equals qvalue.shape with the axis dimension replaced by num_blocks
(i.e., scale.shape == (*qvalue.shape[:ax], num_blocks, *qvalue.shape[ax+1:])),
and assert scale.dtype == jnp.float32; include clear assertion messages
referencing qvalue, scale, axis, block_size, num_blocks and fail fast so reshape
operations later (new_shape, scale_expanded, out) are safe.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1a684cb9-d4d3-4c01-a1cd-299732c84293

📥 Commits

Reviewing files that changed from the base of the PR and between ccef5bd and fecd3fc.

📒 Files selected for processing (14)
  • tests/ops/gmm/__init__.py
  • tests/ops/gmm/test_common.py
  • tests/ops/gmm/test_gmm.py
  • tests/ops/gmm/test_gmm_fp8.py
  • tests/ops/gmm/test_gmm_ref.py
  • tests/ops/gmm/test_gmm_vjp.py
  • tests/ops/gmm/test_quantize.py
  • tests/ops/gmm/test_tgmm.py
  • tests/ops/gmm/test_tgmm_fp8.py
  • tops/ops/gmm/__init__.py
  • tops/ops/gmm/backend.py
  • tops/ops/gmm/common.py
  • tops/ops/gmm/ops.py
  • tops/ops/gmm/quantize.py

Comment on lines +83 to +88
np.testing.assert_allclose(
np.array(result),
np.array(expected),
rtol=1e-4,
atol=1e-4,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Switch this kernel/reference tolerance assertion to compare_tensor.

The comparison intent is correct, but this should use the repository’s standard comparison helper for consistency.

As per coding guidelines "Each JAX/Pallas kernel must have a corresponding CPU reference test that compares the optimized kernel against naive implementations with tolerance-based assertions using compare_tensor utility".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_fp8.py` around lines 83 - 88, Replace the direct numpy
assertion with the repository helper compare_tensor: instead of
np.testing.assert_allclose(np.array(result), np.array(expected), rtol=1e-4,
atol=1e-4) call compare_tensor(result, expected, rtol=1e-4, atol=1e-4) (ensuring
compare_tensor is imported in the test file); keep using the same result and
expected variables but pass them directly to compare_tensor so the test follows
the standard kernel vs reference comparison pattern.

Comment on lines +69 to +74
np.testing.assert_allclose(
np.array(result),
np.array(expected),
rtol=1e-4,
atol=1e-4,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use compare_tensor instead of direct np.testing.assert_allclose in kernel reference checks.

These assertions are doing the right numerical comparison, but the test helper expected by the repo standards is not used.

As per coding guidelines "Each JAX/Pallas kernel must have a corresponding CPU reference test that compares the optimized kernel against naive implementations with tolerance-based assertions using compare_tensor utility".

Also applies to: 101-106, 156-161, 180-185

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm.py` around lines 69 - 74, Replace direct
np.testing.assert_allclose checks with the project's compare_tensor test helper:
import compare_tensor if missing, then call compare_tensor(result, expected,
rtol=1e-4, atol=1e-4) (or matching tolerances) wherever
np.testing.assert_allclose is used (the shown call and the other occurrences
mentioned around the other ranges). Update the assertions in the kernel
reference tests (the compare between optimized kernel outputs and naive CPU
references) to use compare_tensor to conform to repo standards.

Comment on lines +87 to +92
np.testing.assert_allclose(
np.array(result),
np.array(expected),
rtol=1e-4,
atol=1e-4,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use compare_tensor for TGMM FP8 kernel-vs-reference validation.

Please replace direct np.testing.assert_allclose here with the shared compare helper used by the test suite standards.

As per coding guidelines "Each JAX/Pallas kernel must have a corresponding CPU reference test that compares the optimized kernel against naive implementations with tolerance-based assertions using compare_tensor utility".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_tgmm_fp8.py` around lines 87 - 92, Replace the direct
NumPy assertion np.testing.assert_allclose in the test_tgmm_fp8 test with the
shared compare helper compare_tensor: locate the np.testing.assert_allclose call
that compares result and expected and call compare_tensor(result, expected,
rtol=1e-4, atol=1e-4) instead so the TGMM FP8 kernel-vs-reference validation
uses the standard test-suite comparator; ensure you import or reference
compare_tensor as used by other tests in the suite if not already available.

Comment on lines +71 to +76
np.testing.assert_allclose(
np.array(result),
np.array(expected),
rtol=1e-4,
atol=1e-4,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Please standardize tolerance-based tensor checks with compare_tensor.

The assertions are valid numerically, but this test file should use the repository’s comparison utility for kernel/reference parity checks.

As per coding guidelines "Each JAX/Pallas kernel must have a corresponding CPU reference test that compares the optimized kernel against naive implementations with tolerance-based assertions using compare_tensor utility".

Also applies to: 127-132, 152-157

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_tgmm.py` around lines 71 - 76, Replace the
np.testing.assert_allclose usages in this test with the repository's
compare_tensor utility: locate the assertions that call
np.testing.assert_allclose comparing result and expected (lines around the shown
diff and also the similar blocks at 127-132 and 152-157) and change them to use
compare_tensor(result, expected, rtol=1e-4, atol=1e-4). Ensure the
compare_tensor symbol is imported in tests/ops/gmm/test_tgmm.py if not already
present and remove the old np.testing.assert_allclose calls to standardize
tolerance-based kernel/reference checks.

- Load full scale dimension in FP8 kernels to satisfy TPU Pallas
  requirement (block dims must be divisible by 8/128 or equal to array
  dim). Use jax.lax.dynamic_slice to select the correct scale block
  inside the kernel.
- Relax test tolerances: bf16 MXU ~2e-2, FP8 ~5e-2 (only on TPU;
  interpret mode keeps tight 1e-4).
- Increase VJP numerical gradient eps from 1e-3 to 0.1 on TPU (bf16
  MXU truncation loses small perturbations).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
tops/ops/gmm/backend.py (3)

291-302: LGTM. Minor: prefix unused M with underscore.

Same as gmm_ref - the static analyzer flagged M (line 294) as unused. Consider _M, K = lhs.shape.

🧹 Optional fix for unused variable
-  M, K = lhs.shape
+  _M, K = lhs.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 291 - 302, The variable M is assigned
but unused in the gmm_ref implementation; rename M to _M in the unpacking (e.g.,
change "M, K = lhs.shape" to "_M, K = lhs.shape") to satisfy the static analyzer
and follow the unused-variable convention, leaving all other logic in functions
like gmm_ref and uses of K unchanged.

233-247: LGTM. Minor: prefix unused K with underscore.

The implementation is correct. The static analyzer flagged K (line 236) as unused. Consider renaming to _K to silence the warning, though this is purely cosmetic.

🧹 Optional fix for unused variable
-  M, K = lhs.shape
+  M, _K = lhs.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 233 - 247, The variable K assigned from
lhs.shape in the block starting with "M, K = lhs.shape" is unused and flagged by
static analysis; rename it to _K (i.e., use "M, _K = lhs.shape") to signal it's
intentionally unused and silence the warning, leaving the rest of the logic
using lhs, rhs, transpose_rhs, group_sizes, lhs_np, rhs_np, and out unchanged.

576-580: Magic number for existing_out_arg_index is fragile.

The comment explains the calculation, but if the number of prefetch arrays or input specs changes, this index will silently become incorrect. Consider documenting this more explicitly or deriving it programmatically.

📝 Suggested documentation improvement
   # Prefetch args: group_metadata (3 arrays), group_offset (1 array) = 4
   # Then: lhs, rhs, existing_out -> existing_out is arg index 6
+  # NOTE: Update this index if in_specs order changes.
   existing_out_arg_index = 6
   input_output_aliases = {existing_out_arg_index: 0}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 576 - 580, The hard-coded
existing_out_arg_index = 6 is fragile; instead compute it from the actual counts
so it stays correct if prefetch or input lists change: determine the number of
prefetch args (e.g., group_metadata arrays + group_offset arrays) and the number
of preceding input args, then set existing_out_arg_index = prefetch_count +
num_preceding_inputs (and update input_output_aliases = {existing_out_arg_index:
0}); update the nearby comment to state how existing_out_arg_index is derived
and reference in_out_block_spec / out_block_spec to show which inputs are being
counted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/gmm/test_gmm.py`:
- Around line 167-189: The test docstring for test_gmm_existing_out is
misleading: it says "accumulates into the provided buffer" but the
implementation and assertion show existing_out is overwritten in active regions;
update the docstring (and any inline comment) to state that existing_out is
overwritten (or "overwritten in active regions") rather than accumulated,
referencing the test function name test_gmm_existing_out and the gmm / gmm_ref
behavior to guide the change.

In `@tops/ops/gmm/backend.py`:
- Around line 129-131: The histogram range currently uses range=(0, tiles_m - 1)
which excludes the last tile index; update the jnp.histogram call that computes
tile_visits (the line using partial_tile_ids, tiles_m and jnp.histogram) to use
range=(0, tiles_m) so all tile indices 0..tiles_m-1 are binned correctly.

---

Nitpick comments:
In `@tops/ops/gmm/backend.py`:
- Around line 291-302: The variable M is assigned but unused in the gmm_ref
implementation; rename M to _M in the unpacking (e.g., change "M, K = lhs.shape"
to "_M, K = lhs.shape") to satisfy the static analyzer and follow the
unused-variable convention, leaving all other logic in functions like gmm_ref
and uses of K unchanged.
- Around line 233-247: The variable K assigned from lhs.shape in the block
starting with "M, K = lhs.shape" is unused and flagged by static analysis;
rename it to _K (i.e., use "M, _K = lhs.shape") to signal it's intentionally
unused and silence the warning, leaving the rest of the logic using lhs, rhs,
transpose_rhs, group_sizes, lhs_np, rhs_np, and out unchanged.
- Around line 576-580: The hard-coded existing_out_arg_index = 6 is fragile;
instead compute it from the actual counts so it stays correct if prefetch or
input lists change: determine the number of prefetch args (e.g., group_metadata
arrays + group_offset arrays) and the number of preceding input args, then set
existing_out_arg_index = prefetch_count + num_preceding_inputs (and update
input_output_aliases = {existing_out_arg_index: 0}); update the nearby comment
to state how existing_out_arg_index is derived and reference in_out_block_spec /
out_block_spec to show which inputs are being counted.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0c6f5269-b6e7-4e3f-b7e5-227475951e61

📥 Commits

Reviewing files that changed from the base of the PR and between fecd3fc and 61d8475.

📒 Files selected for processing (6)
  • tests/ops/gmm/test_gmm.py
  • tests/ops/gmm/test_gmm_fp8.py
  • tests/ops/gmm/test_gmm_vjp.py
  • tests/ops/gmm/test_tgmm.py
  • tests/ops/gmm/test_tgmm_fp8.py
  • tops/ops/gmm/backend.py
✅ Files skipped from review due to trivial changes (2)
  • tests/ops/gmm/test_gmm_vjp.py
  • tests/ops/gmm/test_tgmm_fp8.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/gmm/test_gmm_fp8.py

…TPU)

jax.lax.dynamic_slice is not implemented in Pallas TPU lowering.
Replace with pl.ds() Ref indexing which is a native Pallas primitive.
Also widen tolerances: bf16 atol=0.2, FP8 atol=0.5.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
tests/ops/gmm/test_gmm.py (1)

167-189: ⚠️ Potential issue | 🟡 Minor

Clarify docstring: existing_out is overwritten, not accumulated.

The docstring says "accumulates into the provided buffer" but the kernel implementation (lines 480-508 in backend.py) shows that existing_out is overwritten in active regions, not accumulated. The test assertion (result == ref) and the inline comment on line 182 are correct; only the docstring is misleading.

Note: This differs from tgmm which does accumulate (to_store += existing_out_ref).

📝 Suggested docstring fix
   def test_gmm_existing_out(self):
-    """GMM with existing_out accumulates into the provided buffer."""
+    """GMM with existing_out overwrites active regions, preserving inactive."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm.py` around lines 167 - 189, Update the test docstring
in test_gmm_existing_out to state that existing_out is overwritten (not
accumulated) by the gmm kernel: mention that gmm(lhs, rhs, group_sizes,
existing_out=existing, ...) overwrites active regions of the provided buffer,
matching the assertion that result == gmm_ref(...); you can optionally note that
this differs from tgmm which accumulates (to_store += existing_out_ref), so the
docstring should say "overwrites the provided buffer in active regions" instead
of "accumulates into the provided buffer."
tops/ops/gmm/backend.py (1)

129-131: ⚠️ Potential issue | 🟡 Minor

Fix histogram range to properly bin all tile indices.

The range=(0, tiles_m - 1) argument creates bins that don't properly cover tile index tiles_m - 1. This can cause incorrect tile visit counts for the last tile.

🐛 Proposed fix
   tile_visits = (
-    jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1
+    jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m))[0] + 1
   )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 129 - 131, The histogram call that
computes tile_visits using jnp.histogram(partial_tile_ids, bins=tiles_m,
range=(0, tiles_m - 1)) excludes the last tile index; update the range to (0,
tiles_m) so all integer tile indices 0..tiles_m-1 are properly binned (i.e.,
change the range argument in the jnp.histogram call that uses partial_tile_ids
and tiles_m to (0, tiles_m)), keeping bins=tiles_m and the +1 offset unchanged.
🧹 Nitpick comments (3)
tops/ops/gmm/backend.py (3)

291-314: Consider prefixing unused M with underscore.

Static analysis flags that M (line 294) is unpacked but never explicitly used. Same suggestion as for gmm_ref:

-  M, K = lhs.shape
+  _, K = lhs.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 291 - 314, The variable M is unpacked
but unused; change the tuple unpack in this block (where M, K = lhs.shape) to
use an underscore-prefixed name (e.g., _M, K = lhs.shape) to signal it's
intentionally unused and satisfy static analysis; update any references if
present (the block around out computation that uses K and group sizes in
functions like the current group-matrix multiply, similar to the gmm_ref change)
and keep types and return (jnp.array(out, dtype=preferred_element_type))
unchanged.

233-262: Consider prefixing unused K with underscore.

Static analysis flags that K (line 236) is unpacked but never explicitly used in the function body. While the shape unpacking documents the expected layout, the linter warning can be silenced:

-  M, K = lhs.shape
+  M, _ = lhs.shape

Alternatively, if K is useful for documentation purposes, consider adding a brief comment or using it in an assertion like assert lhs.shape[1] == rhs.shape[1].

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 233 - 262, The variable K returned when
unpacking M, K = lhs.shape is unused and triggers linter warnings; either rename
it to _K (or _k) to indicate it is intentionally unused, or add a short
assertion using K (for example assert lhs.shape[1] == rhs.shape[2] or similar)
to make its purpose explicit; update the symbol in the unpacking line (M, K ->
M, _K) or add the assertion after the shapes are computed, and leave all
remaining logic (lhs, rhs, group_sizes loop, out construction, return
jnp.array(...)) unchanged.

60-88: Add input validation to make_group_metadata.

Per coding guidelines, public functions should enforce strict constraints on input shapes and types. This function lacks validation for:

  • group_sizes being 1-D
  • m, tm being positive integers
  • start_group bounds relative to group_sizes.shape[0]

Consider adding assertions at the function entry:

🛡️ Suggested validation
 def make_group_metadata(
   *,
   group_sizes: jnp.ndarray,
   m: int,
   tm: int,
   start_group: jnp.ndarray,
   num_nonzero_groups: int,
   visit_empty_groups: bool = False,
 ) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:
   """Build tile-to-group mapping for Pallas grid dispatch.
   ...
   """
+  assert group_sizes.ndim == 1, f"group_sizes must be 1-D, got {group_sizes.ndim}-D"
+  assert m > 0 and tm > 0, f"m ({m}) and tm ({tm}) must be positive"
   num_groups = group_sizes.shape[0]

As per coding guidelines: "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 60 - 88, Add explicit input assertions
at the start of make_group_metadata to validate shapes and types: assert
group_sizes is a 1-D jnp.ndarray (group_sizes.ndim == 1) and not empty, assert m
and tm are positive integers (m > 0, tm > 0) and that m equals the total rows
implied by group_sizes (m == int(jnp.sum(group_sizes))), assert start_group is a
scalar integer (0 <= int(start_group) < group_sizes.shape[0]), and assert
num_nonzero_groups is non-negative; also optionally assert visit_empty_groups is
a bool—place these checks at the top of make_group_metadata to fail fast and
reference the parameters group_sizes, m, tm, start_group, num_nonzero_groups,
and visit_empty_groups.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@tests/ops/gmm/test_gmm.py`:
- Around line 167-189: Update the test docstring in test_gmm_existing_out to
state that existing_out is overwritten (not accumulated) by the gmm kernel:
mention that gmm(lhs, rhs, group_sizes, existing_out=existing, ...) overwrites
active regions of the provided buffer, matching the assertion that result ==
gmm_ref(...); you can optionally note that this differs from tgmm which
accumulates (to_store += existing_out_ref), so the docstring should say
"overwrites the provided buffer in active regions" instead of "accumulates into
the provided buffer."

In `@tops/ops/gmm/backend.py`:
- Around line 129-131: The histogram call that computes tile_visits using
jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1)) excludes
the last tile index; update the range to (0, tiles_m) so all integer tile
indices 0..tiles_m-1 are properly binned (i.e., change the range argument in the
jnp.histogram call that uses partial_tile_ids and tiles_m to (0, tiles_m)),
keeping bins=tiles_m and the +1 offset unchanged.

---

Nitpick comments:
In `@tops/ops/gmm/backend.py`:
- Around line 291-314: The variable M is unpacked but unused; change the tuple
unpack in this block (where M, K = lhs.shape) to use an underscore-prefixed name
(e.g., _M, K = lhs.shape) to signal it's intentionally unused and satisfy static
analysis; update any references if present (the block around out computation
that uses K and group sizes in functions like the current group-matrix multiply,
similar to the gmm_ref change) and keep types and return (jnp.array(out,
dtype=preferred_element_type)) unchanged.
- Around line 233-262: The variable K returned when unpacking M, K = lhs.shape
is unused and triggers linter warnings; either rename it to _K (or _k) to
indicate it is intentionally unused, or add a short assertion using K (for
example assert lhs.shape[1] == rhs.shape[2] or similar) to make its purpose
explicit; update the symbol in the unpacking line (M, K -> M, _K) or add the
assertion after the shapes are computed, and leave all remaining logic (lhs,
rhs, group_sizes loop, out construction, return jnp.array(...)) unchanged.
- Around line 60-88: Add explicit input assertions at the start of
make_group_metadata to validate shapes and types: assert group_sizes is a 1-D
jnp.ndarray (group_sizes.ndim == 1) and not empty, assert m and tm are positive
integers (m > 0, tm > 0) and that m equals the total rows implied by group_sizes
(m == int(jnp.sum(group_sizes))), assert start_group is a scalar integer (0 <=
int(start_group) < group_sizes.shape[0]), and assert num_nonzero_groups is
non-negative; also optionally assert visit_empty_groups is a bool—place these
checks at the top of make_group_metadata to fail fast and reference the
parameters group_sizes, m, tm, start_group, num_nonzero_groups, and
visit_empty_groups.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 13dbb9c1-702e-45a4-a74a-98f9e56cdeee

📥 Commits

Reviewing files that changed from the base of the PR and between 61d8475 and d7870bc.

📒 Files selected for processing (5)
  • tests/ops/gmm/test_gmm.py
  • tests/ops/gmm/test_gmm_fp8.py
  • tests/ops/gmm/test_tgmm.py
  • tests/ops/gmm/test_tgmm_fp8.py
  • tops/ops/gmm/backend.py
✅ Files skipped from review due to trivial changes (1)
  • tests/ops/gmm/test_gmm_fp8.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/gmm/test_tgmm_fp8.py

… alignment

Replace pl.ds(k_i, 1) dynamic indexing along the last dimension with
broadcasted_iota one-hot masking. TPU Mosaic requires indices along
dim -1 to be multiples of 128 (E2003 alignment error), but k_i is a
runtime grid index that can't satisfy this. Loading the full scale block
and selecting via multiply+sum avoids the alignment constraint entirely.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/gmm/backend.py`:
- Around line 1060-1077: The kernel path must reject cases where K is not a
multiple of block_size because tiles_k is rounded up while num_k_blocks = k //
block_size rounds down, causing the last partial K tile to be dropped; in
gmm_fp8 (around the block computing tiles_k, tiles_n and num_k_blocks) add a
guard that asserts or raises when k % block_size != 0 (before using num_k_blocks
and validating lhs_scale/rhs_scale), so the code mirrors the ref-path behavior
and ensures lhs_scale/rhs_scale shape checks are correct for all valid inputs.
- Around line 334-352: The current _validate_args only checks ranks and dtype
but must also enforce group_sizes invariants: verify all entries in group_sizes
are non-negative, sum(group_sizes) == lhs.shape[0], and that any requested group
window (as used by TGMM/ref helpers and Pallas entry points) fits inside each
group (i.e., window start+width <= group size). Add these checks inside
_validate_args (or a small helper called from it) using assert statements (or
tops.utils assertion helpers) and reuse this validator from the ref helpers and
Pallas entry points; reference the functions _validate_args,
make_group_metadata, and the TGMM/ref/Pallas entry-point functions so callers
use the same strict validation before building group metadata or launching
kernels.
- Around line 1427-1445: The CPU reference tgmm_fp8_ref() incorrectly indexes
lhs_scale/rhs_scale by a group-local scale_row (mb * block_size) rather than the
global M-block ids, so when a group starts mid-FP8 block the dequantization uses
the wrong scale for some rows; fix by iterating over the intersecting global
M-block ids for each group (use lhs_q_np/rhs_q_np row ranges clipped to each
global block) and multiply each clipped slice by the corresponding
lhs_scale/rhs_scale row (instead of a single scale_row per mb), ensuring out[g]
accumulation still sums lhs_block.T @ rhs_block for each clipped interval; then
add a regression test that constructs a tensor where a group boundary falls
inside an FP8 M block and compares the JAX/Pallas kernel output to
tgmm_fp8_ref() using the compare_tensor utility with appropriate tolerances.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ba46b50-c56e-45a3-936f-71b58196b6c2

📥 Commits

Reviewing files that changed from the base of the PR and between d7870bc and 0377a4a.

📒 Files selected for processing (1)
  • tops/ops/gmm/backend.py

Comment on lines +334 to +352
def _validate_args(
*,
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
expected_rhs_dims: int = 3,
) -> jnp.ndarray:
"""Validate argument shapes and dtypes for GMM / TGMM kernels."""
if lhs.ndim != 2:
raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.")
if rhs.ndim != expected_rhs_dims:
raise ValueError(
f"Expected {expected_rhs_dims}-tensor for 'rhs' but got {rhs.ndim}-tensor."
)
if group_sizes.dtype != jnp.int32:
raise ValueError(
f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}."
)
return group_sizes
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Tighten the shared metadata validation before launch.

_validate_args() only checks ranks and group_sizes.dtype, but every backend path below assumes group_sizes is non-negative, sums to lhs.shape[0], and that the requested group window stays inside group_sizes. When those invariants are violated, make_group_metadata() can build truncated mappings and the TGMM paths can leave trailing groups unwritten instead of failing fast. Please add a shared pre-launch validator here (or a small companion next to it) and reuse it from the ref helpers and the Pallas entry points.

As per coding guidelines, "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions (or utilities like assert_shape_or_none from tops.utils) before executing the main logic".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 334 - 352, The current _validate_args
only checks ranks and dtype but must also enforce group_sizes invariants: verify
all entries in group_sizes are non-negative, sum(group_sizes) == lhs.shape[0],
and that any requested group window (as used by TGMM/ref helpers and Pallas
entry points) fits inside each group (i.e., window start+width <= group size).
Add these checks inside _validate_args (or a small helper called from it) using
assert statements (or tops.utils assertion helpers) and reuse this validator
from the ref helpers and Pallas entry points; reference the functions
_validate_args, make_group_metadata, and the TGMM/ref/Pallas entry-point
functions so callers use the same strict validation before building group
metadata or launching kernels.

sii-xinglong and others added 5 commits April 1, 2026 17:28
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wire the backend group_offset parameter through the public gmm()
function and its custom VJP (forward, residuals, backward), enabling
MoE layers to skip initial expert groups when using expert parallelism.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Allow decoupling forward and backward FP8 quantization via bwd_quantize
parameter. When quantize=True and bwd_quantize=False, residuals are
dequantized and BF16 kernels are used for the backward pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
tops/ops/gmm/backend.py (3)

1095-1112: ⚠️ Potential issue | 🔴 Critical

Reject partial K blocks in gmm_fp8().

tiles_k rounds up, but num_k_blocks = k // block_size rounds down. When k % block_size != 0, the last kernel iteration has no matching scale column/row, so the K tail is silently dropped. The ref path already rejects this case, so the kernel should do the same before validating lhs_scale / rhs_scale.

Suggested guard
   tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
   tiles_n, _ = _calculate_irregular_num_tiles(n, tn)
 
   # Validate scale shapes.
+  if k % block_size != 0:
+    raise ValueError(
+      f"K ({k}) must be divisible by block_size ({block_size}) for FP8 GMM."
+    )
   num_k_blocks = k // block_size
   assert lhs_scale.shape == (m, num_k_blocks), (
     f"lhs_scale shape {lhs_scale.shape} != expected ({m}, {num_k_blocks})"
   )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 1095 - 1112, In gmm_fp8(), reject
partial K blocks before validating lhs_scale/rhs_scale: check if k % block_size
!= 0 (using variables k and block_size) and raise/assert an error (same behavior
as the ref path) so the kernel does not silently drop the K tail; add this guard
before the existing scale-shape validation that uses num_k_blocks and tiles_k to
ensure num_k_blocks = k // block_size is valid for all iterations.

1462-1480: ⚠️ Potential issue | 🟠 Major

Dequantize tgmm_fp8_ref() on global M-block boundaries.

mb * block_size walks each group in local chunks, but lhs_scale and rhs_scale are indexed in global M // block_size space. If a group starts mid-block, the first slice spans two scale rows and Lines 1471 and 1476 apply only the first one, so the reference no longer matches the kernel. Iterate over the intersecting global block ids and clip each slice to the group, then add a regression where a group boundary lands inside an FP8 M block.

Based on learnings, "Each JAX/Pallas kernel must have a corresponding CPU reference test that compares the optimized kernel against naive implementations with tolerance-based assertions using compare_tensor utility".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 1462 - 1480, The dequantization in
tgmm_fp8_ref() incorrectly assumes lhs_scale_rhs_scale align with group-local mb
indices; when a group starts mid-global M block the slice m_start:m_end may span
multiple global scale rows so using scale_row = m_start // block_size only
applies the first partial row. Fix tgmm_fp8_ref() by iterating over the
intersecting global M-block indices for each group (compute global_block_id
range that intersects [m_start, m_end)), clip each global block slice to the
group's local range, multiply each clipped lhs_q_np and rhs_q_np slice by the
corresponding lhs_scale and rhs_scale rows and accumulate, and add a regression
test where a group boundary falls inside an FP8 M block to assert the kernel
matches the corrected CPU reference using compare_tensor tolerances.

361-379: ⚠️ Potential issue | 🟠 Major

Validate the active group window before building metadata.

_validate_args() still accepts non-1D / negative group_sizes, totals that do not match lhs.shape[0], and shard windows that can run past group_sizes.shape[0]. Once those slip through, make_group_metadata() and _zero_uninitialized_memory() can build truncated mappings or index past group_offsets instead of failing fast. Please add the shared checks here and reuse them from the ref helpers as well as the Pallas entry points.

As per coding guidelines, "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions (or utilities like assert_shape_or_none from tops.utils) before executing the main logic".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 361 - 379, Update _validate_args to
enforce strict checks on group_sizes: assert it's a 1-D int32 vector
(group_sizes.ndim == 1), assert all entries are non-negative (>= 0), and assert
that group_sizes.sum() equals lhs.shape[0]; also validate that any shard/window
indices derived later will be in-bounds by ensuring group_sizes.shape[0] is
large enough for downstream indexing (so callers cannot pass windows that index
past group_sizes). Replace ad-hoc checks in make_group_metadata and
_zero_uninitialized_memory by calling this strengthened _validate_args (or the
shared ref/pallas helper validator) at the Pallas entry points; use assert-style
checks per coding guidelines so failures occur fast and locally.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/gmm/ops.py`:
- Around line 141-167: The backward branches must mirror the forward transpose:
when transpose_rhs is True use the quantization pair and shapes for rhs
transposed (swap rhs_rq/rhs_rs with rhs_cq/rhs_cs or use rhs.T equivalents) and
form drhs/dlhs with the correct axis order so drhs has shape [G,N,K] for
transpose_rhs=True and [G,K,N] otherwise; update the gmm_fp8 call (dlhs) to pass
the correct (rhs quant pair) and tiling slice depending on transpose_rhs and
update the tgmm_fp8 call (drhs) to use the corresponding lhs/rhs quant pairs and
num_actual_groups argument so the contraction axes match transpose_rhs, and add
a regression test that differentiates transpose_rhs=True/False in both quantized
(gmm_fp8/tgmm_fp8) and non-quantized paths to catch shape/scale assertion
failures.
- Around line 43-70: Add stricter input validation before any quantization:
assert the contraction dim K matches between lhs and rhs (respecting
transpose_rhs), assert group_sizes.dtype is int32 and group_sizes.ndim == 1,
assert sum(group_sizes) == lhs.shape[0], normalize and validate group_offset as
a non-negative scalar int (not shape [1]) and assert (group_offset +
rhs.shape[0]) <= group_sizes.shape[0], and assert block_size divides the
dimension expected by fp8_dual_quantize (e.g., M or K per that routine) so
quantization blocks are valid; place these asserts immediately after the
existing rank/tiling checks in the gmm/ops entrypoint (referencing lhs, rhs,
group_sizes, transpose_rhs, group_offset, block_size, and fp8_dual_quantize).
- Around line 93-112: The FP8 rhs scale pair passed to backend.gmm_fp8 is wrong
when transpose_rhs=True: for rhs shaped [G,N,K], fp8_dual_quantize(rhs) returns
the contraction-axis scales in (rhs_rq, rhs_rs), not (rhs_cq, rhs_cs), so when
transpose_rhs is True call backend.gmm_fp8 with rhs_rq and rhs_rs (instead of
rhs_cq/rhs_cs); also include the alternate rhs orientation scales in the JAX
residuals so the backward pass can choose the correct pair (update the residuals
tuple that currently contains lhs_dtype_sentinel, rhs_dtype_sentinel, lhs_cq,
lhs_cs, rhs_rq, rhs_rs, group_sizes, group_offset to preserve both rhs pairs and
reflect the chosen pair for transpose_rhs).

---

Duplicate comments:
In `@tops/ops/gmm/backend.py`:
- Around line 1095-1112: In gmm_fp8(), reject partial K blocks before validating
lhs_scale/rhs_scale: check if k % block_size != 0 (using variables k and
block_size) and raise/assert an error (same behavior as the ref path) so the
kernel does not silently drop the K tail; add this guard before the existing
scale-shape validation that uses num_k_blocks and tiles_k to ensure num_k_blocks
= k // block_size is valid for all iterations.
- Around line 1462-1480: The dequantization in tgmm_fp8_ref() incorrectly
assumes lhs_scale_rhs_scale align with group-local mb indices; when a group
starts mid-global M block the slice m_start:m_end may span multiple global scale
rows so using scale_row = m_start // block_size only applies the first partial
row. Fix tgmm_fp8_ref() by iterating over the intersecting global M-block
indices for each group (compute global_block_id range that intersects [m_start,
m_end)), clip each global block slice to the group's local range, multiply each
clipped lhs_q_np and rhs_q_np slice by the corresponding lhs_scale and rhs_scale
rows and accumulate, and add a regression test where a group boundary falls
inside an FP8 M block to assert the kernel matches the corrected CPU reference
using compare_tensor tolerances.
- Around line 361-379: Update _validate_args to enforce strict checks on
group_sizes: assert it's a 1-D int32 vector (group_sizes.ndim == 1), assert all
entries are non-negative (>= 0), and assert that group_sizes.sum() equals
lhs.shape[0]; also validate that any shard/window indices derived later will be
in-bounds by ensuring group_sizes.shape[0] is large enough for downstream
indexing (so callers cannot pass windows that index past group_sizes). Replace
ad-hoc checks in make_group_metadata and _zero_uninitialized_memory by calling
this strengthened _validate_args (or the shared ref/pallas helper validator) at
the Pallas entry points; use assert-style checks per coding guidelines so
failures occur fast and locally.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 911e7e71-25ae-4f9b-aeea-a4e386e49f11

📥 Commits

Reviewing files that changed from the base of the PR and between 0377a4a and 5331fbc.

📒 Files selected for processing (6)
  • tests/ops/gmm/test_gmm.py
  • tests/ops/gmm/test_gmm_fp8.py
  • tests/ops/gmm/test_gmm_vjp.py
  • tops/ops/gmm/__init__.py
  • tops/ops/gmm/backend.py
  • tops/ops/gmm/ops.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tops/ops/gmm/init.py
  • tests/ops/gmm/test_gmm.py

@sii-xinglong
Copy link
Copy Markdown
Contributor Author

TPU v7x End-to-End Test Results

Cluster: tpu7x-cluster (GKE, us-central1)
Topology: 2x2x1 (4 chips, 8 devices)
JAX: 0.9.2 + libtpu 0.0.37

============================= test session starts ==============================
collected 117 items
======================== 117 passed in 70.37s (0:01:10) ========================

All 117 GMM tests passed on real TPU hardware, including:

  • G1 (group_offset): forward zeroing + backward gradients with group_offset
  • G2 (bwd_quantize): auto, explicit true/false, FP8 grad fidelity <15% Frobenius error
  • G3/G6 (validation): invalid shapes, dtypes, K%block_size, tk!=block_size
  • G4 (LUT tiling): callable dispatch + None error
  • G8 (zero padding): rows outside active groups zeroed

No regressions in existing GMM, TGMM, FP8, quantize, or VJP tests.

pengchengneo and others added 2 commits April 2, 2026 00:29
…XU throughput

The tgmm_fp8 kernel was dequantizing FP8 operands to f32 BEFORE the dot
product, which meant the MXU operated on f32 x f32 inputs. Since both
scales are along M (contraction axis), the correction is a separable
outer product: (lhs * s_l).T @ (rhs * s_r) = s_l.T * (lhs.T @ rhs) * s_r.

This change does the FP8 dot product first (leveraging native FP8 MXU
throughput with f32 accumulation), then applies the scale correction as
an outer-product post-multiply. This matches the approach already used
in the gmm_fp8 forward kernel.

Verified: CPU interpret mode passes all test cases (max_diff < 1e-5,
fro_rel_err vs BF16 < 5%). Full forward+backward pipeline produces
no NaN and maintains gradient accuracy (fro_rel_err ~2.5%).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Fix histogram range in make_group_metadata (off-by-one)
- Fix tgmm_fp8_ref global M-block scale indexing for cross-boundary groups
- Fix existing_out_arg_index to use flattened prefetch count (3+1=4)
- Add input validation: group_sizes ndim/dtype, K dimension, block_size
- Add docstrings to kernel bodies, index maps, and public API functions
- Add drhs gradient spot-checks in VJP numerical test
- Prefix unused variables with underscore

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
pengchengneo and others added 7 commits April 2, 2026 12:05
Replace the `tk == block_size` constraint in gmm_fp8 and `tm == block_size`
in tgmm_fp8 with `tk/tm % block_size == 0`.  The kernel body now uses an
unrolled Python loop over block_size sub-blocks within each tile, applying
per-block FP8 scales independently.

This eliminates the 8x grid-point penalty of tk=128 (block_size), enabling
the same tile sizes as BF16 (e.g., tk=1024).  Each sub-block still does a
native FP8 dot at MXU throughput with outer-product scale correction.

Add 12 new tests verifying subchannel correctness (vs ref and vs single-block).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
In gmm_fp8 subchannel loop, replace the one-hot masking pattern
(iota + cmp + mul + sum per sub-block) with lax.dynamic_slice_in_dim.
This eliminates 16 elementwise operations per grid point (8 sub-blocks
× 2 scales × 4 ops each) in favor of 16 dynamic slices.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace lax.dynamic_slice_in_dim (which may not be supported in Pallas
TPU lowering) with the proven approach from blockwise_fp8.py:
- Transpose lhs_scale [M, nk] → [nk, M] before passing to kernel
- Use pl.ds(global_k_idx, 1) on axis-0 (always aligned on TPU)
- Reshape [1, tm] → [tm, 1] for broadcasting

This eliminates the iota+cmp+mul+sum one-hot masking pattern for
lhs_scale. rhs_scale retains one-hot masking due to 3D shape constraints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extend the transpose+pl.ds() optimization to rhs_scale (no-transpose path):
- lhs_scale: transposed [nk, M], pl.ds(idx, 1) on axis-0
- rhs_scale (no-transpose): [nk, tn] has nk on axis-0, pl.ds directly
- rhs_scale (transpose): [tn, nk] retains one-hot (nk on last dim)

Eliminates ALL one-hot masking in the forward pass (non-transpose).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Transpose rhs_scale so K-blocks are always on axis-0, enabling pl.ds()
for both lhs_scale and rhs_scale. This eliminates the one-hot masking
fallback for transpose_rhs=True (TGMM) path, reducing per-subchannel
instruction count from 5 ops (iota+cmp+cast+mul+sum) to 1 op (pl.ds).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The block_size parameter was accepted by the high-level gmm() API
and used for quantization, but was NOT passed to the backend kernel
calls (gmm_fp8, tgmm_fp8), which silently used the default=128.
This would produce wrong results when block_size != 128.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Validates FP8 block-wise quantized GMM/TGMM kernel fidelity using
robust aggregate metrics (cosine similarity > 0.999, Frobenius < 5%,
SNR > 20dB) across multiple block_size values (128, 256, 512).

Tests cover:
- gmm_fp8 kernel vs BF16 ref + vs FP8 ref (8 shapes x 2 = 16)
- tgmm_fp8 kernel vs BF16 ref + vs FP8 ref (5 shapes x 2 = 10)
- High-level gmm(quantize=True) API forward (5 cases)
- VJP gradient finiteness + quality vs BF16 grads (5 x 2 = 10)
- Finite-output smoke tests (5)
- Precision report with metrics table (1)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants