Skip to content

Refactor: replace manual chunked loops with pl.parallel chunk syntax and clean up scope2#103

Merged
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
zhangqi-chen:refactor/scope2-chunked-parallel-syntax
Apr 13, 2026
Merged

Refactor: replace manual chunked loops with pl.parallel chunk syntax and clean up scope2#103
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
zhangqi-chen:refactor/scope2-chunked-parallel-syntax

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

@zhangqi-chen zhangqi-chen commented Apr 13, 2026

Summary

  • Replace hand-rolled for sb0 in pl.range(0, ctx_blocks, SB_BATCH) + inner pl.range(SB_BATCH) + if sb < ctx_blocks guard pattern with pl.parallel(ctx_blocks, chunk=SB_BATCH) using the compiler-managed guarded chunk policy
  • Remove redundant zero-initialization loop for intermediate tensors in decode scope2 (all_raw_scores, all_exp_padded, etc. are overwritten by subsequent stages)
  • Rename program class Qwen3Scope123Qwen3Decode
  • Add qwen3_32b_decode_tile.py: InCore + Orchestration separated rewrite with explicit pl.load/pl.store/pl.move data movement

Files changed

  • qwen3_32b_decode_scope2.py — 3 stages converted to pl.parallel chunk syntax
  • qwen3_32b_decode.py — 3 stages converted + removed redundant init loop + class rename
  • qwen3_32b_decode_tile.py — new tile DSL version of decode

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Refactors per-block attention iteration to use chunked parallel loops and updates related loop-optimizer annotations; adds a tiled decode implementation and renames the program-scoped class from Qwen3Scope123 to Qwen3Decode.

Changes

Cohort / File(s) Summary
Scope2 / Decode refactor
examples/models/qwen3/qwen3_32b_decode.py, examples/models/qwen3/qwen3_32b_decode_scope2.py
Replaced nested pl.range(..., SB_BATCH) + inner pl.range(SB_BATCH) + if sb < ctx_blocks guards with pl.parallel(ctx_blocks, chunk=SB_BATCH) and pl.at(..., optimization=pl.chunked_loop_optimizer). Stages 2–4 iterate directly over sb and assemble all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, all_cur_li using indexed offsets. Renamed program class Qwen3Scope123Qwen3Decode in decode file.
New tiled decode example
examples/models/qwen3/qwen3_32b_decode_tile.py
Added a new tile-DSL single-layer Qwen3 32B decode implementation: InCore kernels, orchestration @pl.program class Qwen3DecodeTile, builder build_qwen3_decode_program(), build_tensor_specs(), PyTorch reference golden_qwen3_decode(), and compile_and_run() CLI entrypoint. Large new file with tiled attention, projection, norm, and MLP stages.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

🐰 I hopped through blocks both near and far,

Replaced old guards with a chunked-parallel star.
Scores line up, softmax sings in tune,
Tiles assemble under the optimizing moon.
A tiny rabbit claps—decode complete! 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main refactoring: replacing manual chunked loops with pl.parallel chunk syntax and cleaning up scope2.
Description check ✅ Passed The pull request description clearly describes the changes made: replacing manual chunked loops with pl.parallel syntax, removing redundant initialization, renaming a class, and adding a new tile DSL implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope2.py`:
- Around line 159-169: Zero-initialize the padded regions before the
chunked-parallel stages by creating a pl.full(0.0, [Q_HEAD_PAD, ...]) tile and
using pl.assemble to write it into all_exp_padded, all_raw_scores, and
all_oi_tmp for every ctx block (loop over ctx_blocks) so Stage 4 won't read
uninitialized values; apply the same pattern used in qwen3_32b_decode.py:
allocate a zero tile sized to Q_HEAD_PAD rows and assemble it into those buffers
for each ctx block prior to the block-parallel loops that use
SEQ_TILE/Q_HEAD_BATCH.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 6eb0c213-21e1-4d63-a8cd-884dcd87d37a

📥 Commits

Reviewing files that changed from the base of the PR and between e663c14 and f721536.

📒 Files selected for processing (2)
  • examples/models/qwen3/qwen3_32b_decode.py
  • examples/models/qwen3/qwen3_32b_decode_scope2.py

Comment on lines +159 to +169
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(
k_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*qwen3*.py" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 571


🏁 Script executed:

# Check the specific file and line ranges mentioned
wc -l examples/models/qwen3/qwen3_32b_decode_scope2.py
wc -l examples/models/qwen3/qwen3_32b_decode.py

Repository: hw-native-sys/pypto-lib

Length of output: 165


🏁 Script executed:

# Inspect the Stage 2-4 regions in qwen3_32b_decode_scope2.py
sed -n '150,220p' examples/models/qwen3/qwen3_32b_decode_scope2.py | cat -n

Repository: hw-native-sys/pypto-lib

Length of output: 5449


🏁 Script executed:

# Check the sibling file qwen3_32b_decode.py at the referenced lines
sed -n '245,295p' examples/models/qwen3/qwen3_32b_decode.py | cat -n

Repository: hw-native-sys/pypto-lib

Length of output: 3666


Add explicit zero-initialization before the chunked-parallel stages.

Stage 3 writes only the first Q_HEAD_BATCH rows of the all_exp_padded, all_raw_scores, and all_oi_tmp buffers, but Stage 4 immediately reads the full Q_HEAD_PAD tile from each. Without zeroing, the padded region contains uninitialized data that flows into Stage 4's matmul. The sibling file qwen3_32b_decode.py (lines 253–286) has already implemented this fix with an explicit pl.full(0.0) + pl.assemble loop over all ctx_blocks.

Apply the same initialization pattern here before Stage 2 begins.

Suggested fix
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
+with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
+    for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
+        all_raw_scores = pl.assemble(
+            all_raw_scores,
+            pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
+            [sb * Q_HEAD_PAD, 0],
+        )
+        all_exp_padded = pl.assemble(
+            all_exp_padded,
+            pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
+            [sb * Q_HEAD_PAD, 0],
+        )
+        all_oi_tmp = pl.assemble(
+            all_oi_tmp,
+            pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
+            [sb * Q_HEAD_PAD, 0],
+        )

Also applies to: lines 172–191, 194–209

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 159 - 169,
Zero-initialize the padded regions before the chunked-parallel stages by
creating a pl.full(0.0, [Q_HEAD_PAD, ...]) tile and using pl.assemble to write
it into all_exp_padded, all_raw_scores, and all_oi_tmp for every ctx block (loop
over ctx_blocks) so Stage 4 won't read uninitialized values; apply the same
pattern used in qwen3_32b_decode.py: allocate a zero tile sized to Q_HEAD_PAD
rows and assemble it into those buffers for each ctx block prior to the
block-parallel loops that use SEQ_TILE/Q_HEAD_BATCH.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the decoding logic in qwen3_32b_decode.py and qwen3_32b_decode_scope2.py by replacing manual chunked loops and explicit boundary checks with pl.parallel loops and the pl.chunked_loop_optimizer flag. These changes simplify the implementation and allow the compiler to more efficiently manage loop boundaries and core group transitions. I have no feedback to provide.

@zhangqi-chen zhangqi-chen force-pushed the refactor/scope2-chunked-parallel-syntax branch from f721536 to 6557828 Compare April 13, 2026 03:40
@zhangqi-chen zhangqi-chen changed the title Refactor: replace manual chunked loops with pl.parallel chunk syntax Refactor: replace manual chunked loops with pl.parallel chunk syntax and clean up scope2 Apr 13, 2026
…and clean up scope2

- Replace hand-rolled `for sb0 in pl.range(0, ctx_blocks, SB_BATCH)` +
  inner `pl.range(SB_BATCH)` + `if sb < ctx_blocks` guard pattern with
  `pl.parallel(ctx_blocks, chunk=SB_BATCH)` in scope2 and decode
- Remove redundant zero-initialization loop for intermediate tensors in
  decode scope2 (all_raw_scores, all_exp_padded, etc. are overwritten
  by subsequent stages)
- Rename program class Qwen3Scope123 → Qwen3Decode
- Add qwen3_32b_decode_tile.py: InCore + Orchestration separated rewrite
  with explicit pl.load/pl.store/pl.move data movement
@zhangqi-chen zhangqi-chen force-pushed the refactor/scope2-chunked-parallel-syntax branch from 6557828 to 5c3808d Compare April 13, 2026 03:41
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_tile.py`:
- Around line 205-271: kernel_rope_kv_cache_q_pad is only writing the first
query group per KV head (writing at q_pad_base + ki * Q_HEAD_PAD + qi), which
breaks when q_groups > 1; fix by either iterating over qg and writing into
all_q_padded at index (q_pad_base + (ki * q_groups + qg) * Q_HEAD_PAD + qi)
inside kernel_rope_kv_cache_q_pad (and the sibling region at 601-611), or
enforce/validate q_groups == 1 during construction and fail fast; update the
function(s) that reference q_pad_base/all_q_padded (kernel_rope_kv_cache_q_pad
and the similar block at 601-611) to implement the qg loop or the validation.
- Around line 31-35: Replace the Unicode multiplication sign "×" with ASCII "x"
in the documentation/comments — specifically update the "Scope 3:" comment block
(the line listing "Output projection: attn_out × wo") and the other occurrence
around line 681 to use "x" instead; search the file
examples/models/qwen3/qwen3_32b_decode_tile.py for any remaining "×" characters
and replace them only in comments/docstrings (do not change actual code logic or
variable names).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 917c578a-7e56-4a27-ac1d-eeced18e2dc7

📥 Commits

Reviewing files that changed from the base of the PR and between f721536 and 6557828.

📒 Files selected for processing (3)
  • examples/models/qwen3/qwen3_32b_decode.py
  • examples/models/qwen3/qwen3_32b_decode_scope2.py
  • examples/models/qwen3/qwen3_32b_decode_tile.py
✅ Files skipped from review due to trivial changes (1)
  • examples/models/qwen3/qwen3_32b_decode_scope2.py

Comment on lines +31 to +35
Scope 3:
1. Output projection: attn_out × wo
2. Residual addition with hidden_states
3. Post-attention RMSNorm
4. MLP: gate/up projections, SiLU activation, down projection
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace × with ASCII x in docs/comments.

Ruff already flags these multiplication signs as ambiguous Unicode, so they will keep tripping lint and are harder to search/copy/paste than plain ASCII.

Also applies to: 681-681

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 32-32: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_tile.py` around lines 31 - 35, Replace
the Unicode multiplication sign "×" with ASCII "x" in the documentation/comments
— specifically update the "Scope 3:" comment block (the line listing "Output
projection: attn_out × wo") and the other occurrence around line 681 to use "x"
instead; search the file examples/models/qwen3/qwen3_32b_decode_tile.py for any
remaining "×" characters and replace them only in comments/docstrings (do not
change actual code logic or variable names).

Comment on lines +70 to +92
def build_qwen3_decode_program(
batch: int = BATCH,
max_seq: int = MAX_SEQ,
hidden_size: int = HIDDEN,
intermediate_size: int = INTERMEDIATE,
num_heads: int = NUM_HEADS,
num_kv_heads: int = NUM_KV_HEADS,
head_dim: int = HEAD_DIM,
):
hidden = hidden_size
kv_hidden = num_kv_heads * head_dim
inter = intermediate_size
hidden_blocks = hidden // K_CHUNK
q_out_blocks = hidden // Q_OUT_CHUNK
kv_out_blocks = kv_hidden // KV_OUT_CHUNK
mlp_out_blocks = inter // MLP_OUT_CHUNK
cache_rows = batch * num_kv_heads * max_seq
half_dim = head_dim // 2
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups
attn_scale = 1.0 / (head_dim ** 0.5)
max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the supported shape invariants up front.

This builder accepts arbitrary batch, hidden_size, intermediate_size, num_heads, num_kv_heads, and head_dim, but the implementation assumes aligned tiles everywhere: fixed BATCH_TILE loads, //-based block counts, and Q_HEAD_BATCH grouping. For unsupported inputs, this will either skip tail work or issue out-of-range fixed-size accesses. Please fail fast here with explicit checks, or narrow the public API to the single supported model shape.

Suggested guardrail
 def build_qwen3_decode_program(
     batch: int = BATCH,
     max_seq: int = MAX_SEQ,
     hidden_size: int = HIDDEN,
     intermediate_size: int = INTERMEDIATE,
     num_heads: int = NUM_HEADS,
     num_kv_heads: int = NUM_KV_HEADS,
     head_dim: int = HEAD_DIM,
 ):
+    if hidden_size != num_heads * head_dim:
+        raise ValueError("hidden_size must equal num_heads * head_dim")
+    if batch % BATCH_TILE != 0:
+        raise ValueError(f"batch must be a multiple of {BATCH_TILE}")
+    if hidden_size % K_CHUNK != 0 or hidden_size % Q_OUT_CHUNK != 0:
+        raise ValueError("hidden_size must align with K_CHUNK and Q_OUT_CHUNK")
+    if (num_kv_heads * head_dim) % KV_OUT_CHUNK != 0:
+        raise ValueError("num_kv_heads * head_dim must align with KV_OUT_CHUNK")
+    if intermediate_size % MLP_OUT_CHUNK != 0:
+        raise ValueError(f"intermediate_size must be a multiple of {MLP_OUT_CHUNK}")
+    if num_heads % num_kv_heads != 0:
+        raise ValueError("num_heads must be divisible by num_kv_heads")
+    if (num_heads // num_kv_heads) % Q_HEAD_BATCH != 0:
+        raise ValueError("num_heads // num_kv_heads must be divisible by Q_HEAD_BATCH")
+
     hidden = hidden_size

@zhangqi-chen zhangqi-chen merged commit 6b1ff63 into hw-native-sys:main Apr 13, 2026
8 of 9 checks passed
@zhangqi-chen zhangqi-chen deleted the refactor/scope2-chunked-parallel-syntax branch April 13, 2026 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant