Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 69 additions & 57 deletions examples/models/qwen3/qwen3_32b_decode_scope1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
HIDDEN_INV = 1.0 / HIDDEN

# Vector TILELET budget (2 KB = 2048 B, FP32 = 4 B/elem):
K_CHUNK = 128
K_CHUNK = 512
Q_OUT_CHUNK = 64
KV_OUT_CHUNK = 64
MLP_OUT_CHUNK = 64
Expand Down Expand Up @@ -67,86 +67,98 @@ def qwen3_scope1(
pl.Tensor[[batch, kv_hidden], pl.FP32],
]:
for b0 in pl.range(0, batch, BATCH_TILE):
normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16)
# Stage 1: compute per-chunk RMS partials, reduce once.
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
sq_partials = pl.create_tensor([hidden_blocks, BATCH_TILE], dtype=pl.FP32)
for kb_i in pl.parallel(hidden_blocks, chunk=2):
k0_i = kb_i * K_CHUNK
x_i = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0_i]),
target_type=pl.FP32,
)
partial_sq_i = pl.reshape(pl.row_sum(pl.mul(x_i, x_i)), [1, BATCH_TILE])
sq_partials = pl.assemble(sq_partials, partial_sq_i, [kb_i, 0])

# Stage 1: RMSNorm + apply weights (vector ops only).
with pl.at(level=pl.Level.CORE_GROUP):
partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
target_type=pl.FP32,
)
partial_sq = pl.add(
partial_sq,
pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]),
)
# Compute variance in [1, BATCH_TILE], then reshape to [BATCH_TILE, 1]
# for row_expand_mul broadcasting.
for kb_i in pl.range(hidden_blocks):
partial_sq_i = pl.slice(sq_partials, [1, BATCH_TILE], [kb_i, 0])
partial_sq = pl.add(partial_sq, partial_sq_i)
variance = pl.reshape(
pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS),
[BATCH_TILE, 1],
)

inv_rms = pl.recip(pl.sqrt(variance))

for kb in pl.range(hidden_blocks):
k0 = kb * K_CHUNK

# Stage 2: Q projection (vector normalization + matmul_acc).
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(q_out_blocks, chunk=3):
q0 = ob * Q_OUT_CHUNK
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, 0]),
target_type=pl.FP32,
)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0])
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma)
normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0])

# Stage 2: Q projection (matmul + matmul_acc in single incore).
for ob in pl.range(q_out_blocks):
q0 = ob * Q_OUT_CHUNK

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0])
q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, 0])
normed_chunk = pl.cast(
pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma),
target_type=pl.BF16,
)
tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0])
q_acc = pl.matmul(normed_chunk, tile_b_i, out_dtype=pl.FP32)

for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
target_type=pl.FP32,
)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0])
normed_chunk = pl.cast(
pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma),
target_type=pl.BF16,
)
tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)

q_proj = pl.assemble(q_proj, q_acc, [b0, q0])
q_acc = pl.matmul_acc(q_acc, normed_chunk, tile_b_i)

# Stage 3: K/V projection (matmul + matmul_acc in single incore).
for ob in pl.range(kv_out_blocks):
kv0 = ob * KV_OUT_CHUNK
q_proj = pl.assemble(q_proj, q_acc, [b0, q0])

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32)
# Stage 3: K/V projection (vector normalization + matmul_acc).
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for ob in pl.parallel(kv_out_blocks, chunk=3):
kv0 = ob * KV_OUT_CHUNK
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, 0]),
target_type=pl.FP32,
)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, 0])
normed_chunk = pl.cast(
pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma),
target_type=pl.BF16,
)
tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
k_acc = pl.matmul(normed_chunk, tile_wk_i, out_dtype=pl.FP32)
tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
v_acc = pl.matmul(normed_chunk, tile_wv_i, out_dtype=pl.FP32)

for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
x_chunk = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]),
target_type=pl.FP32,
)
gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0])
normed_chunk = pl.cast(
pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma),
target_type=pl.BF16,
)
tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i)

k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])

with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0])
tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0])
v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32)

for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0])
k_acc = pl.matmul_acc(k_acc, normed_chunk, tile_wk_i)
tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i)
v_acc = pl.matmul_acc(v_acc, normed_chunk, tile_wv_i)

v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])
k_proj = pl.assemble(k_proj, k_acc, [b0, kv0])
v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])

return q_proj, k_proj, v_proj

Expand Down