Skip to content

Refactor: Qwen3 decode with 3-scope architecture and TILELET rename#99

Merged
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
bumble0918:scope123
Apr 10, 2026
Merged

Refactor: Qwen3 decode with 3-scope architecture and TILELET rename#99
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
bumble0918:scope123

Conversation

@bumble0918
Copy link
Copy Markdown
Contributor

  • qwen3_32b_decode.py: Refactored into 3 scopes for better incore
    • Scope 1: Input RMSNorm + Q/K/V projection
    • Scope 2: Attention (K RoPE + cache, QK matmul, softmax, SV matmul)
    • Scope 3: Output projection, residual, RMSNorm, MLP
  • Updated HIDDEN size from 5120 to 8192 (64 heads × 128 dim)
  • Renamed qwen3_32b_decode_tilelet.py to qwen3_32b_decode_mixed.py for clearer TILELET-aware version naming
  • Adjusted tiling constants for each scope

- qwen3_32b_decode.py: Refactored into 3 scopes for better incore
  * Scope 1: Input RMSNorm + Q/K/V projection
  * Scope 2: Attention (K RoPE + cache, QK matmul, softmax, SV matmul)
  * Scope 3: Output projection, residual, RMSNorm, MLP
- Updated HIDDEN size from 5120 to 8192 (64 heads × 128 dim)
- Renamed qwen3_32b_decode_tilelet.py to qwen3_32b_decode_mixed.py
  for clearer TILELET-aware version naming
- Adjusted tiling constants for each scope
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 10, 2026

📝 Walkthrough

Walkthrough

Restructured Qwen3 decode from single-layer into three explicit phases (Scope 1/2/3): Scope 1 performs input normalization and Q/K/V projections in FP32; Scope 2 implements grouped-query attention with KV cache updates and staged softmax; Scope 3 handles output projection, residual operations, and MLP. Added PyTorch golden reference implementation and updated function signatures with new parameters.

Changes

Cohort / File(s) Summary
Qwen3 Decode Restructure
examples/models/qwen3/qwen3_32b_decode.py
Replaced Qwen3SingleLayerDecode with Qwen3Scope123 program dividing computation into Scope 1 (RMSNorm + FP32 Q/K/V projections), Scope 2 (grouped-query attention with KV caching and staged softmax), and Scope 3 (output projection, residual, post-attention RMSNorm, SwiGLU MLP). Renamed build_qwen3_single_layer_decode_program() to build_qwen3_decode_program() with updated parameter names (max_seq_lenmax_seq, added hidden_size). Added golden_qwen3_decode() PyTorch reference. Updated tiling constants, attention scale handling, tensor shape derivation, and runtime verification. Changed default platform from "a2a3" to "a5" and added use_max_seq parameter with corresponding --max-seq CLI flag.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input Tensors
    participant S1 as Scope 1<br/>RMSNorm + Projections
    participant S2 as Scope 2<br/>Attention + KV Cache
    participant S3 as Scope 3<br/>Output + MLP + Residual
    participant Output as Output Tensors

    Input->>S1: raw activations
    S1->>S1: apply input RMSNorm
    S1->>S1: project to Q/K/V (FP32)
    S1->>S2: Q, K, V intermediates

    S2->>S2: update KV caches
    S2->>S2: apply RoPE to Q/K
    S2->>S2: grouped-query attention<br/>(staged softmax + accumulation)
    S2->>S3: attention output

    S3->>S3: output projection
    S3->>S3: add residual
    S3->>S3: post-attention RMSNorm
    S3->>S3: SwiGLU MLP
    S3->>S3: final residual assembly
    S3->>Output: decoded layer output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

🐰 Three scopes now dance in perfect harmony,
Scope One normalizes with graceful ceremony,
Scope Two attends with padded precision so fine,
Scope Three projects and fuses the line,
From single to triple, the decode takes flight! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: refactoring Qwen3 decode into a 3-scope architecture and renaming related to TILELET.
Description check ✅ Passed The description directly relates to the changeset, detailing the 3-scope refactoring, HIDDEN size update, TILELET rename, and tiling adjustments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Qwen3-32B single-layer decode forward implementation into three distinct scopes, optimizing tiling strategies and adding a PyTorch reference function for validation. Feedback focuses on improving code robustness and performance: specifically, reshaping the results of pl.row_max and pl.row_sum to maintain consistent 2D tensor dimensions, and replacing inefficient loops of small pl.assemble calls with pl.full for tensor initialization to reduce DMA overhead.

)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
cur_mi = pl.row_max(scores)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result of pl.row_max should be reshaped to [Q_HEAD_BATCH, 1] to ensure consistent dimensions for subsequent pl.assemble and pl.row_expand_sub operations, as the target tensors are 2D.

Suggested change
cur_mi = pl.row_max(scores)
cur_mi = pl.reshape(pl.row_max(scores), [Q_HEAD_BATCH, 1])

exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi))
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result of pl.row_sum should be reshaped to [Q_HEAD_BATCH, 1] to match the shape of all_cur_li and ensure correct broadcasting in the online softmax update.

Suggested change
cur_li = pl.row_sum(exp_scores_fp32)
cur_li = pl.reshape(pl.row_sum(exp_scores_fp32), [Q_HEAD_BATCH, 1])

Comment on lines +185 to +190
for idx in pl.range(batch * total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initializing the padding rows of all_q_padded using a loop of pl.assemble calls is inefficient as each call typically triggers a separate DMA transfer. It would be better to initialize the tensor with zeros using pl.full or perform a single larger pl.assemble if possible.

Comment on lines +258 to +289
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.incore():
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
all_raw_scores = pl.assemble(
all_raw_scores,
pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
all_exp_padded = pl.assemble(
all_exp_padded,
pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[sb * Q_HEAD_PAD, 0],
)
all_oi_tmp = pl.assemble(
all_oi_tmp,
pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_mi = pl.assemble(
all_cur_mi,
pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)
li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_li = pl.assemble(
all_cur_li,
pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This 'Workaround' for initializing intermediate tensors uses nested loops with many small pl.assemble calls. This pattern is highly inefficient for Global Memory operations. Consider using pl.full on the entire tensor or larger chunks to reduce the number of DMA operations.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
examples/models/qwen3/qwen3_32b_decode.py (3)

23-23: Consider using ASCII x instead of Unicode × in documentation.

The docstring uses the Unicode multiplication sign (×) which could cause issues in some editors or environments. This is flagged by static analysis (RUF002/RUF003).

Suggested fix
-  1. Output projection: attn_out × wo
+  1. Output projection: attn_out x wo

Also applies to line 382.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode.py` at line 23, Replace the Unicode
multiplication sign in the docstrings with an ASCII 'x' so editors/static
checkers don't flag it; locate the string "Output projection: attn_out × wo"
(and the similar occurrence noted around line 382) in
examples/models/qwen3/qwen3_32b_decode.py and change "×" to "x" (i.e., "attn_out
x wo") while preserving surrounding text and formatting.

253-289: Clarify the "Workaround" comment purpose.

The comment on line 252 says "# Workaround" without explaining what it's working around. Consider adding a brief explanation for maintainability.

Suggested clarification
-                    # Workaround
+                    # Workaround: pre-allocate intermediate tensors for all context blocks
+                    # to avoid dynamic tensor creation within the attention loop
                     all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 253 - 289, The lone
"# Workaround" comment is too vague—update it to explain what bug or
hardware/SDK limitation the loop over ctx_blocks and the in-core assembly of
all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, and all_cur_li is
addressing (e.g., avoiding OOM, alignment/padding requirements, or a compiler
bug when assembling large tensors), and mention why mi_init_flat/li_init_flat
are reshaped into Q_HEAD_BATCH×1; place the clarified comment immediately above
the for sb0 in pl.range(...) loop and include the specific constraint or
behavior being worked around plus any links or issue IDs if available.

72-79: Integer division assumes exact divisibility without validation.

The tiling calculations silently truncate if dimensions aren't evenly divisible by chunk sizes. With default constants this works, but custom configurations could produce incorrect results.

Consider adding assertions for non-default configurations:

Proposed validation
     hidden = hidden_size
     kv_hidden = num_kv_heads * head_dim
     inter = intermediate_size
+    assert hidden % K_CHUNK == 0, f"hidden ({hidden}) must be divisible by K_CHUNK ({K_CHUNK})"
+    assert hidden % Q_OUT_CHUNK == 0, f"hidden ({hidden}) must be divisible by Q_OUT_CHUNK ({Q_OUT_CHUNK})"
+    assert kv_hidden % KV_OUT_CHUNK == 0, f"kv_hidden ({kv_hidden}) must be divisible by KV_OUT_CHUNK ({KV_OUT_CHUNK})"
+    assert inter % MLP_OUT_CHUNK == 0, f"inter ({inter}) must be divisible by MLP_OUT_CHUNK ({MLP_OUT_CHUNK})"
     hidden_blocks = hidden // K_CHUNK
     q_out_blocks = hidden // Q_OUT_CHUNK
     kv_out_blocks = kv_hidden // KV_OUT_CHUNK
     mlp_out_blocks = inter // MLP_OUT_CHUNK
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 72 - 79, The tiling
variables hidden_blocks, q_out_blocks, kv_out_blocks, mlp_out_blocks,
cache_rows, half_dim, q_per_kv, and q_groups perform integer division without
validating divisibility; add explicit assertions (or raise ValueError) to verify
that hidden % K_CHUNK == 0, hidden % Q_OUT_CHUNK == 0, kv_hidden % KV_OUT_CHUNK
== 0, inter % MLP_OUT_CHUNK == 0, head_dim % 2 == 0, num_heads % num_kv_heads ==
0, and q_per_kv % Q_HEAD_BATCH == 0 (using the same variable names from the
diff) so misconfigured shapes fail fast with a clear error message describing
which dimension/chunk pair is invalid.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode.py`:
- Around line 493-496: The code defines a local variable hidden = num_heads *
head_dim but other tensor specs use hidden_size, causing inconsistency and a
potentially unused shadow variable; change the code so the output tensor shape
and any uses of hidden reference hidden_size (or set hidden = hidden_size if you
prefer a local alias) and remove any other unused definition of hidden or
replace its usage with hidden_size to ensure all tensor shapes (including the
output tensor created near the current output-shape code) consistently use
hidden_size rather than recomputing num_heads * head_dim.

---

Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode.py`:
- Line 23: Replace the Unicode multiplication sign in the docstrings with an
ASCII 'x' so editors/static checkers don't flag it; locate the string "Output
projection: attn_out × wo" (and the similar occurrence noted around line 382) in
examples/models/qwen3/qwen3_32b_decode.py and change "×" to "x" (i.e., "attn_out
x wo") while preserving surrounding text and formatting.
- Around line 253-289: The lone "# Workaround" comment is too vague—update it to
explain what bug or hardware/SDK limitation the loop over ctx_blocks and the
in-core assembly of all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, and
all_cur_li is addressing (e.g., avoiding OOM, alignment/padding requirements, or
a compiler bug when assembling large tensors), and mention why
mi_init_flat/li_init_flat are reshaped into Q_HEAD_BATCH×1; place the clarified
comment immediately above the for sb0 in pl.range(...) loop and include the
specific constraint or behavior being worked around plus any links or issue IDs
if available.
- Around line 72-79: The tiling variables hidden_blocks, q_out_blocks,
kv_out_blocks, mlp_out_blocks, cache_rows, half_dim, q_per_kv, and q_groups
perform integer division without validating divisibility; add explicit
assertions (or raise ValueError) to verify that hidden % K_CHUNK == 0, hidden %
Q_OUT_CHUNK == 0, kv_hidden % KV_OUT_CHUNK == 0, inter % MLP_OUT_CHUNK == 0,
head_dim % 2 == 0, num_heads % num_kv_heads == 0, and q_per_kv % Q_HEAD_BATCH ==
0 (using the same variable names from the diff) so misconfigured shapes fail
fast with a clear error message describing which dimension/chunk pair is
invalid.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 5ae171cb-263c-401e-8755-c6ea6501c21c

📥 Commits

Reviewing files that changed from the base of the PR and between 0d48e70 and 3ee28c7.

📒 Files selected for processing (2)
  • examples/models/qwen3/qwen3_32b_decode.py
  • examples/models/qwen3/qwen3_32b_decode_mixed.py

Comment on lines +493 to +496
hidden = num_heads * head_dim
kv_hidden = num_kv_heads * head_dim
cache_rows = batch * num_kv_heads * max_seq_len
inter = intermediate_size
cache_rows = batch * num_kv_heads * max_seq
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused variable hidden and inconsistent usage with hidden_size.

Line 493 defines hidden = num_heads * head_dim which shadows the intent but is mostly unused. Line 575 uses hidden for the output tensor shape, while all other tensor specs use hidden_size. This inconsistency could cause bugs if hidden_size ever differs from num_heads * head_dim.

Suggested fix for consistency
-    hidden = num_heads * head_dim
     kv_hidden = num_kv_heads * head_dim
     inter = intermediate_size
     cache_rows = batch * num_kv_heads * max_seq

And at line 575:

-        TensorSpec("out", [batch, hidden], torch.bfloat16, is_output=True),
+        TensorSpec("out", [batch, hidden_size], torch.bfloat16, is_output=True),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hidden = num_heads * head_dim
kv_hidden = num_kv_heads * head_dim
cache_rows = batch * num_kv_heads * max_seq_len
inter = intermediate_size
cache_rows = batch * num_kv_heads * max_seq
kv_hidden = num_kv_heads * head_dim
inter = intermediate_size
cache_rows = batch * num_kv_heads * max_seq
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 493 - 496, The code
defines a local variable hidden = num_heads * head_dim but other tensor specs
use hidden_size, causing inconsistency and a potentially unused shadow variable;
change the code so the output tensor shape and any uses of hidden reference
hidden_size (or set hidden = hidden_size if you prefer a local alias) and remove
any other unused definition of hidden or replace its usage with hidden_size to
ensure all tensor shapes (including the output tensor created near the current
output-shape code) consistently use hidden_size rather than recomputing
num_heads * head_dim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants