diff --git a/examples/models/qwen3/qwen3_32b_decode_scope3.py b/examples/models/qwen3/qwen3_32b_decode_scope3.py index ed8249d..dfa3b0f 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope3.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope3.py @@ -63,14 +63,7 @@ def scope3( for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) - # Stage 0: Initialize resid1_tile accumulator in parallel. - with pl.auto_incore(): - for ob in pl.parallel(0, Q_OUT_BLOCKS, chunk=8): - o0 = ob * Q_OUT_CHUNK - zero_resid1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) - resid1_tile = pl.assemble(resid1_tile, zero_resid1, [0, o0]) - - # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. + # Stage 0: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. for ob in pl.range(Q_OUT_BLOCKS): o0 = ob * Q_OUT_CHUNK @@ -84,7 +77,7 @@ def scope3( w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) - # Stage 2: Residual addition with hidden_states + # Stage 1: Residual addition with hidden_states with pl.incore(): resid = pl.cast( pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), @@ -93,7 +86,7 @@ def scope3( resid_sum = pl.add(o_acc, resid) resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) - # Stage 3: Post-attention RMSNorm + # Stage 2: Post-attention RMSNorm post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) with pl.incore(): sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) @@ -111,18 +104,8 @@ def scope3( normed_bf16 = pl.cast(normed, target_type=pl.BF16) post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) - # Stage 4: Initialize mlp_tile + # Stage 3 & 4 & 5: MLP: gate/up projections + SiLU. mlp_tile = pl.create_tensor([BATCH_TILE, INTER_CFG], dtype=pl.BF16) - with pl.auto_incore(): - for ob in pl.parallel(0, MLP_OUT_BLOCKS, chunk=8): - o0 = ob * MLP_OUT_CHUNK - zero_mlp = pl.cast( - pl.full([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32, value=0.0), - target_type=pl.BF16, - ) - mlp_tile = pl.assemble(mlp_tile, zero_mlp, [0, o0]) - - # Stage 5 & 6 & 7: MLP: gate/up projections + SiLU. for ob in pl.range(MLP_OUT_BLOCKS): o0 = ob * MLP_OUT_CHUNK with pl.incore(): @@ -151,7 +134,7 @@ def scope3( mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) - # Stage 8 & 9: Down projection + final residual writeback. + # Stage 6 & 7: Down projection + final residual writeback. for dob in pl.range(HIDDEN_BLOCKS): d0 = dob * K_CHUNK with pl.incore():