Refactor: replace manual chunked loops with pl.parallel chunk syntax and clean up scope2#103
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRefactors per-block attention iteration to use chunked parallel loops and updates related loop-optimizer annotations; adds a tiled decode implementation and renames the program-scoped class from Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope2.py`:
- Around line 159-169: Zero-initialize the padded regions before the
chunked-parallel stages by creating a pl.full(0.0, [Q_HEAD_PAD, ...]) tile and
using pl.assemble to write it into all_exp_padded, all_raw_scores, and
all_oi_tmp for every ctx block (loop over ctx_blocks) so Stage 4 won't read
uninitialized values; apply the same pattern used in qwen3_32b_decode.py:
allocate a zero tile sized to Q_HEAD_PAD rows and assemble it into those buffers
for each ctx block prior to the block-parallel loops that use
SEQ_TILE/Q_HEAD_BATCH.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 6eb0c213-21e1-4d63-a8cd-884dcd87d37a
📒 Files selected for processing (2)
examples/models/qwen3/qwen3_32b_decode.pyexamples/models/qwen3/qwen3_32b_decode_scope2.py
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): | ||
| for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): | ||
| s0 = sb * SEQ_TILE | ||
| cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 | ||
| k_tile = pl.slice( | ||
| k_cache, | ||
| [SEQ_TILE, head_dim], | ||
| [cache_row0, 0], | ||
| ) | ||
| raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) | ||
| all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "*qwen3*.py" | head -20Repository: hw-native-sys/pypto-lib
Length of output: 571
🏁 Script executed:
# Check the specific file and line ranges mentioned
wc -l examples/models/qwen3/qwen3_32b_decode_scope2.py
wc -l examples/models/qwen3/qwen3_32b_decode.pyRepository: hw-native-sys/pypto-lib
Length of output: 165
🏁 Script executed:
# Inspect the Stage 2-4 regions in qwen3_32b_decode_scope2.py
sed -n '150,220p' examples/models/qwen3/qwen3_32b_decode_scope2.py | cat -nRepository: hw-native-sys/pypto-lib
Length of output: 5449
🏁 Script executed:
# Check the sibling file qwen3_32b_decode.py at the referenced lines
sed -n '245,295p' examples/models/qwen3/qwen3_32b_decode.py | cat -nRepository: hw-native-sys/pypto-lib
Length of output: 3666
Add explicit zero-initialization before the chunked-parallel stages.
Stage 3 writes only the first Q_HEAD_BATCH rows of the all_exp_padded, all_raw_scores, and all_oi_tmp buffers, but Stage 4 immediately reads the full Q_HEAD_PAD tile from each. Without zeroing, the padded region contains uninitialized data that flows into Stage 4's matmul. The sibling file qwen3_32b_decode.py (lines 253–286) has already implemented this fix with an explicit pl.full(0.0) + pl.assemble loop over all ctx_blocks.
Apply the same initialization pattern here before Stage 2 begins.
Suggested fix
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
+with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
+ for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
+ all_raw_scores = pl.assemble(
+ all_raw_scores,
+ pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
+ [sb * Q_HEAD_PAD, 0],
+ )
+ all_exp_padded = pl.assemble(
+ all_exp_padded,
+ pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
+ [sb * Q_HEAD_PAD, 0],
+ )
+ all_oi_tmp = pl.assemble(
+ all_oi_tmp,
+ pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
+ [sb * Q_HEAD_PAD, 0],
+ )Also applies to: lines 172–191, 194–209
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 159 - 169,
Zero-initialize the padded regions before the chunked-parallel stages by
creating a pl.full(0.0, [Q_HEAD_PAD, ...]) tile and using pl.assemble to write
it into all_exp_padded, all_raw_scores, and all_oi_tmp for every ctx block (loop
over ctx_blocks) so Stage 4 won't read uninitialized values; apply the same
pattern used in qwen3_32b_decode.py: allocate a zero tile sized to Q_HEAD_PAD
rows and assemble it into those buffers for each ctx block prior to the
block-parallel loops that use SEQ_TILE/Q_HEAD_BATCH.
There was a problem hiding this comment.
Code Review
This pull request refactors the decoding logic in qwen3_32b_decode.py and qwen3_32b_decode_scope2.py by replacing manual chunked loops and explicit boundary checks with pl.parallel loops and the pl.chunked_loop_optimizer flag. These changes simplify the implementation and allow the compiler to more efficiently manage loop boundaries and core group transitions. I have no feedback to provide.
f721536 to
6557828
Compare
…and clean up scope2 - Replace hand-rolled `for sb0 in pl.range(0, ctx_blocks, SB_BATCH)` + inner `pl.range(SB_BATCH)` + `if sb < ctx_blocks` guard pattern with `pl.parallel(ctx_blocks, chunk=SB_BATCH)` in scope2 and decode - Remove redundant zero-initialization loop for intermediate tensors in decode scope2 (all_raw_scores, all_exp_padded, etc. are overwritten by subsequent stages) - Rename program class Qwen3Scope123 → Qwen3Decode - Add qwen3_32b_decode_tile.py: InCore + Orchestration separated rewrite with explicit pl.load/pl.store/pl.move data movement
6557828 to
5c3808d
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_tile.py`:
- Around line 205-271: kernel_rope_kv_cache_q_pad is only writing the first
query group per KV head (writing at q_pad_base + ki * Q_HEAD_PAD + qi), which
breaks when q_groups > 1; fix by either iterating over qg and writing into
all_q_padded at index (q_pad_base + (ki * q_groups + qg) * Q_HEAD_PAD + qi)
inside kernel_rope_kv_cache_q_pad (and the sibling region at 601-611), or
enforce/validate q_groups == 1 during construction and fail fast; update the
function(s) that reference q_pad_base/all_q_padded (kernel_rope_kv_cache_q_pad
and the similar block at 601-611) to implement the qg loop or the validation.
- Around line 31-35: Replace the Unicode multiplication sign "×" with ASCII "x"
in the documentation/comments — specifically update the "Scope 3:" comment block
(the line listing "Output projection: attn_out × wo") and the other occurrence
around line 681 to use "x" instead; search the file
examples/models/qwen3/qwen3_32b_decode_tile.py for any remaining "×" characters
and replace them only in comments/docstrings (do not change actual code logic or
variable names).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 917c578a-7e56-4a27-ac1d-eeced18e2dc7
📒 Files selected for processing (3)
examples/models/qwen3/qwen3_32b_decode.pyexamples/models/qwen3/qwen3_32b_decode_scope2.pyexamples/models/qwen3/qwen3_32b_decode_tile.py
✅ Files skipped from review due to trivial changes (1)
- examples/models/qwen3/qwen3_32b_decode_scope2.py
| Scope 3: | ||
| 1. Output projection: attn_out × wo | ||
| 2. Residual addition with hidden_states | ||
| 3. Post-attention RMSNorm | ||
| 4. MLP: gate/up projections, SiLU activation, down projection |
There was a problem hiding this comment.
Replace × with ASCII x in docs/comments.
Ruff already flags these multiplication signs as ambiguous Unicode, so they will keep tripping lint and are harder to search/copy/paste than plain ASCII.
Also applies to: 681-681
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 32-32: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_tile.py` around lines 31 - 35, Replace
the Unicode multiplication sign "×" with ASCII "x" in the documentation/comments
— specifically update the "Scope 3:" comment block (the line listing "Output
projection: attn_out × wo") and the other occurrence around line 681 to use "x"
instead; search the file examples/models/qwen3/qwen3_32b_decode_tile.py for any
remaining "×" characters and replace them only in comments/docstrings (do not
change actual code logic or variable names).
| def build_qwen3_decode_program( | ||
| batch: int = BATCH, | ||
| max_seq: int = MAX_SEQ, | ||
| hidden_size: int = HIDDEN, | ||
| intermediate_size: int = INTERMEDIATE, | ||
| num_heads: int = NUM_HEADS, | ||
| num_kv_heads: int = NUM_KV_HEADS, | ||
| head_dim: int = HEAD_DIM, | ||
| ): | ||
| hidden = hidden_size | ||
| kv_hidden = num_kv_heads * head_dim | ||
| inter = intermediate_size | ||
| hidden_blocks = hidden // K_CHUNK | ||
| q_out_blocks = hidden // Q_OUT_CHUNK | ||
| kv_out_blocks = kv_hidden // KV_OUT_CHUNK | ||
| mlp_out_blocks = inter // MLP_OUT_CHUNK | ||
| cache_rows = batch * num_kv_heads * max_seq | ||
| half_dim = head_dim // 2 | ||
| q_per_kv = num_heads // num_kv_heads | ||
| q_groups = q_per_kv // Q_HEAD_BATCH | ||
| total_q_groups = num_kv_heads * q_groups | ||
| attn_scale = 1.0 / (head_dim ** 0.5) | ||
| max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE |
There was a problem hiding this comment.
Validate the supported shape invariants up front.
This builder accepts arbitrary batch, hidden_size, intermediate_size, num_heads, num_kv_heads, and head_dim, but the implementation assumes aligned tiles everywhere: fixed BATCH_TILE loads, //-based block counts, and Q_HEAD_BATCH grouping. For unsupported inputs, this will either skip tail work or issue out-of-range fixed-size accesses. Please fail fast here with explicit checks, or narrow the public API to the single supported model shape.
Suggested guardrail
def build_qwen3_decode_program(
batch: int = BATCH,
max_seq: int = MAX_SEQ,
hidden_size: int = HIDDEN,
intermediate_size: int = INTERMEDIATE,
num_heads: int = NUM_HEADS,
num_kv_heads: int = NUM_KV_HEADS,
head_dim: int = HEAD_DIM,
):
+ if hidden_size != num_heads * head_dim:
+ raise ValueError("hidden_size must equal num_heads * head_dim")
+ if batch % BATCH_TILE != 0:
+ raise ValueError(f"batch must be a multiple of {BATCH_TILE}")
+ if hidden_size % K_CHUNK != 0 or hidden_size % Q_OUT_CHUNK != 0:
+ raise ValueError("hidden_size must align with K_CHUNK and Q_OUT_CHUNK")
+ if (num_kv_heads * head_dim) % KV_OUT_CHUNK != 0:
+ raise ValueError("num_kv_heads * head_dim must align with KV_OUT_CHUNK")
+ if intermediate_size % MLP_OUT_CHUNK != 0:
+ raise ValueError(f"intermediate_size must be a multiple of {MLP_OUT_CHUNK}")
+ if num_heads % num_kv_heads != 0:
+ raise ValueError("num_heads must be divisible by num_kv_heads")
+ if (num_heads // num_kv_heads) % Q_HEAD_BATCH != 0:
+ raise ValueError("num_heads // num_kv_heads must be divisible by Q_HEAD_BATCH")
+
hidden = hidden_size
Summary
for sb0 in pl.range(0, ctx_blocks, SB_BATCH)+ innerpl.range(SB_BATCH)+if sb < ctx_blocksguard pattern withpl.parallel(ctx_blocks, chunk=SB_BATCH)using the compiler-managed guarded chunk policyall_raw_scores,all_exp_padded, etc. are overwritten by subsequent stages)Qwen3Scope123→Qwen3Decodeqwen3_32b_decode_tile.py: InCore + Orchestration separated rewrite with explicitpl.load/pl.store/pl.movedata movementFiles changed
qwen3_32b_decode_scope2.py— 3 stages converted topl.parallelchunk syntaxqwen3_32b_decode.py— 3 stages converted + removed redundant init loop + class renameqwen3_32b_decode_tile.py— new tile DSL version of decode