Skip to content

Fix: correct RMSNorm to use rsqrt and keep projections in FP32#69

Open
YunjiQin wants to merge 2 commits intohw-native-sys:mainfrom
YunjiQin:fix/rmsnorm-rsqrt-bf16-projections
Open

Fix: correct RMSNorm to use rsqrt and keep projections in FP32#69
YunjiQin wants to merge 2 commits intohw-native-sys:mainfrom
YunjiQin:fix/rmsnorm-rsqrt-bf16-projections

Conversation

@YunjiQin
Copy link
Copy Markdown
Contributor

@YunjiQin YunjiQin commented Apr 7, 2026

Summary

  • Replace raw variance with rsqrt-based inv_rms in RMSNorm for both scope1 and tilelet decode examples
  • Keep Q/K/V projection outputs in FP32 (avoid precision loss from BF16 round-trip before RoPE/attention)
  • Restructure scope1 to use auto_incore + parallel loops with UP_DOWN split for projection stages
  • Replace matmul/matmul_acc with add + matmul accumulation pattern
  • Drop redundant FP32 casts when reading FP32 q_proj/k_proj in RoPE/attention stages
  • Add explicit BF16 cast when writing FP32 v_proj into BF16 v_cache
  • Update golden references and tensor specs to match FP32 projection output
  • Remove unused constants (MAX_SEQ, INTERMEDIATE, MLP_OUT_CHUNK)

- Replace raw variance with rsqrt-based inv_rms in RMSNorm for both
  scope1 and tilelet decode examples
- Change Q/K/V projection output dtype from FP32 to BF16
- Restructure scope1 to use auto_incore + parallel loops instead of
  manual incore + range, with UP_DOWN split for projection stages
- Replace matmul/matmul_acc with add + matmul accumulation pattern
- Add explicit FP32 cast for RoPE q/k slices now stored as BF16
- Update golden references to match kernel precision path
- Remove unused constants (MAX_SEQ, INTERMEDIATE, MLP_OUT_CHUNK)
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Refactors Qwen3 decode examples: changes RMSNorm from single-pass variance to two-pass per-batch-tile inv_rms, adds BF16 normalized buffer, and rewrites Q/K/V projection to explicit FP32 accumulation over K_CHUNK with cross-core pl.auto_incore loops; updates docs and golden references accordingly.

Changes

Cohort / File(s) Summary
Decode scope1 & tilelet
examples/models/qwen3/qwen3_32b_decode_scope1.py, examples/models/qwen3/qwen3_32b_decode_tilelet.py
Replaced single-pass RMSNorm with two-pass per-batch-tile sq_sum→inv_rms→normed_buf (BF16); removed unused constants; rewrote Q/K/V projections to use pl.auto_incore cross-core loops with explicit FP32 accumulators (q_acc/k_acc/v_acc) accumulating pl.matmul(..., out_dtype=FP32) over K_CHUNK, then cast/assemble BF16 outputs; updated golden reference math to match; minor formatting and doc constant updates.

Sequence Diagram(s)

(omitted — changes are internal kernel/control-flow updates without a clear multi-component external sequence)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related issues

Possibly related PRs

Poem

🐰 I hopped through tiles at break of day,
Summed squares, took roots, then hopped away.
BF16 leaves shimmer in each tiny buf,
Accumulators hum, the kernels fluff—
A rabbit cheers: decode runs smooth and gay! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: correcting RMSNorm to use rsqrt and keeping Q/K/V projections in FP32, which are the primary objectives.
Description check ✅ Passed The description comprehensively details the changes to RMSNorm computation, projection dtype handling, code restructuring, and removal of unused constants, all of which are directly related to the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Qwen3 decode kernels to optimize performance and memory usage by switching projection outputs to BF16 and parallelizing the RMSNorm and projection stages using auto-incore and parallel loops. The reference implementations have been updated to match the new precision and chunking logic. Feedback identifies a potential issue where tensor initialization only covers the first batch tile, which could lead to uninitialized data if the batch size exceeds the tile size. Additionally, there are suggestions to fix inconsistent variable naming between 'batch' and 'BATCH_CFG' within the parallel projection loops.

Comment on lines +154 to +168
with pl.incore():
for ob in pl.range(Q_OUT_BLOCKS):
q0 = ob * Q_OUT_CHUNK
zero_1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0)
zero_1_bf = pl.cast(zero_1, target_type=pl.BF16)
q_proj = pl.assemble(q_proj, zero_1_bf, [0, q0])
attn_out = pl.assemble(attn_out, zero_1_bf, [0, q0])
normed_buf = pl.assemble(normed_buf, zero_1_bf, [0, q0])
with pl.incore():
for ob in pl.range(KV_OUT_BLOCKS):
kv0 = ob * KV_OUT_CHUNK
zero_2 = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0)
zero_2_bf = pl.cast(zero_2, target_type=pl.BF16)
k_proj = pl.assemble(k_proj, zero_2_bf, [0, kv0])
v_proj = pl.assemble(v_proj, zero_2_bf, [0, kv0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initialization of q_proj, k_proj, v_proj, attn_out, and normed_buf only covers the first BATCH_TILE rows (using offset 0 in pl.assemble). If BATCH_CFG is greater than BATCH_TILE, the remaining rows will be left uninitialized. Although the current default BATCH matches BATCH_TILE, this is fragile. Consider wrapping these in a loop over batch blocks to ensure the entire tensors are initialized.

with pl.auto_incore(split=pl.SplitMode.UP_DOWN):
# Stage 2: Q projection (AIC+AIV cross-core incore).
for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4):
for b0 in pl.range(0, batch, BATCH_TILE):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable batch is used here, but it is not a parameter of the qwen3_decode_layer function. While it is available via closure from the outer build_qwen3_single_layer_decode_program scope, it is inconsistent with the rest of the function which uses BATCH_CFG (e.g., lines 173 and 245). Please use BATCH_CFG for consistency.

Suggested change
for b0 in pl.range(0, batch, BATCH_TILE):
for b0 in pl.range(0, BATCH_CFG, BATCH_TILE):

for kb in pl.range(1, HIDDEN_BLOCKS):
# Stage 3: K/V projection (AIC+AIV cross-core incore).
for ob in pl.parallel(0, KV_OUT_BLOCKS, 1, chunk=8):
for b0 in pl.range(0, batch, BATCH_TILE):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable batch is used here instead of BATCH_CFG. Please use BATCH_CFG to maintain consistency with the program configuration variables used elsewhere in this function.

Suggested change
for b0 in pl.range(0, batch, BATCH_TILE):
for b0 in pl.range(0, BATCH_CFG, BATCH_TILE):

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/models/qwen3/qwen3_32b_decode_tilelet.py (1)

530-547: ⚠️ Potential issue | 🟠 Major

Make the golden reference use runtime dimensions.

golden_qwen3_decode() still allocates and iterates from module-level BATCH, MAX_SEQ, HIDDEN, and NUM_* constants. compile_and_run(...) exposes all of those as parameters, so non-default runs will compare against a wrong-sized or partially computed reference.

🧪 Suggested shape derivation
batch, hidden_size = hidden_states.shape
max_seq_len, head_dim = rope_cos.shape
kv_hidden = wk.shape[1]
num_kv_heads = kv_hidden // head_dim
num_heads = hidden_size // head_dim
intermediate_size = w_gate.shape[1]
hidden_inv = 1.0 / hidden_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 547,
The golden_qwen3_decode reference still uses module-level constants (BATCH,
MAX_SEQ, HIDDEN, NUM_*) causing mismatched shapes for non-default runs; change
golden_qwen3_decode to derive all dimensions at runtime from inputs (e.g.,
batch, hidden_size = hidden_states.shape; max_seq_len, head_dim =
rope_cos.shape; kv_hidden = wk.shape[1]; num_kv_heads = kv_hidden // head_dim;
num_heads = hidden_size // head_dim; intermediate_size = w_gate.shape[1];
hidden_inv = 1.0 / hidden_size) and replace any allocations, loop bounds, and
slices that use BATCH, MAX_SEQ, HIDDEN, NUM_* with these derived variables so
compile_and_run(...) comparisons use correct sizes for arbitrary arguments.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py`:
- Around line 93-95: The RMSNorm calculation in
build_decode_projection_program() uses the module-level constant HIDDEN_INV
instead of the per-call hidden size; replace uses of HIDDEN_INV (e.g., where
inv_rms_tile is computed) with 1.0 / hidden (or a local hidden_inv computed from
the hidden parameter) so RMS scaling uses the configured hidden size; apply the
same change to the other occurrence around lines 221-229 where HIDDEN_INV is
referenced.
- Around line 68-75: The scope boundary types are mismatched:
qwen3_32b_decode_scope1.py now emits BF16 q_proj/k_proj/v_proj but the
downstream consumer in qwen3_32b_decode_scope2.py still declares those inputs as
FP32; update the scope2 function signature to accept pl.BF16 for q_proj, k_proj,
and v_proj (replace pl.FP32/float annotations with pl.BF16) so it matches the
outputs from qwen3_32b_decode_scope1.py, and make the same annotation fixes for
the other occurrences mentioned (around the second instance at the lines
referenced 189-191) to keep the dtype contract consistent across the boundary.
- Line 56: The code computes batch_blocks = batch // BATCH_TILE which silently
drops a tail when batch is not a multiple of BATCH_TILE, causing Stage 1 to skip
the last partial tile while later stages still read from normed_buf; add an
explicit validation at the start of the pipeline (before computing batch_blocks
and before any projection loops that use range(0, batch, BATCH_TILE)) to either
raise a ValueError when batch <= 0 or batch % BATCH_TILE != 0 or else change the
tiling logic to use math.ceil to include a final partial tile and ensure Stage 1
writes the tail into normed_buf; reference the symbols batch_blocks, BATCH_TILE,
the projection loop ranges (range(0, batch, BATCH_TILE)), and normed_buf when
making the change so all stages remain consistent.

In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 154-168: The code assumes full BATCH_TILE tiles when assembling
buffers (e.g., in the pl.incore loops that write slices of shape [BATCH_TILE,
...]), so add a guard in build_qwen3_single_layer_decode_program to validate the
batch argument (e.g., if batch <= 0 or batch % BATCH_TILE != 0: raise
ValueError(f"batch ({batch}) must be a positive multiple of BATCH_TILE
({BATCH_TILE})")); update any other entry points that construct these programs
(Scope 3 / callers) to perform the same check or propagate the validated value
rather than attempting partial-tile handling.
- Around line 186-188: The RMS normalization uses the module-level constant
HIDDEN_INV (1/8192) which breaks when build_qwen3_single_layer_decode_program()
is called with a non-default hidden_size; compute a per-call inverse hidden size
(e.g., hidden_inv_cfg = 1.0 / HIDDEN_CFG or based on the hidden_size parameter)
and replace uses of HIDDEN_INV in the inv_rms_tile calculation (the
pl.rsqrt(...) / reshape block that defines inv_rms_tile) and the later
post-RMSNorm path so both normalization sites use the per-call hidden_inv_cfg
instead of the global HIDDEN_INV.

---

Outside diff comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-547: The golden_qwen3_decode reference still uses module-level
constants (BATCH, MAX_SEQ, HIDDEN, NUM_*) causing mismatched shapes for
non-default runs; change golden_qwen3_decode to derive all dimensions at runtime
from inputs (e.g., batch, hidden_size = hidden_states.shape; max_seq_len,
head_dim = rope_cos.shape; kv_hidden = wk.shape[1]; num_kv_heads = kv_hidden //
head_dim; num_heads = hidden_size // head_dim; intermediate_size =
w_gate.shape[1]; hidden_inv = 1.0 / hidden_size) and replace any allocations,
loop bounds, and slices that use BATCH, MAX_SEQ, HIDDEN, NUM_* with these
derived variables so compile_and_run(...) comparisons use correct sizes for
arbitrary arguments.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: f3d6ba06-3d2b-4211-9460-a45220ab8b55

📥 Commits

Reviewing files that changed from the base of the PR and between 9a4a25f and 994c33a.

📒 Files selected for processing (2)
  • examples/models/qwen3/qwen3_32b_decode_scope1.py
  • examples/models/qwen3/qwen3_32b_decode_tilelet.py

hidden_blocks = hidden // K_CHUNK
q_out_blocks = hidden // Q_OUT_CHUNK
kv_out_blocks = kv_hidden // KV_OUT_CHUNK
batch_blocks = batch // BATCH_TILE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject or handle non-multiple batch sizes.

batch_blocks = batch // BATCH_TILE truncates the tail, but the projection loops still iterate range(0, batch, BATCH_TILE). For batch=17, Stage 1 never writes row 16 while Stages 2/3 still slice a full [16, ...] tile from normed_buf.

🛡️ Suggested guard
if batch <= 0 or batch % BATCH_TILE != 0:
    raise ValueError(
        f"batch ({batch}) must be a positive multiple of BATCH_TILE ({BATCH_TILE})"
    )

Also applies to: 76-80, 112-145

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` at line 56, The code
computes batch_blocks = batch // BATCH_TILE which silently drops a tail when
batch is not a multiple of BATCH_TILE, causing Stage 1 to skip the last partial
tile while later stages still read from normed_buf; add an explicit validation
at the start of the pipeline (before computing batch_blocks and before any
projection loops that use range(0, batch, BATCH_TILE)) to either raise a
ValueError when batch <= 0 or batch % BATCH_TILE != 0 or else change the tiling
logic to use math.ceil to include a final partial tile and ensure Stage 1 writes
the tail into normed_buf; reference the symbols batch_blocks, BATCH_TILE, the
projection loop ranges (range(0, batch, BATCH_TILE)), and normed_buf when making
the change so all stages remain consistent.

Comment on lines 68 to 75
q_proj: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]],
k_proj: pl.Out[pl.Tensor[[batch, kv_hidden], pl.BF16]],
v_proj: pl.Out[pl.Tensor[[batch, kv_hidden], pl.BF16]],
) -> tuple[
pl.Tensor[[batch, hidden], pl.FP32],
pl.Tensor[[batch, kv_hidden], pl.FP32],
pl.Tensor[[batch, kv_hidden], pl.FP32],
pl.Tensor[[batch, hidden], pl.BF16],
pl.Tensor[[batch, kv_hidden], pl.BF16],
pl.Tensor[[batch, kv_hidden], pl.BF16],
]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Align the Scope 1 → Scope 2 dtype contract.

If these outputs still feed examples/models/qwen3/qwen3_32b_decode_scope2.py, that function still declares FP32 q_proj, k_proj, and v_proj inputs on Lines 71-76 while this file now exposes BF16 tensors. The top-level orchestrator already allocates BF16 buffers on Lines 104-106 of examples/models/qwen3/qwen3_32b_decode.py, so the consumer change needs to land with this boundary change.

Also applies to: 189-191

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 68 - 75, The
scope boundary types are mismatched: qwen3_32b_decode_scope1.py now emits BF16
q_proj/k_proj/v_proj but the downstream consumer in qwen3_32b_decode_scope2.py
still declares those inputs as FP32; update the scope2 function signature to
accept pl.BF16 for q_proj, k_proj, and v_proj (replace pl.FP32/float annotations
with pl.BF16) so it matches the outputs from qwen3_32b_decode_scope1.py, and
make the same annotation fixes for the other occurrences mentioned (around the
second instance at the lines referenced 189-191) to keep the dtype contract
consistent across the boundary.

Comment on lines +93 to 95
inv_rms_tile: pl.Tensor[[BATCH_TILE, 1], pl.FP32] = pl.reshape(
pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)),
[BATCH_TILE, 1],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the configured hidden size in RMSNorm.

This path still normalizes with the module-level HIDDEN_INV (1 / 8192) even though build_decode_projection_program() accepts arbitrary hidden_size. The PyTorch golden mirrors the same constant, so non-default sizes will be wrong and still appear green.

