Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 50 additions & 93 deletions examples/models/qwen3/qwen3_32b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def build_qwen3_decode_program(
max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE

@pl.program
class Qwen3Scope123:
class Qwen3Decode:
@pl.function(type=pl.FunctionType.Opaque)
def qwen3_decode(
self,
Expand Down Expand Up @@ -249,106 +249,63 @@ def qwen3_decode(
q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH
q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0])

# Workaround
# Stage 2: QK matmul for all active sb blocks.
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
all_raw_scores = pl.assemble(
all_raw_scores,
pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
all_exp_padded = pl.assemble(
all_exp_padded,
pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[sb * Q_HEAD_PAD, 0],
)
all_oi_tmp = pl.assemble(
all_oi_tmp,
pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_mi = pl.assemble(
all_cur_mi,
pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)
li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_li = pl.assemble(
all_cur_li,
pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)

# Stage 2: QK matmul for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(
k_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(
k_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])

# Stage 3: softmax for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
scores_valid = pl.slice(
all_raw_scores,
[Q_HEAD_BATCH, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_BATCH, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
cur_mi = pl.row_max(scores)
exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi))
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
scores_valid = pl.slice(
all_raw_scores,
[Q_HEAD_BATCH, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_BATCH, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
cur_mi = pl.row_max(scores)
exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi))
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0])

# Stage 4: SV matmul for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
exp_tile = pl.slice(
all_exp_padded,
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
)
v_tile = pl.slice(
v_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
exp_tile = pl.slice(
all_exp_padded,
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
)
v_tile = pl.slice(
v_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])

# Stage 5: online softmax accumulation and normalisation.
with pl.at(level=pl.Level.CORE_GROUP):
Expand Down Expand Up @@ -474,7 +431,7 @@ def qwen3_decode(

return out

return Qwen3Scope123
return Qwen3Decode


def build_tensor_specs(
Expand Down
103 changes: 47 additions & 56 deletions examples/models/qwen3/qwen3_32b_decode_scope2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,66 +156,57 @@ def qwen3_scope2(
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(
k_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(
k_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
Comment on lines +159 to +169
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*qwen3*.py" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 571


🏁 Script executed:

# Check the specific file and line ranges mentioned
wc -l examples/models/qwen3/qwen3_32b_decode_scope2.py
wc -l examples/models/qwen3/qwen3_32b_decode.py

Repository: hw-native-sys/pypto-lib

Length of output: 165


🏁 Script executed:

# Inspect the Stage 2-4 regions in qwen3_32b_decode_scope2.py
sed -n '150,220p' examples/models/qwen3/qwen3_32b_decode_scope2.py | cat -n

Repository: hw-native-sys/pypto-lib

Length of output: 5449


🏁 Script executed:

# Check the sibling file qwen3_32b_decode.py at the referenced lines
sed -n '245,295p' examples/models/qwen3/qwen3_32b_decode.py | cat -n

Repository: hw-native-sys/pypto-lib

Length of output: 3666


Add explicit zero-initialization before the chunked-parallel stages.

Stage 3 writes only the first Q_HEAD_BATCH rows of the all_exp_padded, all_raw_scores, and all_oi_tmp buffers, but Stage 4 immediately reads the full Q_HEAD_PAD tile from each. Without zeroing, the padded region contains uninitialized data that flows into Stage 4's matmul. The sibling file qwen3_32b_decode.py (lines 253–286) has already implemented this fix with an explicit pl.full(0.0) + pl.assemble loop over all ctx_blocks.

Apply the same initialization pattern here before Stage 2 begins.

Suggested fix
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
+with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
+    for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
+        all_raw_scores = pl.assemble(
+            all_raw_scores,
+            pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
+            [sb * Q_HEAD_PAD, 0],
+        )
+        all_exp_padded = pl.assemble(
+            all_exp_padded,
+            pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
+            [sb * Q_HEAD_PAD, 0],
+        )
+        all_oi_tmp = pl.assemble(
+            all_oi_tmp,
+            pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
+            [sb * Q_HEAD_PAD, 0],
+        )

Also applies to: lines 172–191, 194–209

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 159 - 169,
Zero-initialize the padded regions before the chunked-parallel stages by
creating a pl.full(0.0, [Q_HEAD_PAD, ...]) tile and using pl.assemble to write
it into all_exp_padded, all_raw_scores, and all_oi_tmp for every ctx block (loop
over ctx_blocks) so Stage 4 won't read uninitialized values; apply the same
pattern used in qwen3_32b_decode.py: allocate a zero tile sized to Q_HEAD_PAD
rows and assemble it into those buffers for each ctx block prior to the
block-parallel loops that use SEQ_TILE/Q_HEAD_BATCH.


# Stage 3: softmax for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
scores_valid = pl.slice(
all_raw_scores,
[Q_HEAD_BATCH, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_BATCH, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
cur_mi = pl.row_max(scores)
exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi))
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
scores_valid = pl.slice(
all_raw_scores,
[Q_HEAD_BATCH, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_BATCH, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
cur_mi = pl.row_max(scores)
exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi))
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0])

# Stage 4: SV matmul for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
exp_tile = pl.slice(
all_exp_padded,
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
)
v_tile = pl.slice(
v_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
exp_tile = pl.slice(
all_exp_padded,
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
)
v_tile = pl.slice(
v_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])

# Stage 5: online softmax accumulation and normalisation.
with pl.at(level=pl.Level.CORE_GROUP):
Expand Down
Loading
Loading