Skip to content

fix: tgmm subchannel axis and loop for blockwise FP8#4

Open
MokusMokun wants to merge 2 commits intomainfrom
fix/tgmm-subchannel-axis
Open

fix: tgmm subchannel axis and loop for blockwise FP8#4
MokusMokun wants to merge 2 commits intomainfrom
fix/tgmm-subchannel-axis

Conversation

@MokusMokun
Copy link
Copy Markdown

@MokusMokun MokusMokun commented Mar 30, 2026

Summary

Fix two bugs in tgmm that caused blockwise FP8 to produce wrong results (cosine ~0.34-0.49) when tm > 128, and crash with NotImplementedError when tk > 128.

Bug 1 — Wrong axis for subchannel_iters (line 874, 878):
subchannel_iters = max(1, tk // lhs_eps) uses tk (K-tile), but tgmm's reduction axis is M, not K. Changed to tm // lhs_eps.

Bug 2 — Missing subchannel loop (line 738-741):
if subchannel_iters != 1: raise NotImplementedError. Replaced with a subchannel loop (modeled after gmm's existing implementation) that iterates over M-axis sub-tiles, applying the correct scale per sub-tile.

Benchmark results (TPU v7x)

Shape Before (BW/CW gap) After (BW/CW gap)
medium (2048,1024,512,G=8) 1.49x 1.09x
large/DSv3 (8192,2048,1024,G=16) 3.40x 1.48x
wide K (2048,7168,2048,G=8) 3.49x 1.35x
wide N (1024,512,4096,G=8) 1.51x 1.21x

Correctness: all tilings (including tm=512, tk=1024) produce cosine=0.999318 vs bf16 reference.

Test plan

  • Existing tokamax ragged_dot tests pass
  • Blockwise FP8 tgmm benchmark with full tiling space (verified on TPU v7x)
  • End-to-end training validation (ant-pretrain CI)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Performance & Optimization
    • Improved quantization support and tile size optimization for TPU-accelerated operations, enhancing performance for quantized tensor computations with better resource utilization.

Fix two bugs in tgmm that prevent blockwise FP8 from using optimal tiling:

1. subchannel_iters computed from wrong axis (tk instead of tm).
   tgmm reduction axis is M (axis 0), so subchannel_iters should be
   tm // eps, not tk // eps.

2. Replace NotImplementedError with subchannel loop implementation,
   modeled after gmm's existing loop. Each sub-tile applies its own
   scale and group boundary mask independently.

Before: tm and tk locked to 128, BW/CW gap 1.5-3.5x.
After: all dims free, BW/CW gap 1.1-1.5x (subchannel overhead).

Verified on TPU v7x: all tilings produce cosine=0.999318 vs bf16.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 30, 2026

📝 Walkthrough

Walkthrough

The change modifies the TGMM TPU kernel to support arbitrary subchannel iterations by removing a guard condition, capping tile sizes based on quantization granularity, and introducing subchannel-aware slicing for both matrix and scale tensors during computation.

Changes

Cohort / File(s) Summary
TGMM TPU Kernel Modifications
tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py
Added quantization granularity-based tm tile size capping via eps = cdiv(qvalue.shape[0], scale.shape[0]); removed NotImplementedError guard for subchannel_iters != 1; changed group_metadata unpacking to include m_tile_ids; refactored quantized scale handling to extract full-tile scales before masking; introduced subchannel iteration loop with sc_tile = tm // subchannel_iters slicing; updated subchannel_iters computation from tk // eps to tm // eps.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

🐰 A kernel blooms with tile precision bright,
Subchannels now slice through the tpu night,
Quantization scales align with care,
No guard to stop—just freedom in the air! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: tgmm subchannel axis and loop for blockwise FP8' accurately captures the two main functional fixes: correcting the subchannel axis calculation (from tk to tm) and implementing the subchannel loop to replace the NotImplementedError, both addressing blockwise FP8 handling in tgmm.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/tgmm-subchannel-axis
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch fix/tgmm-subchannel-axis

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for subchannel_iters > 1 in the Pallas Mosaic TPU kernel for ragged dot operations. Key changes include the introduction of a loop to process sub-tiles, sub-slicing of input references on the M-axis, and the implementation of group boundary masking using broadcasted iotas. Feedback suggests refactoring the sub-slicing logic to improve efficiency by moving helper function definitions outside the loop and simplifying redundant conditional checks for single iterations.

Comment on lines +769 to +775
else:
def _get_sub(x, idx=it):
size = x.shape[0] // subchannel_iters
s = idx * size
return x[s:s + size, :]
lhs = jax.tree.map(_get_sub, lhs_ref)
rhs = jax.tree.map(_get_sub, rhs_ref)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defining the _get_sub function inside the for loop causes it to be redefined on each iteration. This is inefficient and makes the code harder to read.

For better performance and readability, I suggest defining this helper function once, outside of the for loop. This would also align better with the implementation of _maybe_get_subslice in the gmm kernel, which you've modeled this change after.

Additionally, the if subchannel_iters == 1: check is redundant. The slicing logic in the else branch correctly handles the case where subchannel_iters is 1, so the if/else structure can be simplified.

Two changes to eliminate subchannel loop overhead for blockwise FP8 TGMM:

1. Cap tm to blockwise scale granularity (eps) when QArrays are detected.
   This ensures subchannel_iters=1, avoiding the inner loop entirely.
   For block_size=128, tm is capped to 128 regardless of the requested tiling.

2. Hoist group boundary mask outside the subchannel loop. The previous
   implementation (c24fbaa) applied per-sub-tile masking with broadcasted_iota
   inside each iteration — 4x iota + 4x comparisons + 4x where ops of
   Vector ALU overhead. Now the full tile is masked once before sub-slicing.

Benchmark on TPU v7x (production MoE shapes, block_size=128):

  w0 (K=2048): BW/bf16 1.49x → 1.23x
  wo (K=512):  BW/bf16 1.25x → 1.19x

The remaining gap is inherent to blockwise: extra scale data loading for
w0 (~20%), and grid-level overhead for wo (K=512 gives small per-tile
compute, so fixed costs dominate).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@MokusMokun
Copy link
Copy Markdown
Author

Identified the biggest bottleneck is the tgmm custom call compute (not analysed completely given LLO is not listed)
https://alidocs.dingtalk.com/i/nodes/3NwLYZXWynN5ADnOtQxvRQloVkyEqBQm

Currently implemented two changes to eliminate subchannel loop overhead for blockwise FP8 TGMM:

  1. Cap tm to blockwise scale granularity (eps) when QArrays are detected.
    This ensures subchannel_iters=1, avoiding the inner loop entirely.
    For block_size=128, tm is capped to 128 regardless of the requested tiling.

  2. Hoist group boundary mask outside the subchannel loop. The previous
    implementation (c24fbaa) applied per-sub-tile masking with broadcasted_iota
    inside each iteration — 4x iota + 4x comparisons + 4x where ops of
    Vector ALU overhead. Now the full tile is masked once before sub-slicing.

Benchmark on TPU v7x (production MoE shapes, block_size=128):

w0 (K=2048): BW/bf16 1.49x → 1.23x
wo (K=512): BW/bf16 1.25x → 1.19x

The remaining gap is inherent to blockwise: extra scale data loading for
w0 (~20%), and grid-level overhead for wo (K=512 gives small per-tile
compute, so fixed costs dominate).

@MokusMokun MokusMokun changed the title [WIP] fix: tgmm subchannel axis and loop for blockwise FP8 fix: tgmm subchannel axis and loop for blockwise FP8 Apr 7, 2026
@MokusMokun MokusMokun marked this pull request as ready for review April 7, 2026 07:24
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py`:
- Around line 803-815: The scale slicing uses the shared subchannel_iters and
loop index it, which can be incorrect when lhs_scale_full and rhs_scale_full use
different per-operand subchannel counts; instead compute the slice size and
index from each scale tensor's own shape (e.g. use lhs_sz =
lhs_scale_full.shape[0] // <lhs_subchannel_iters_or_derive_from_lhs_shape> and
rhs_sz = rhs_scale_full.shape[0] //
<rhs_subchannel_iters_or_derive_from_rhs_shape>) and slice lhs_scale_full and
rhs_scale_full using their respective sz and it_for_that_operand (derived from
the same loop index but relative to that operand's subchannel partitioning).
Update the scales assembly in pallas_mosaic_tpu_kernel.py (variables: scales,
lhs_scale_full, rhs_scale_full, subchannel_iters, it) and apply the same fix to
the analogous block at the later location (lines referenced in the review) so
each scale's slice uses its own shape-derived subchannel count rather than the
global subchannel_iters.
- Around line 817-818: Replace the temporary lambda is_int with an inline dtype
check when computing acc_dtype: remove the is_int assignment and use
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) directly in the acc_dtype expression so acc_dtype = jnp.int32 if
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) else jnp.float32 (references: lhs_sub, rhs_sub, acc_dtype, jnp).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8e4ff141-1c75-4306-8577-1b3bf9936578

📥 Commits

Reviewing files that changed from the base of the PR and between e656568 and 651fde2.

📒 Files selected for processing (1)
  • tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py

Comment on lines +803 to +815
scales = []
if lhs_scale_full is not None:
if subchannel_iters == 1:
scales.append(lhs_scale_full.T)
else:
sz = lhs_scale_full.shape[0] // subchannel_iters
scales.append(lhs_scale_full[it * sz:(it + 1) * sz, :].T)
if rhs_scale_full is not None:
if subchannel_iters == 1:
scales.append(rhs_scale_full)
else:
sz = rhs_scale_full.shape[0] // subchannel_iters
scales.append(rhs_scale_full[it * sz:(it + 1) * sz, :])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Map scale rows from the qvalue window, not from the global loop count.

subchannel_iters is chosen across both operands, but these slices assume each scale tensor has at least that many rows. tokamax/_src/mosaic_tpu.py:100-130 derives eps from each tensor’s own scale.shape, so mixed M-axis scale grids are valid inputs. With a finer-grained operand driving subchannel_iters (for example, per-row on one side and blockwise/full-channel on the other), shape[0] // subchannel_iters becomes 0, producing an empty scale slice and a crash or wrong scaling. The common blockwise-FP8 fast path hides this because it usually forces subchannel_iters == 1, but the generic loop is still broken.

💡 Possible fix
         if lhs_scale_full is not None:
           if subchannel_iters == 1:
             scales.append(lhs_scale_full.T)
           else:
-            sz = lhs_scale_full.shape[0] // subchannel_iters
-            scales.append(lhs_scale_full[it * sz:(it + 1) * sz, :].T)
+            scale_idx = min(
+                (s * lhs_scale_full.shape[0]) // tm,
+                lhs_scale_full.shape[0] - 1,
+            )
+            scales.append(lhs_scale_full[scale_idx:scale_idx + 1, :].T)
         if rhs_scale_full is not None:
           if subchannel_iters == 1:
             scales.append(rhs_scale_full)
           else:
-            sz = rhs_scale_full.shape[0] // subchannel_iters
-            scales.append(rhs_scale_full[it * sz:(it + 1) * sz, :])
+            scale_idx = min(
+                (s * rhs_scale_full.shape[0]) // tm,
+                rhs_scale_full.shape[0] - 1,
+            )
+            scales.append(rhs_scale_full[scale_idx:scale_idx + 1, :])

Also applies to: 905-909

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py` around lines 803 -
815, The scale slicing uses the shared subchannel_iters and loop index it, which
can be incorrect when lhs_scale_full and rhs_scale_full use different
per-operand subchannel counts; instead compute the slice size and index from
each scale tensor's own shape (e.g. use lhs_sz = lhs_scale_full.shape[0] //
<lhs_subchannel_iters_or_derive_from_lhs_shape> and rhs_sz =
rhs_scale_full.shape[0] // <rhs_subchannel_iters_or_derive_from_rhs_shape>) and
slice lhs_scale_full and rhs_scale_full using their respective sz and
it_for_that_operand (derived from the same loop index but relative to that
operand's subchannel partitioning). Update the scales assembly in
pallas_mosaic_tpu_kernel.py (variables: scales, lhs_scale_full, rhs_scale_full,
subchannel_iters, it) and apply the same fix to the analogous block at the later
location (lines referenced in the review) so each scale's slice uses its own
shape-derived subchannel count rather than the global subchannel_iters.

Comment on lines +817 to +818
is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer)
acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, verify the file exists and check the content at lines 817-818
if [ -f "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py" ]; then
  echo "File found. Checking lines 817-818:"
  sed -n '815,820p' "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py"
else
  echo "File not found. Searching for the file..."
  find . -name "pallas_mosaic_tpu_kernel.py" -type f
fi

Repository: primatrix/tokamax

Length of output: 365


🏁 Script executed:

# Check if the file exists and get some context around the target lines
if [ -f "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py" ]; then
  wc -l "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py"
fi

Repository: primatrix/tokamax

Length of output: 120


🌐 Web query:

Ruff E731 rule definition lambda assignment

💡 Result:

In Ruff, rule E731 (rule name: lambda-assignment) flags code that assigns a lambda to a name, e.g. f = lambda x: x + 1, with the message: “Do not assign a lambda expression, use a def.” [1][2]

Why it’s flagged: if you’re giving the function a name anyway, def is clearer and provides better debugging/introspection than a named lambda. [2]

Fix (recommended):

def f(x):
    return x + 1

(Instead of f = lambda x: x + 1.) [1][2]

What’s still fine: using lambda inline as an anonymous function (e.g., as a key=), since it’s not being assigned to a name. [2]

Suppress if needed:

  • Per line: # noqa: E731
  • Project config: ignore E731 (e.g., add it to Ruff’s ignore list). [1]

Sources: Ruff rules table for E731 / lambda-assignment [1], pycodestyle’s E731 definition [2].

[1] (docs.astral.sh)
[2] (pycodestyle.pycqa.org)

Citations:


🏁 Script executed:

#!/bin/bash
# Search for any existing lint configuration in the repository
find . -maxdepth 2 -name "pyproject.toml" -o -name "ruff.toml" -o -name ".ruff.toml" -o -name "setup.cfg" | head -5

Repository: primatrix/tokamax

Length of output: 89


🏁 Script executed:

#!/bin/bash
# Check the ruff configuration
echo "=== ruff.toml ==="
cat ./ruff.toml
echo ""
echo "=== pyproject.toml (ruff section) ==="
grep -A 20 "\[tool.ruff" ./pyproject.toml | head -30

Repository: primatrix/tokamax

Length of output: 1195


🏁 Script executed:

#!/bin/bash
# Verify the syntax of the suggested fix by checking if Python can parse it
python3 << 'EOF'
import ast
import sys

# The suggested inline fix
code = '''acc_dtype = (
    jnp.int32
    if jnp.issubdtype(lhs_sub.dtype, jnp.integer)
    and jnp.issubdtype(rhs_sub.dtype, jnp.integer)
    else jnp.float32
)'''

try:
    ast.parse(code)
    print("✓ Suggested fix has valid Python syntax")
except SyntaxError as e:
    print(f"✗ Syntax error in suggested fix: {e}")
    sys.exit(1)
EOF

Repository: primatrix/tokamax

Length of output: 100


Inline the dtype check to clear Ruff E731.

The lambda assignment at line 817 triggers the lambda-assignment (E731) lint error. Inline the dtype check to fix it:

♻️ Small cleanup
-        is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer)
-        acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32
+        acc_dtype = (
+            jnp.int32
+            if jnp.issubdtype(lhs_sub.dtype, jnp.integer)
+            and jnp.issubdtype(rhs_sub.dtype, jnp.integer)
+            else jnp.float32
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer)
acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32
acc_dtype = (
jnp.int32
if jnp.issubdtype(lhs_sub.dtype, jnp.integer)
and jnp.issubdtype(rhs_sub.dtype, jnp.integer)
else jnp.float32
)
🧰 Tools
🪛 Ruff (0.15.9)

[error] 817-817: Do not assign a lambda expression, use a def

Rewrite is_int as a def

(E731)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py` around lines 817 -
818, Replace the temporary lambda is_int with an inline dtype check when
computing acc_dtype: remove the is_int assignment and use
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) directly in the acc_dtype expression so acc_dtype = jnp.int32 if
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) else jnp.float32 (references: lhs_sub, rhs_sub, acc_dtype, jnp).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant