Fix: correct RMSNorm to use rsqrt and keep projections in FP32#69
Fix: correct RMSNorm to use rsqrt and keep projections in FP32#69YunjiQin wants to merge 2 commits intohw-native-sys:mainfrom
Conversation
- Replace raw variance with rsqrt-based inv_rms in RMSNorm for both scope1 and tilelet decode examples - Change Q/K/V projection output dtype from FP32 to BF16 - Restructure scope1 to use auto_incore + parallel loops instead of manual incore + range, with UP_DOWN split for projection stages - Replace matmul/matmul_acc with add + matmul accumulation pattern - Add explicit FP32 cast for RoPE q/k slices now stored as BF16 - Update golden references to match kernel precision path - Remove unused constants (MAX_SEQ, INTERMEDIATE, MLP_OUT_CHUNK)
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRefactors Qwen3 decode examples: changes RMSNorm from single-pass variance to two-pass per-batch-tile inv_rms, adds BF16 normalized buffer, and rewrites Q/K/V projection to explicit FP32 accumulation over K_CHUNK with cross-core Changes
Sequence Diagram(s)(omitted — changes are internal kernel/control-flow updates without a clear multi-component external sequence) Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the Qwen3 decode kernels to optimize performance and memory usage by switching projection outputs to BF16 and parallelizing the RMSNorm and projection stages using auto-incore and parallel loops. The reference implementations have been updated to match the new precision and chunking logic. Feedback identifies a potential issue where tensor initialization only covers the first batch tile, which could lead to uninitialized data if the batch size exceeds the tile size. Additionally, there are suggestions to fix inconsistent variable naming between 'batch' and 'BATCH_CFG' within the parallel projection loops.
| with pl.incore(): | ||
| for ob in pl.range(Q_OUT_BLOCKS): | ||
| q0 = ob * Q_OUT_CHUNK | ||
| zero_1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) | ||
| zero_1_bf = pl.cast(zero_1, target_type=pl.BF16) | ||
| q_proj = pl.assemble(q_proj, zero_1_bf, [0, q0]) | ||
| attn_out = pl.assemble(attn_out, zero_1_bf, [0, q0]) | ||
| normed_buf = pl.assemble(normed_buf, zero_1_bf, [0, q0]) | ||
| with pl.incore(): | ||
| for ob in pl.range(KV_OUT_BLOCKS): | ||
| kv0 = ob * KV_OUT_CHUNK | ||
| zero_2 = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0) | ||
| zero_2_bf = pl.cast(zero_2, target_type=pl.BF16) | ||
| k_proj = pl.assemble(k_proj, zero_2_bf, [0, kv0]) | ||
| v_proj = pl.assemble(v_proj, zero_2_bf, [0, kv0]) |
There was a problem hiding this comment.
The initialization of q_proj, k_proj, v_proj, attn_out, and normed_buf only covers the first BATCH_TILE rows (using offset 0 in pl.assemble). If BATCH_CFG is greater than BATCH_TILE, the remaining rows will be left uninitialized. Although the current default BATCH matches BATCH_TILE, this is fragile. Consider wrapping these in a loop over batch blocks to ensure the entire tensors are initialized.
| with pl.auto_incore(split=pl.SplitMode.UP_DOWN): | ||
| # Stage 2: Q projection (AIC+AIV cross-core incore). | ||
| for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): | ||
| for b0 in pl.range(0, batch, BATCH_TILE): |
There was a problem hiding this comment.
The variable batch is used here, but it is not a parameter of the qwen3_decode_layer function. While it is available via closure from the outer build_qwen3_single_layer_decode_program scope, it is inconsistent with the rest of the function which uses BATCH_CFG (e.g., lines 173 and 245). Please use BATCH_CFG for consistency.
| for b0 in pl.range(0, batch, BATCH_TILE): | |
| for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): |
| for kb in pl.range(1, HIDDEN_BLOCKS): | ||
| # Stage 3: K/V projection (AIC+AIV cross-core incore). | ||
| for ob in pl.parallel(0, KV_OUT_BLOCKS, 1, chunk=8): | ||
| for b0 in pl.range(0, batch, BATCH_TILE): |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/models/qwen3/qwen3_32b_decode_tilelet.py (1)
530-547:⚠️ Potential issue | 🟠 MajorMake the golden reference use runtime dimensions.
golden_qwen3_decode()still allocates and iterates from module-levelBATCH,MAX_SEQ,HIDDEN, andNUM_*constants.compile_and_run(...)exposes all of those as parameters, so non-default runs will compare against a wrong-sized or partially computed reference.🧪 Suggested shape derivation
batch, hidden_size = hidden_states.shape max_seq_len, head_dim = rope_cos.shape kv_hidden = wk.shape[1] num_kv_heads = kv_hidden // head_dim num_heads = hidden_size // head_dim intermediate_size = w_gate.shape[1] hidden_inv = 1.0 / hidden_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 547, The golden_qwen3_decode reference still uses module-level constants (BATCH, MAX_SEQ, HIDDEN, NUM_*) causing mismatched shapes for non-default runs; change golden_qwen3_decode to derive all dimensions at runtime from inputs (e.g., batch, hidden_size = hidden_states.shape; max_seq_len, head_dim = rope_cos.shape; kv_hidden = wk.shape[1]; num_kv_heads = kv_hidden // head_dim; num_heads = hidden_size // head_dim; intermediate_size = w_gate.shape[1]; hidden_inv = 1.0 / hidden_size) and replace any allocations, loop bounds, and slices that use BATCH, MAX_SEQ, HIDDEN, NUM_* with these derived variables so compile_and_run(...) comparisons use correct sizes for arbitrary arguments.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py`:
- Around line 93-95: The RMSNorm calculation in
build_decode_projection_program() uses the module-level constant HIDDEN_INV
instead of the per-call hidden size; replace uses of HIDDEN_INV (e.g., where
inv_rms_tile is computed) with 1.0 / hidden (or a local hidden_inv computed from
the hidden parameter) so RMS scaling uses the configured hidden size; apply the
same change to the other occurrence around lines 221-229 where HIDDEN_INV is
referenced.
- Around line 68-75: The scope boundary types are mismatched:
qwen3_32b_decode_scope1.py now emits BF16 q_proj/k_proj/v_proj but the
downstream consumer in qwen3_32b_decode_scope2.py still declares those inputs as
FP32; update the scope2 function signature to accept pl.BF16 for q_proj, k_proj,
and v_proj (replace pl.FP32/float annotations with pl.BF16) so it matches the
outputs from qwen3_32b_decode_scope1.py, and make the same annotation fixes for
the other occurrences mentioned (around the second instance at the lines
referenced 189-191) to keep the dtype contract consistent across the boundary.
- Line 56: The code computes batch_blocks = batch // BATCH_TILE which silently
drops a tail when batch is not a multiple of BATCH_TILE, causing Stage 1 to skip
the last partial tile while later stages still read from normed_buf; add an
explicit validation at the start of the pipeline (before computing batch_blocks
and before any projection loops that use range(0, batch, BATCH_TILE)) to either
raise a ValueError when batch <= 0 or batch % BATCH_TILE != 0 or else change the
tiling logic to use math.ceil to include a final partial tile and ensure Stage 1
writes the tail into normed_buf; reference the symbols batch_blocks, BATCH_TILE,
the projection loop ranges (range(0, batch, BATCH_TILE)), and normed_buf when
making the change so all stages remain consistent.
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 154-168: The code assumes full BATCH_TILE tiles when assembling
buffers (e.g., in the pl.incore loops that write slices of shape [BATCH_TILE,
...]), so add a guard in build_qwen3_single_layer_decode_program to validate the
batch argument (e.g., if batch <= 0 or batch % BATCH_TILE != 0: raise
ValueError(f"batch ({batch}) must be a positive multiple of BATCH_TILE
({BATCH_TILE})")); update any other entry points that construct these programs
(Scope 3 / callers) to perform the same check or propagate the validated value
rather than attempting partial-tile handling.
- Around line 186-188: The RMS normalization uses the module-level constant
HIDDEN_INV (1/8192) which breaks when build_qwen3_single_layer_decode_program()
is called with a non-default hidden_size; compute a per-call inverse hidden size
(e.g., hidden_inv_cfg = 1.0 / HIDDEN_CFG or based on the hidden_size parameter)
and replace uses of HIDDEN_INV in the inv_rms_tile calculation (the
pl.rsqrt(...) / reshape block that defines inv_rms_tile) and the later
post-RMSNorm path so both normalization sites use the per-call hidden_inv_cfg
instead of the global HIDDEN_INV.
---
Outside diff comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-547: The golden_qwen3_decode reference still uses module-level
constants (BATCH, MAX_SEQ, HIDDEN, NUM_*) causing mismatched shapes for
non-default runs; change golden_qwen3_decode to derive all dimensions at runtime
from inputs (e.g., batch, hidden_size = hidden_states.shape; max_seq_len,
head_dim = rope_cos.shape; kv_hidden = wk.shape[1]; num_kv_heads = kv_hidden //
head_dim; num_heads = hidden_size // head_dim; intermediate_size =
w_gate.shape[1]; hidden_inv = 1.0 / hidden_size) and replace any allocations,
loop bounds, and slices that use BATCH, MAX_SEQ, HIDDEN, NUM_* with these
derived variables so compile_and_run(...) comparisons use correct sizes for
arbitrary arguments.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: f3d6ba06-3d2b-4211-9460-a45220ab8b55
📒 Files selected for processing (2)
examples/models/qwen3/qwen3_32b_decode_scope1.pyexamples/models/qwen3/qwen3_32b_decode_tilelet.py
| hidden_blocks = hidden // K_CHUNK | ||
| q_out_blocks = hidden // Q_OUT_CHUNK | ||
| kv_out_blocks = kv_hidden // KV_OUT_CHUNK | ||
| batch_blocks = batch // BATCH_TILE |
There was a problem hiding this comment.
Reject or handle non-multiple batch sizes.
batch_blocks = batch // BATCH_TILE truncates the tail, but the projection loops still iterate range(0, batch, BATCH_TILE). For batch=17, Stage 1 never writes row 16 while Stages 2/3 still slice a full [16, ...] tile from normed_buf.
🛡️ Suggested guard
if batch <= 0 or batch % BATCH_TILE != 0:
raise ValueError(
f"batch ({batch}) must be a positive multiple of BATCH_TILE ({BATCH_TILE})"
)Also applies to: 76-80, 112-145
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` at line 56, The code
computes batch_blocks = batch // BATCH_TILE which silently drops a tail when
batch is not a multiple of BATCH_TILE, causing Stage 1 to skip the last partial
tile while later stages still read from normed_buf; add an explicit validation
at the start of the pipeline (before computing batch_blocks and before any
projection loops that use range(0, batch, BATCH_TILE)) to either raise a
ValueError when batch <= 0 or batch % BATCH_TILE != 0 or else change the tiling
logic to use math.ceil to include a final partial tile and ensure Stage 1 writes
the tail into normed_buf; reference the symbols batch_blocks, BATCH_TILE, the
projection loop ranges (range(0, batch, BATCH_TILE)), and normed_buf when making
the change so all stages remain consistent.
| q_proj: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]], | ||
| k_proj: pl.Out[pl.Tensor[[batch, kv_hidden], pl.BF16]], | ||
| v_proj: pl.Out[pl.Tensor[[batch, kv_hidden], pl.BF16]], | ||
| ) -> tuple[ | ||
| pl.Tensor[[batch, hidden], pl.FP32], | ||
| pl.Tensor[[batch, kv_hidden], pl.FP32], | ||
| pl.Tensor[[batch, kv_hidden], pl.FP32], | ||
| pl.Tensor[[batch, hidden], pl.BF16], | ||
| pl.Tensor[[batch, kv_hidden], pl.BF16], | ||
| pl.Tensor[[batch, kv_hidden], pl.BF16], | ||
| ]: |
There was a problem hiding this comment.
Align the Scope 1 → Scope 2 dtype contract.
If these outputs still feed examples/models/qwen3/qwen3_32b_decode_scope2.py, that function still declares FP32 q_proj, k_proj, and v_proj inputs on Lines 71-76 while this file now exposes BF16 tensors. The top-level orchestrator already allocates BF16 buffers on Lines 104-106 of examples/models/qwen3/qwen3_32b_decode.py, so the consumer change needs to land with this boundary change.
Also applies to: 189-191
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 68 - 75, The
scope boundary types are mismatched: qwen3_32b_decode_scope1.py now emits BF16
q_proj/k_proj/v_proj but the downstream consumer in qwen3_32b_decode_scope2.py
still declares those inputs as FP32; update the scope2 function signature to
accept pl.BF16 for q_proj, k_proj, and v_proj (replace pl.FP32/float annotations
with pl.BF16) so it matches the outputs from qwen3_32b_decode_scope1.py, and
make the same annotation fixes for the other occurrences mentioned (around the
second instance at the lines referenced 189-191) to keep the dtype contract
consistent across the boundary.
| inv_rms_tile: pl.Tensor[[BATCH_TILE, 1], pl.FP32] = pl.reshape( | ||
| pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)), | ||
| [BATCH_TILE, 1], |
There was a problem hiding this comment.
Use the configured hidden size in RMSNorm.
This path still normalizes with the module-level HIDDEN_INV (1 / 8192) even though build_decode_projection_program() accepts arbitrary hidden_size. The PyTorch golden mirrors the same constant, so non-default sizes will be wrong and still appear green.
⚙️ Suggested fix
hidden_inv = 1.0 / hidden- pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)),
+ pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), EPS)),- inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) # [BATCH_TILE, 1]
+ inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)Also applies to: 221-229
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 93 - 95, The
RMSNorm calculation in build_decode_projection_program() uses the module-level
constant HIDDEN_INV instead of the per-call hidden size; replace uses of
HIDDEN_INV (e.g., where inv_rms_tile is computed) with 1.0 / hidden (or a local
hidden_inv computed from the hidden parameter) so RMS scaling uses the
configured hidden size; apply the same change to the other occurrence around
lines 221-229 where HIDDEN_INV is referenced.
| with pl.incore(): | ||
| for ob in pl.range(Q_OUT_BLOCKS): | ||
| q0 = ob * Q_OUT_CHUNK | ||
| zero_1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) | ||
| zero_1_bf = pl.cast(zero_1, target_type=pl.BF16) | ||
| q_proj = pl.assemble(q_proj, zero_1_bf, [0, q0]) | ||
| attn_out = pl.assemble(attn_out, zero_1_bf, [0, q0]) | ||
| normed_buf = pl.assemble(normed_buf, zero_1_bf, [0, q0]) | ||
| with pl.incore(): | ||
| for ob in pl.range(KV_OUT_BLOCKS): | ||
| kv0 = ob * KV_OUT_CHUNK | ||
| zero_2 = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0) | ||
| zero_2_bf = pl.cast(zero_2, target_type=pl.BF16) | ||
| k_proj = pl.assemble(k_proj, zero_2_bf, [0, kv0]) | ||
| v_proj = pl.assemble(v_proj, zero_2_bf, [0, kv0]) |
There was a problem hiding this comment.
Guard the full-tile batch invariant.
These loops always read and write [BATCH_TILE, ...] slices. build_qwen3_single_layer_decode_program(batch=...) will go out of bounds for batch < 16 or any batch % 16 != 0, and Scope 3 still has the same assumption. Either add partial-tile handling or fail fast on unsupported batch sizes.
🛡️ Suggested guard
if batch <= 0 or batch % BATCH_TILE != 0:
raise ValueError(
f"batch ({batch}) must be a positive multiple of BATCH_TILE ({BATCH_TILE})"
)Also applies to: 172-238
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 154 - 168,
The code assumes full BATCH_TILE tiles when assembling buffers (e.g., in the
pl.incore loops that write slices of shape [BATCH_TILE, ...]), so add a guard in
build_qwen3_single_layer_decode_program to validate the batch argument (e.g., if
batch <= 0 or batch % BATCH_TILE != 0: raise ValueError(f"batch ({batch}) must
be a positive multiple of BATCH_TILE ({BATCH_TILE})")); update any other entry
points that construct these programs (Scope 3 / callers) to perform the same
check or propagate the validated value rather than attempting partial-tile
handling.
- scope1: change q_proj/k_proj/v_proj signature, tensor specs and golden to FP32 and drop the BF16 cast on projection write-back - tilelet: same FP32 promotion for the intermediate Q/K/V tensors, zero-init, projection assemble and golden reference - tilelet: drop now-redundant FP32 casts when reading q_proj/k_proj during RoPE/attention, and add an explicit BF16 cast before writing the FP32 v_proj slice into the BF16 v_cache
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
examples/models/qwen3/qwen3_32b_decode_scope1.py (1)
219-220:⚠️ Potential issue | 🟡 MinorGolden reference also uses hardcoded
HIDDEN_INV.Both the kernel (line 94) and the golden reference (line 220) use the module-level
HIDDEN_INVconstant. This means tests will pass for non-defaulthidden_sizevalues, but both implementations will produce mathematically incorrect results.Consider fixing both together:
♻️ Suggested fix for golden
- inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) # [BATCH_TILE, 1] + inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) # [BATCH_TILE, 1]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode_scope1.py` around lines 219 - 220, The code incorrectly uses the module-level constant HIDDEN_INV; compute a dynamic inverse based on the actual hidden_size and use it instead in both the kernel and golden reference: replace uses of HIDDEN_INV in the sq_sum -> inv_rms computation with a computed hidden_inv = 1.0 / float(hidden_size) (or torch.tensor(1.0/hidden_size, device=..., dtype=... ) if needed) and change the line inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) to inv_rms = torch.rsqrt(sq_sum * hidden_inv + EPS); apply the same replacement wherever HIDDEN_INV is used (e.g., in the kernel implementation and in the golden reference) so sq_sum, inv_rms, EPS all use the runtime hidden_size rather than a hardcoded constant.
🧹 Nitpick comments (1)
examples/models/qwen3/qwen3_32b_decode_tilelet.py (1)
530-535: Golden reference doesn't match kernel's chunked accumulation pattern.The kernel performs chunked FP32 accumulation over
K_CHUNK(lines 210-214), but the golden reference performs a single full matmul (normed.float() @ wq.float()). While mathematically equivalent in exact arithmetic, BF16 intermediate precision and FP32 accumulation ordering can produce different rounding behavior.For consistency with the kernel's precision path (and to match the
golden_decode_projectioninqwen3_32b_decode_scope1.pywhich does use chunked accumulation), consider updating this golden to use the same chunked pattern:♻️ Suggested alignment with kernel accumulation
- q_proj[b0:b_end, :] = normed.float() @ wq.float() - k_proj[b0:b_end, :] = normed.float() @ wk.float() - v_proj[b0:b_end, :] = normed.float() @ wv.float() + q_acc = torch.zeros(b_end - b0, hidden_size, dtype=torch.float32) + k_acc = torch.zeros(b_end - b0, kv_hidden, dtype=torch.float32) + v_acc = torch.zeros(b_end - b0, kv_hidden, dtype=torch.float32) + for k0 in range(0, hidden_size, K_CHUNK): + normed_chunk = normed[:, k0:k0 + K_CHUNK].float() + q_acc = q_acc + normed_chunk @ wq[k0:k0 + K_CHUNK, :].float() + k_acc = k_acc + normed_chunk @ wk[k0:k0 + K_CHUNK, :].float() + v_acc = v_acc + normed_chunk @ wv[k0:k0 + K_CHUNK, :].float() + q_proj[b0:b_end, :] = q_acc + k_proj[b0:b_end, :] = k_acc + v_proj[b0:b_end, :] = v_acc🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 535, The golden reference uses a single full matmul for q/k/v (normed.float() @ wq.float(), etc.) which doesn't match the kernel's chunked FP32 accumulation over K_CHUNK; update the block that computes q_proj, k_proj, v_proj to perform chunked accumulation: iterate over input feature dimension in slices of K_CHUNK, compute partial = normed[:, k0:k1].float() @ wX[k0:k1, :].float() for each of wq/wk/wv, accumulate into FP32 accumulators, and only after the loop cast/assign to q_proj/k_proj/v_proj as in the kernel and as done in golden_decode_projection in qwen3_32b_decode_scope1.py so rounding/order match the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-531: The RMSNorm in the other implementation is wrong: instead
of multiplying by "variance" you must divide by the RMS (i.e. multiply by
1/sqrt(variance)). In the code that computes variance = sq_sum / hidden_size +
EPS and then normed = x_tile * variance * input_rms_weight.float(), replace that
with computing inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) and set normed
= x_tile * inv_rms * input_rms_weight.float(), using the same variable names
(sq_sum, hidden_size, EPS, x_tile, input_rms_weight, normed) so both
implementations match.
---
Duplicate comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope1.py`:
- Around line 219-220: The code incorrectly uses the module-level constant
HIDDEN_INV; compute a dynamic inverse based on the actual hidden_size and use it
instead in both the kernel and golden reference: replace uses of HIDDEN_INV in
the sq_sum -> inv_rms computation with a computed hidden_inv = 1.0 /
float(hidden_size) (or torch.tensor(1.0/hidden_size, device=..., dtype=... ) if
needed) and change the line inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) to
inv_rms = torch.rsqrt(sq_sum * hidden_inv + EPS); apply the same replacement
wherever HIDDEN_INV is used (e.g., in the kernel implementation and in the
golden reference) so sq_sum, inv_rms, EPS all use the runtime hidden_size rather
than a hardcoded constant.
---
Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py`:
- Around line 530-535: The golden reference uses a single full matmul for q/k/v
(normed.float() @ wq.float(), etc.) which doesn't match the kernel's chunked
FP32 accumulation over K_CHUNK; update the block that computes q_proj, k_proj,
v_proj to perform chunked accumulation: iterate over input feature dimension in
slices of K_CHUNK, compute partial = normed[:, k0:k1].float() @ wX[k0:k1,
:].float() for each of wq/wk/wv, accumulate into FP32 accumulators, and only
after the loop cast/assign to q_proj/k_proj/v_proj as in the kernel and as done
in golden_decode_projection in qwen3_32b_decode_scope1.py so rounding/order
match the kernel.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 21e9a707-d40e-4083-a9fb-0a593104c361
📒 Files selected for processing (2)
examples/models/qwen3/qwen3_32b_decode_scope1.pyexamples/models/qwen3/qwen3_32b_decode_tilelet.py
| inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) | ||
| normed = (x_tile * inv_rms * input_rms_weight.float()).bfloat16() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if qwen3_32b_decode_scope12.py uses the old variance formulation
rg -n "variance\s*=\s*sq_sum" --type pyRepository: hw-native-sys/pypto-lib
Length of output: 166
🏁 Script executed:
# First, confirm the tilelet file has the rsqrt formulation at lines 530-531
fd "qwen3_32b_decode_tilelet.py" --type fRepository: hw-native-sys/pypto-lib
Length of output: 116
🏁 Script executed:
# Read the tilelet file around lines 530-531
sed -n '525,535p' examples/models/qwen3/qwen3_32b_decode_tilelet.pyRepository: hw-native-sys/pypto-lib
Length of output: 609
🏁 Script executed:
# Read the scope12 file around the variance line (401) with context
sed -n '395,410p' examples/models/qwen3/qwen3_32b_decode_scope12.pyRepository: hw-native-sys/pypto-lib
Length of output: 810
🏁 Script executed:
# Check if both files are imported or used together
rg -l "qwen3_32b_decode_tilelet|qwen3_32b_decode_scope12" --type pyRepository: hw-native-sys/pypto-lib
Length of output: 164
Critical RMSNorm formulation mismatch between implementations.
qwen3_32b_decode_tilelet.py (lines 530-531) uses the correct RMSNorm formulation:
inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS)
normed = x_tile * inv_rms * input_rms_weight.float()But qwen3_32b_decode_scope12.py (line 401) uses an incorrect formula that multiplies by variance instead of dividing by its square root:
variance = sq_sum / hidden_size + EPS
normed = x_tile * variance * input_rms_weight.float()These compute fundamentally different operations and will produce drastically different outputs. This inconsistency will cause golden reference divergence and test failures when both implementations are used together.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_tilelet.py` around lines 530 - 531,
The RMSNorm in the other implementation is wrong: instead of multiplying by
"variance" you must divide by the RMS (i.e. multiply by 1/sqrt(variance)). In
the code that computes variance = sq_sum / hidden_size + EPS and then normed =
x_tile * variance * input_rms_weight.float(), replace that with computing
inv_rms = torch.rsqrt(sq_sum / hidden_size + EPS) and set normed = x_tile *
inv_rms * input_rms_weight.float(), using the same variable names (sq_sum,
hidden_size, EPS, x_tile, input_rms_weight, normed) so both implementations
match.
Summary
auto_incore+parallelloops withUP_DOWNsplit for projection stagesmatmul/matmul_accwithadd+matmulaccumulation pattern