Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions examples/models/qwen3/qwen3_32b_decode_scope3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,7 @@ def scope3(
for b0 in pl.range(0, BATCH_CFG, BATCH_TILE):
resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32)

# Stage 0: Initialize resid1_tile accumulator in parallel.
with pl.auto_incore():
for ob in pl.parallel(0, Q_OUT_BLOCKS, chunk=8):
o0 = ob * Q_OUT_CHUNK
zero_resid1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0)
resid1_tile = pl.assemble(resid1_tile, zero_resid1, [0, o0])

# Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
# Stage 0: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
for ob in pl.range(Q_OUT_BLOCKS):
o0 = ob * Q_OUT_CHUNK

Expand All @@ -84,7 +77,7 @@ def scope3(
w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0])
o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk)

# Stage 2: Residual addition with hidden_states
# Stage 1: Residual addition with hidden_states
with pl.incore():
resid = pl.cast(
pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]),
Expand All @@ -93,7 +86,7 @@ def scope3(
resid_sum = pl.add(o_acc, resid)
resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0])

# Stage 3: Post-attention RMSNorm
# Stage 2: Post-attention RMSNorm
post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16)
with pl.incore():
sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
Expand All @@ -111,18 +104,8 @@ def scope3(
normed_bf16 = pl.cast(normed, target_type=pl.BF16)
post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0])

# Stage 4: Initialize mlp_tile
# Stage 3 & 4 & 5: MLP: gate/up projections + SiLU.
mlp_tile = pl.create_tensor([BATCH_TILE, INTER_CFG], dtype=pl.BF16)
with pl.auto_incore():
for ob in pl.parallel(0, MLP_OUT_BLOCKS, chunk=8):
o0 = ob * MLP_OUT_CHUNK
zero_mlp = pl.cast(
pl.full([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32, value=0.0),
target_type=pl.BF16,
)
mlp_tile = pl.assemble(mlp_tile, zero_mlp, [0, o0])

# Stage 5 & 6 & 7: MLP: gate/up projections + SiLU.
for ob in pl.range(MLP_OUT_BLOCKS):
o0 = ob * MLP_OUT_CHUNK
with pl.incore():
Expand Down Expand Up @@ -151,7 +134,7 @@ def scope3(
mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16)
mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0])

# Stage 8 & 9: Down projection + final residual writeback.
# Stage 6 & 7: Down projection + final residual writeback.
for dob in pl.range(HIDDEN_BLOCKS):
d0 = dob * K_CHUNK
with pl.incore():
Expand Down
Loading