⚙️ Suggested fix
hidden_inv = 1.0 / hidden
-                        pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)),
+                        pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), EPS)),
-        inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS)  # [BATCH_TILE, 1]
+        inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)

Also applies to: 221-229

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 93 - 95, The
RMSNorm calculation in build_decode_projection_program() uses the module-level
constant HIDDEN_INV instead of the per-call hidden size; replace uses of
HIDDEN_INV (e.g., where inv_rms_tile is computed) with 1.0 / hidden (or a local
hidden_inv computed from the hidden parameter) so RMS scaling uses the
configured hidden size; apply the same change to the other occurrence around
lines 221-229 where HIDDEN_INV is referenced.

Comment on lines +154 to +168
with pl.incore():
for ob in pl.range(Q_OUT_BLOCKS):
q0 = ob * Q_OUT_CHUNK
zero_1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0)
zero_1_bf = pl.cast(zero_1, target_type=pl.BF16)
q_proj = pl.assemble(q_proj, zero_1_bf, [0, q0])
attn_out = pl.assemble(attn_out, zero_1_bf, [0, q0])
normed_buf = pl.assemble(normed_buf, zero_1_bf, [0, q0])
with pl.incore():
for ob in pl.range(KV_OUT_BLOCKS):
kv0 = ob * KV_OUT_CHUNK
zero_2 = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0)
zero_2_bf = pl.cast(zero_2, target_type=pl.BF16)
k_proj = pl.assemble(k_proj, zero_2_bf, [0, kv0])
v_proj = pl.assemble(v_proj, zero_2_bf, [0, kv0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the full-tile batch invariant.

These loops always read and write [BATCH_TILE, ...] slices. build_qwen3_single_layer_decode_program(batch=...) will go out of bounds for batch < 16 or any batch % 16 != 0, and Scope 3 still has the same assumption. Either add partial-tile handling or fail fast on unsupported batch sizes.

🛡️ Suggested guard
if batch <= 0 or batch % BATCH_TILE != 0:
    raise ValueError(
        f"batch ({batch}) must be a positive multiple of BATCH_TILE ({BATCH_TILE})"
    )

Also applies to: 172-238

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 154 - 168,
The code assumes full BATCH_TILE tiles when assembling buffers (e.g., in the
pl.incore loops that write slices of shape [BATCH_TILE, ...]), so add a guard in
build_qwen3_single_layer_decode_program to validate the batch argument (e.g., if
batch <= 0 or batch % BATCH_TILE != 0: raise ValueError(f"batch ({batch}) must
be a positive multiple of BATCH_TILE ({BATCH_TILE})")); update any other entry
points that construct these programs (Scope 3 / callers) to perform the same
check or propagate the validated value rather than attempting partial-tile
handling.

- scope1: change q_proj/k_proj/v_proj signature, tensor specs and golden
  to FP32 and drop the BF16 cast on projection write-back
- tilelet: same FP32 promotion for the intermediate Q/K/V tensors,
  zero-init, projection assemble and golden reference
- tilelet: drop now-redundant FP32 casts when reading q_proj/k_proj
  during RoPE/attention, and add an explicit BF16 cast before writing
  the FP32 v_proj slice into the BF16 v_cache
@YunjiQin YunjiQin changed the title Fix: correct RMSNorm to use rsqrt and switch projections to BF16 output Fix: correct RMSNorm to use rsqrt and keep projections in FP32 Apr 7, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
examples/models/qwen3/qwen3_32b_decode_scope1.py (1)

219-220: ⚠️ Potential issue | 🟡 Minor

Golden reference also uses hardcoded HIDDEN_INV.

Both the kernel (line 94) and the golden reference (line 220) use the module-level HIDDEN_INV constant. This means tests will pass for non-default hidden_size values, but both implementations will produce mathematically incorrect results.

Consider fixing both together:

♻️ Suggested fix for golden
-        inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS)  # [BATCH_TILE, 1]
+        inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)  # [BATCH_TILE, 1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 219 - 220, The
code incorrectly uses the module-level constant HIDDEN_INV; compute a dynamic
inverse based on the actual hidden_size and use it instead in both the kernel
and golden reference: replace uses of HIDDEN_INV in the sq_sum -> inv_rms
computation with a computed hidden_inv = 1.0 / float(hidden_size) (or
torch.tensor(1.0/hidden_size, device=..., dtype=... ) if needed) and change the
line inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) to inv_rms =
torch.rsqrt(sq_sum * hidden_inv + EPS); apply the same replacement wherever
HIDDEN_INV is used (e.g., in the kernel implementation and in the golden
reference) so sq_sum, inv_rms, EPS all use the runtime hidden_size rather than a
hardcoded constant.
🧹 Nitpick comments (1)
examples/models/qwen3/qwen3_32b_decode_tilelet.py (1)

530-535: Golden reference doesn't match kernel's chunked accumulation pattern.

The kernel performs chunked FP32 accumulation over K_CHUNK (lines 210-214), but the golden reference performs a single full matmul (normed.float() @ wq.float()). While mathematically equivalent in exact arithmetic, BF16 intermediate precision and FP32 accumulation ordering can produce different rounding behavior.

For consistency with the kernel's precision path (and to match the golden_decode_projection in qwen3_32b_decode_scope1.py which does use chunked accumulation), consider updating this golden to use the same chunked pattern:

♻️ Suggested alignment with kernel accumulation
-        q_proj[b0:b_end, :] = normed.float() @ wq.float()
-        k_proj[b0:b_end, :] = normed.float() @ wk.float()
-        v_proj[b0:b_end, :] = normed.float() @ wv.float()
+        q_acc = torch.zeros(b_end - b0, hidden_size, dtype=torch.float32)
+        k_acc = torch.zeros(b_end - b0, kv_hidden, dtype=torch.float32)
+        v_acc = torch.zeros(b_end - b0, kv_hidden, dtype=torch.float32)
+        for k0 in range(0, hidden_size, K_CHUNK):
+            normed_chunk = normed[:, k0:k0 + K_CHUNK].float()
+            q_acc = q_acc + normed_chunk @ wq[k0:k0 + K_CHUNK, :].float()
+            k_acc = k_acc + normed_chunk @ wk[k0:k0 + K_CHUNK, :].float()
+            v_acc = v_acc + normed_chunk @ wv[k0:k0 + K_CHUNK, :].float()
+        q_proj[b0:b_end, :] = q_acc
+        k_proj[b0:b_end, :] = k_acc
+        v_proj[b0:b_end, :] = v_acc
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 535,
The golden reference uses a single full matmul for q/k/v (normed.float() @
wq.float(), etc.) which doesn't match the kernel's chunked FP32 accumulation
over K_CHUNK; update the block that computes q_proj, k_proj, v_proj to perform
chunked accumulation: iterate over input feature dimension in slices of K_CHUNK,
compute partial = normed[:, k0:k1].float() @ wX[k0:k1, :].float() for each of
wq/wk/wv, accumulate into FP32 accumulators, and only after the loop cast/assign
to q_proj/k_proj/v_proj as in the kernel and as done in golden_decode_projection
in qwen3_32b_decode_scope1.py so rounding/order match the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-531: The RMSNorm in the other implementation is wrong: instead
of multiplying by "variance" you must divide by the RMS (i.e. multiply by
1/sqrt(variance)). In the code that computes variance = sq_sum / hidden_size +
EPS and then normed = x_tile * variance * input_rms_weight.float(), replace that
with computing inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) and set normed
= x_tile * inv_rms * input_rms_weight.float(), using the same variable names
(sq_sum, hidden_size, EPS, x_tile, input_rms_weight, normed) so both
implementations match.

