fix: tgmm subchannel axis and loop for blockwise FP8#4
fix: tgmm subchannel axis and loop for blockwise FP8#4MokusMokun wants to merge 2 commits intomainfrom
Conversation
Fix two bugs in tgmm that prevent blockwise FP8 from using optimal tiling: 1. subchannel_iters computed from wrong axis (tk instead of tm). tgmm reduction axis is M (axis 0), so subchannel_iters should be tm // eps, not tk // eps. 2. Replace NotImplementedError with subchannel loop implementation, modeled after gmm's existing loop. Each sub-tile applies its own scale and group boundary mask independently. Before: tm and tk locked to 128, BW/CW gap 1.5-3.5x. After: all dims free, BW/CW gap 1.1-1.5x (subchannel overhead). Verified on TPU v7x: all tilings produce cosine=0.999318 vs bf16. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThe change modifies the TGMM TPU kernel to support arbitrary subchannel iterations by removing a guard condition, capping tile sizes based on quantization granularity, and introducing subchannel-aware slicing for both matrix and scale tensors during computation. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enables support for subchannel_iters > 1 in the Pallas Mosaic TPU kernel for ragged dot operations. Key changes include the introduction of a loop to process sub-tiles, sub-slicing of input references on the M-axis, and the implementation of group boundary masking using broadcasted iotas. Feedback suggests refactoring the sub-slicing logic to improve efficiency by moving helper function definitions outside the loop and simplifying redundant conditional checks for single iterations.
| else: | ||
| def _get_sub(x, idx=it): | ||
| size = x.shape[0] // subchannel_iters | ||
| s = idx * size | ||
| return x[s:s + size, :] | ||
| lhs = jax.tree.map(_get_sub, lhs_ref) | ||
| rhs = jax.tree.map(_get_sub, rhs_ref) |
There was a problem hiding this comment.
Defining the _get_sub function inside the for loop causes it to be redefined on each iteration. This is inefficient and makes the code harder to read.
For better performance and readability, I suggest defining this helper function once, outside of the for loop. This would also align better with the implementation of _maybe_get_subslice in the gmm kernel, which you've modeled this change after.
Additionally, the if subchannel_iters == 1: check is redundant. The slicing logic in the else branch correctly handles the case where subchannel_iters is 1, so the if/else structure can be simplified.
Two changes to eliminate subchannel loop overhead for blockwise FP8 TGMM: 1. Cap tm to blockwise scale granularity (eps) when QArrays are detected. This ensures subchannel_iters=1, avoiding the inner loop entirely. For block_size=128, tm is capped to 128 regardless of the requested tiling. 2. Hoist group boundary mask outside the subchannel loop. The previous implementation (c24fbaa) applied per-sub-tile masking with broadcasted_iota inside each iteration — 4x iota + 4x comparisons + 4x where ops of Vector ALU overhead. Now the full tile is masked once before sub-slicing. Benchmark on TPU v7x (production MoE shapes, block_size=128): w0 (K=2048): BW/bf16 1.49x → 1.23x wo (K=512): BW/bf16 1.25x → 1.19x The remaining gap is inherent to blockwise: extra scale data loading for w0 (~20%), and grid-level overhead for wo (K=512 gives small per-tile compute, so fixed costs dominate). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Identified the biggest bottleneck is the tgmm custom call compute (not analysed completely given LLO is not listed) Currently implemented two changes to eliminate subchannel loop overhead for blockwise FP8 TGMM:
Benchmark on TPU v7x (production MoE shapes, block_size=128): w0 (K=2048): BW/bf16 1.49x → 1.23x The remaining gap is inherent to blockwise: extra scale data loading for |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py`:
- Around line 803-815: The scale slicing uses the shared subchannel_iters and
loop index it, which can be incorrect when lhs_scale_full and rhs_scale_full use
different per-operand subchannel counts; instead compute the slice size and
index from each scale tensor's own shape (e.g. use lhs_sz =
lhs_scale_full.shape[0] // <lhs_subchannel_iters_or_derive_from_lhs_shape> and
rhs_sz = rhs_scale_full.shape[0] //
<rhs_subchannel_iters_or_derive_from_rhs_shape>) and slice lhs_scale_full and
rhs_scale_full using their respective sz and it_for_that_operand (derived from
the same loop index but relative to that operand's subchannel partitioning).
Update the scales assembly in pallas_mosaic_tpu_kernel.py (variables: scales,
lhs_scale_full, rhs_scale_full, subchannel_iters, it) and apply the same fix to
the analogous block at the later location (lines referenced in the review) so
each scale's slice uses its own shape-derived subchannel count rather than the
global subchannel_iters.
- Around line 817-818: Replace the temporary lambda is_int with an inline dtype
check when computing acc_dtype: remove the is_int assignment and use
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) directly in the acc_dtype expression so acc_dtype = jnp.int32 if
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) else jnp.float32 (references: lhs_sub, rhs_sub, acc_dtype, jnp).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8e4ff141-1c75-4306-8577-1b3bf9936578
📒 Files selected for processing (1)
tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py
| scales = [] | ||
| if lhs_scale_full is not None: | ||
| if subchannel_iters == 1: | ||
| scales.append(lhs_scale_full.T) | ||
| else: | ||
| sz = lhs_scale_full.shape[0] // subchannel_iters | ||
| scales.append(lhs_scale_full[it * sz:(it + 1) * sz, :].T) | ||
| if rhs_scale_full is not None: | ||
| if subchannel_iters == 1: | ||
| scales.append(rhs_scale_full) | ||
| else: | ||
| sz = rhs_scale_full.shape[0] // subchannel_iters | ||
| scales.append(rhs_scale_full[it * sz:(it + 1) * sz, :]) |
There was a problem hiding this comment.
Map scale rows from the qvalue window, not from the global loop count.
subchannel_iters is chosen across both operands, but these slices assume each scale tensor has at least that many rows. tokamax/_src/mosaic_tpu.py:100-130 derives eps from each tensor’s own scale.shape, so mixed M-axis scale grids are valid inputs. With a finer-grained operand driving subchannel_iters (for example, per-row on one side and blockwise/full-channel on the other), shape[0] // subchannel_iters becomes 0, producing an empty scale slice and a crash or wrong scaling. The common blockwise-FP8 fast path hides this because it usually forces subchannel_iters == 1, but the generic loop is still broken.
💡 Possible fix
if lhs_scale_full is not None:
if subchannel_iters == 1:
scales.append(lhs_scale_full.T)
else:
- sz = lhs_scale_full.shape[0] // subchannel_iters
- scales.append(lhs_scale_full[it * sz:(it + 1) * sz, :].T)
+ scale_idx = min(
+ (s * lhs_scale_full.shape[0]) // tm,
+ lhs_scale_full.shape[0] - 1,
+ )
+ scales.append(lhs_scale_full[scale_idx:scale_idx + 1, :].T)
if rhs_scale_full is not None:
if subchannel_iters == 1:
scales.append(rhs_scale_full)
else:
- sz = rhs_scale_full.shape[0] // subchannel_iters
- scales.append(rhs_scale_full[it * sz:(it + 1) * sz, :])
+ scale_idx = min(
+ (s * rhs_scale_full.shape[0]) // tm,
+ rhs_scale_full.shape[0] - 1,
+ )
+ scales.append(rhs_scale_full[scale_idx:scale_idx + 1, :])Also applies to: 905-909
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py` around lines 803 -
815, The scale slicing uses the shared subchannel_iters and loop index it, which
can be incorrect when lhs_scale_full and rhs_scale_full use different
per-operand subchannel counts; instead compute the slice size and index from
each scale tensor's own shape (e.g. use lhs_sz = lhs_scale_full.shape[0] //
<lhs_subchannel_iters_or_derive_from_lhs_shape> and rhs_sz =
rhs_scale_full.shape[0] // <rhs_subchannel_iters_or_derive_from_rhs_shape>) and
slice lhs_scale_full and rhs_scale_full using their respective sz and
it_for_that_operand (derived from the same loop index but relative to that
operand's subchannel partitioning). Update the scales assembly in
pallas_mosaic_tpu_kernel.py (variables: scales, lhs_scale_full, rhs_scale_full,
subchannel_iters, it) and apply the same fix to the analogous block at the later
location (lines referenced in the review) so each scale's slice uses its own
shape-derived subchannel count rather than the global subchannel_iters.
| is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer) | ||
| acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, verify the file exists and check the content at lines 817-818
if [ -f "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py" ]; then
echo "File found. Checking lines 817-818:"
sed -n '815,820p' "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py"
else
echo "File not found. Searching for the file..."
find . -name "pallas_mosaic_tpu_kernel.py" -type f
fiRepository: primatrix/tokamax
Length of output: 365
🏁 Script executed:
# Check if the file exists and get some context around the target lines
if [ -f "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py" ]; then
wc -l "tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py"
fiRepository: primatrix/tokamax
Length of output: 120
🌐 Web query:
Ruff E731 rule definition lambda assignment
💡 Result:
In Ruff, rule E731 (rule name: lambda-assignment) flags code that assigns a lambda to a name, e.g. f = lambda x: x + 1, with the message: “Do not assign a lambda expression, use a def.” [1][2]
Why it’s flagged: if you’re giving the function a name anyway, def is clearer and provides better debugging/introspection than a named lambda. [2]
Fix (recommended):
def f(x):
return x + 1(Instead of f = lambda x: x + 1.) [1][2]
What’s still fine: using lambda inline as an anonymous function (e.g., as a key=), since it’s not being assigned to a name. [2]
Suppress if needed:
- Per line:
# noqa: E731 - Project config: ignore
E731(e.g., add it to Ruff’s ignore list). [1]
Sources: Ruff rules table for E731 / lambda-assignment [1], pycodestyle’s E731 definition [2].
[1] (docs.astral.sh)
[2] (pycodestyle.pycqa.org)
Citations:
- 1: https://docs.astral.sh/ruff/rules/?utm_source=openai
- 2: https://pycodestyle.pycqa.org/_/downloads/en/2.10.0/pdf/?utm_source=openai
🏁 Script executed:
#!/bin/bash
# Search for any existing lint configuration in the repository
find . -maxdepth 2 -name "pyproject.toml" -o -name "ruff.toml" -o -name ".ruff.toml" -o -name "setup.cfg" | head -5Repository: primatrix/tokamax
Length of output: 89
🏁 Script executed:
#!/bin/bash
# Check the ruff configuration
echo "=== ruff.toml ==="
cat ./ruff.toml
echo ""
echo "=== pyproject.toml (ruff section) ==="
grep -A 20 "\[tool.ruff" ./pyproject.toml | head -30Repository: primatrix/tokamax
Length of output: 1195
🏁 Script executed:
#!/bin/bash
# Verify the syntax of the suggested fix by checking if Python can parse it
python3 << 'EOF'
import ast
import sys
# The suggested inline fix
code = '''acc_dtype = (
jnp.int32
if jnp.issubdtype(lhs_sub.dtype, jnp.integer)
and jnp.issubdtype(rhs_sub.dtype, jnp.integer)
else jnp.float32
)'''
try:
ast.parse(code)
print("✓ Suggested fix has valid Python syntax")
except SyntaxError as e:
print(f"✗ Syntax error in suggested fix: {e}")
sys.exit(1)
EOFRepository: primatrix/tokamax
Length of output: 100
Inline the dtype check to clear Ruff E731.
The lambda assignment at line 817 triggers the lambda-assignment (E731) lint error. Inline the dtype check to fix it:
♻️ Small cleanup
- is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer)
- acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32
+ acc_dtype = (
+ jnp.int32
+ if jnp.issubdtype(lhs_sub.dtype, jnp.integer)
+ and jnp.issubdtype(rhs_sub.dtype, jnp.integer)
+ else jnp.float32
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| is_int = lambda x: jnp.issubdtype(x.dtype, jnp.integer) | |
| acc_dtype = jnp.int32 if is_int(lhs_sub) and is_int(rhs_sub) else jnp.float32 | |
| acc_dtype = ( | |
| jnp.int32 | |
| if jnp.issubdtype(lhs_sub.dtype, jnp.integer) | |
| and jnp.issubdtype(rhs_sub.dtype, jnp.integer) | |
| else jnp.float32 | |
| ) |
🧰 Tools
🪛 Ruff (0.15.9)
[error] 817-817: Do not assign a lambda expression, use a def
Rewrite is_int as a def
(E731)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py` around lines 817 -
818, Replace the temporary lambda is_int with an inline dtype check when
computing acc_dtype: remove the is_int assignment and use
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) directly in the acc_dtype expression so acc_dtype = jnp.int32 if
jnp.issubdtype(lhs_sub.dtype, jnp.integer) and jnp.issubdtype(rhs_sub.dtype,
jnp.integer) else jnp.float32 (references: lhs_sub, rhs_sub, acc_dtype, jnp).
Summary
Fix two bugs in tgmm that caused blockwise FP8 to produce wrong results (cosine ~0.34-0.49) when
tm > 128, and crash withNotImplementedErrorwhentk > 128.Bug 1 — Wrong axis for
subchannel_iters(line 874, 878):subchannel_iters = max(1, tk // lhs_eps)usestk(K-tile), but tgmm's reduction axis is M, not K. Changed totm // lhs_eps.Bug 2 — Missing subchannel loop (line 738-741):
if subchannel_iters != 1: raise NotImplementedError. Replaced with a subchannel loop (modeled after gmm's existing implementation) that iterates over M-axis sub-tiles, applying the correct scale per sub-tile.Benchmark results (TPU v7x)
Correctness: all tilings (including tm=512, tk=1024) produce cosine=0.999318 vs bf16 reference.
Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit