Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions examples/models/qwen3/qwen3_32b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
HIDDEN_INV = 1.0 / HIDDEN

# Scope 1 tiling constants.
K_CHUNK = 128
SCOPE1_K_CHUNK = 512
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这名字有点怪

Q_OUT_CHUNK = 64
KV_OUT_CHUNK = 64
BATCH_TILE = 16
Expand All @@ -54,6 +54,7 @@
SB_BATCH = 64

# Scope 3 tiling constants.
K_CHUNK = 128
MLP_OUT_CHUNK = 256


Expand All @@ -69,6 +70,7 @@ def build_qwen3_decode_program(
hidden = hidden_size
kv_hidden = num_kv_heads * head_dim
inter = intermediate_size
scope1_hidden_blocks = hidden // SCOPE1_K_CHUNK
hidden_blocks = hidden // K_CHUNK
q_out_blocks = hidden // Q_OUT_CHUNK
kv_out_blocks = kv_hidden // KV_OUT_CHUNK
Expand Down Expand Up @@ -114,10 +116,10 @@ def qwen3_decode(

with pl.at(level=pl.Level.CORE_GROUP):
partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK
for kb in pl.range(scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
pl.slice(hidden_states, [BATCH_TILE, SCOPE1_K_CHUNK], [b0, k0]),
target_type=pl.FP32,
)
partial_sq = pl.add(
Expand All @@ -132,51 +134,50 @@ def qwen3_decode(
)
inv_rms = pl.recip(pl.sqrt(variance))

for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK
for kb in pl.range(scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
pl.slice(hidden_states, [BATCH_TILE, SCOPE1_K_CHUNK], [b0, k0]),
target_type=pl.FP32,
)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0])
gamma = pl.slice(input_rms_weight, [1, SCOPE1_K_CHUNK], [0, k0])
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma)
normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0])

for ob in pl.range(q_out_blocks):
q0 = ob * Q_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(q_out_blocks, chunk=4):
q0 = ob * Q_OUT_CHUNK
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [0, q0])
q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0])
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)
q_proj = pl.assemble(q_proj, q_acc, [b0, q0])
q_proj = pl.assemble(q_proj, q_acc, [b0, q0])

for ob in pl.range(kv_out_blocks):
kv0 = ob * KV_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(kv_out_blocks, chunk=4):
kv0 = ob * KV_OUT_CHUNK
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_wk = pl.slice(wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0])
k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32)
for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_wk_i = pl.slice(wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i)
k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_wv = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0])
v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32)
for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_wv_i = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i)
v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])
v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])

# ── Scope 2: RoPE + KV cache update + grouped-query attention ──
# Pad q
Expand Down Expand Up @@ -408,16 +409,22 @@ def qwen3_decode(
sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK
x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0])
sq_sum = pl.add(sq_sum, pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]))
resid_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0])
sq_sum = pl.add(
sq_sum,
pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]),
)
inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)))

for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK
x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0])
gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0])
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), gamma)
normed_bf16 = pl.cast(normed, target_type=pl.BF16)
resid_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0])
post_gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0])
post_normed = pl.col_expand_mul(
pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])),
post_gamma,
)
normed_bf16 = pl.cast(post_normed, target_type=pl.BF16)
post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0])

# Stage 4 & 5 & 6: MLP: gate/up projections + SiLU.
Expand Down Expand Up @@ -622,8 +629,8 @@ def golden_qwen3_decode(tensors, params):
x_tile = hidden_states[b0:b_end, :].float()

sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32)
for k0 in range(0, hidden_size, K_CHUNK):
x_chunk = x_tile[:, k0:k0 + K_CHUNK]
for k0 in range(0, hidden_size, SCOPE1_K_CHUNK):
x_chunk = x_tile[:, k0:k0 + SCOPE1_K_CHUNK]
sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True)
variance = sq_sum / hidden_size + EPS
rms = torch.sqrt(variance)
Expand Down
21 changes: 10 additions & 11 deletions examples/models/qwen3/qwen3_32b_decode_scope1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
HIDDEN_INV = 1.0 / HIDDEN

# Vector TILELET budget (2 KB = 2048 B, FP32 = 4 B/elem):
K_CHUNK = 128
K_CHUNK = 512
Q_OUT_CHUNK = 64
KV_OUT_CHUNK = 64
MLP_OUT_CHUNK = 64
Expand Down Expand Up @@ -102,10 +102,10 @@ def qwen3_scope1(
normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0])

# Stage 2: Q projection (matmul + matmul_acc in single incore).
for ob in pl.range(q_out_blocks):
q0 = ob * Q_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(q_out_blocks, chunk=4):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pl.parallel function in this repository is consistently used with at least two positional arguments for the start and stop indices (e.g., pl.parallel(0, q_out_blocks, ...)), as seen in other model examples. Using a single argument may not be supported by the DSL and could lead to incorrect loop bounds or runtime errors.

Suggested change
for ob in pl.parallel(q_out_blocks, chunk=4):
for ob in pl.parallel(0, q_out_blocks, chunk=4):

q0 = ob * Q_OUT_CHUNK

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0])
q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
Expand All @@ -116,13 +116,13 @@ def qwen3_scope1(
tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)

q_proj = pl.assemble(q_proj, q_acc, [b0, q0])
q_proj = pl.assemble(q_proj, q_acc, [b0, q0])

# Stage 3: K/V projection (matmul + matmul_acc in single incore).
for ob in pl.range(kv_out_blocks):
kv0 = ob * KV_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(kv_out_blocks, chunk=4):
kv0 = ob * KV_OUT_CHUNK

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32)
Expand All @@ -133,9 +133,8 @@ def qwen3_scope1(
tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i)

k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])
k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32)
Expand All @@ -146,7 +145,7 @@ def qwen3_scope1(
tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i)

v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])
v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])

return q_proj, k_proj, v_proj

Expand Down
Loading