---

Duplicate comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py`:
- Around line 219-220: The code incorrectly uses the module-level constant
HIDDEN_INV; compute a dynamic inverse based on the actual hidden_size and use it
instead in both the kernel and golden reference: replace uses of HIDDEN_INV in
the sq_sum -> inv_rms computation with a computed hidden_inv = 1.0 /
float(hidden_size) (or torch.tensor(1.0/hidden_size, device=..., dtype=... ) if
needed) and change the line inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) to
inv_rms = torch.rsqrt(sq_sum * hidden_inv + EPS); apply the same replacement
wherever HIDDEN_INV is used (e.g., in the kernel implementation and in the
golden reference) so sq_sum, inv_rms, EPS all use the runtime hidden_size rather
than a hardcoded constant.

---

Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-535: The golden reference uses a single full matmul for q/k/v
(normed.float() @ wq.float(), etc.) which doesn't match the kernel's chunked
FP32 accumulation over K_CHUNK; update the block that computes q_proj, k_proj,
v_proj to perform chunked accumulation: iterate over input feature dimension in
slices of K_CHUNK, compute partial = normed[:, k0:k1].float() @ wX[k0:k1,
:].float() for each of wq/wk/wv, accumulate into FP32 accumulators, and only
after the loop cast/assign to q_proj/k_proj/v_proj as in the kernel and as done
in golden_decode_projection in qwen3_32b_decode_scope1.py so rounding/order
match the kernel.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 21e9a707-d40e-4083-a9fb-0a593104c361

📥 Commits

Reviewing files that changed from the base of the PR and between 994c33a and 6bce850.

📒 Files selected for processing (2)
  • examples/models/qwen3/qwen3_32b_decode_scope1.py
  • examples/models/qwen3/qwen3_32b_decode_tilelet.py

Comment on lines +530 to +531
inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)
normed = (x_tile * inv_rms * input_rms_weight.float()).bfloat16()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if qwen3_32b_decode_scope12.py uses the old variance formulation
rg -n "variance\s*=\s*sq_sum" --type py

Repository: hw-native-sys/pypto-lib

Length of output: 166


🏁 Script executed:

# First, confirm the tilelet file has the rsqrt formulation at lines 530-531
fd "qwen3_32b_decode_tilelet.py" --type f

Repository: hw-native-sys/pypto-lib

Length of output: 116


🏁 Script executed:

# Read the tilelet file around lines 530-531
sed -n '525,535p' examples/models/qwen3/qwen3_32b_decode_tilelet.py

Repository: hw-native-sys/pypto-lib

Length of output: 609


🏁 Script executed:

# Read the scope12 file around the variance line (401) with context
sed -n '395,410p' examples/models/qwen3/qwen3_32b_decode_scope12.py

Repository: hw-native-sys/pypto-lib

Length of output: 810


🏁 Script executed:

# Check if both files are imported or used together
rg -l "qwen3_32b_decode_tilelet|qwen3_32b_decode_scope12" --type py

Repository: hw-native-sys/pypto-lib

Length of output: 164


Critical RMSNorm formulation mismatch between implementations.

qwen3_32b_decode_tilelet.py (lines 530-531) uses the correct RMSNorm formulation:

inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)
normed = x_tile * inv_rms * input_rms_weight.float()

But qwen3_32b_decode_scope12.py (line 401) uses an incorrect formula that multiplies by variance instead of dividing by its square root:

variance = sq_sum / hidden_size + EPS
normed = x_tile * variance * input_rms_weight.float()

These compute fundamentally different operations and will produce drastically different outputs. This inconsistency will cause golden reference divergence and test failures when both implementations are used together.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 531,
The RMSNorm in the other implementation is wrong: instead of multiplying by
"variance" you must divide by the RMS (i.e. multiply by 1/sqrt(variance)). In
the code that computes variance = sq_sum / hidden_size + EPS and then normed =
x_tile * variance * input_rms_weight.float(), replace that with computing
inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) and set normed = x_tile *
inv_rms * input_rms_weight.float(), using the same variable names (sq_sum,
hidden_size, EPS, x_tile, input_rms_weight, normed) so both implementations
match.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant