From f2afe11db0a68f35c4e2f462570cce0eab002c23 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Thu, 9 Apr 2026 21:06:53 +0800 Subject: [PATCH] refactor(qwen3): reorganize scope3 decode flow and runtime config - rewrite scope3 into explicit stage-by-stage sections with clearer dataflow\n- switch projection accumulations to matmul/matmul_acc and streamline residual assembly\n- update default shape baseline (HIDDEN=8192) and keep existing chunk constants\n- refresh tensor initializers and remove dead local from golden\n- expose CLI platform/device/profiling options and backend selection in compile_and_run\n- validate on a2a3sim: out matched 131072/131072 --- .../models/qwen3/qwen3_32b_decode_scope3.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope3.py b/examples/models/qwen3/qwen3_32b_decode_scope3.py index 6839953..bf0e45c 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope3.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope3.py @@ -21,7 +21,7 @@ import pypto.language as pl BATCH = 16 -HIDDEN = 5120 +HIDDEN = 8192 INTERMEDIATE = 25600 EPS = 1e-6 @@ -63,14 +63,14 @@ def scope3( for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) - # Stage 1: Initialize resid1_tile accumulator in parallel. + # Stage 0: Initialize resid1_tile accumulator in parallel. with pl.auto_incore(): for ob in pl.parallel(0, Q_OUT_BLOCKS, chunk=8): o0 = ob * Q_OUT_CHUNK zero_resid1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) resid1_tile = pl.assemble(resid1_tile, zero_resid1, [0, o0]) - # Stage 2: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. + # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. for ob in pl.range(Q_OUT_BLOCKS): o0 = ob * Q_OUT_CHUNK @@ -84,21 +84,18 @@ def scope3( w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) - resid1_tile = pl.assemble(resid1_tile, o_acc, [0, o0]) - + # Stage 2: Residual addition with hidden_states with pl.incore(): resid = pl.cast( pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), target_type=pl.FP32, ) - mm_out = pl.slice(resid1_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]) - add_resid = pl.add(mm_out, resid) - resid1_tile = pl.assemble(resid1_tile, add_resid, [0, o0]) + resid_sum = pl.add(o_acc, resid) + resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) - # Stage 3 & 4 & 5: Post-attention RMSNorm: compute inv_rms over resid1_tile + normalize + initialize down_proj_tile accumulator. + # Stage 3: Post-attention RMSNorm post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) - down_proj_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) - with pl.auto_incore(): + with pl.incore(): sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK @@ -111,16 +108,23 @@ def scope3( x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0]) normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, pl.reshape(inv_rms, [BATCH_TILE, 1])), gamma) - post_norm_tile = pl.assemble( - post_norm_tile, pl.cast(normed, target_type=pl.BF16), [0, k0] + normed_bf16 = pl.cast(normed, target_type=pl.BF16) + post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) + + # Stage 4: Initialize mlp_tile + mlp_tile = pl.create_tensor([BATCH_TILE, INTER_CFG], dtype=pl.BF16) + with pl.auto_incore(): + for ob in pl.parallel(0, MLP_OUT_BLOCKS, chunk=8): + o0 = ob * MLP_OUT_CHUNK + zero_mlp = pl.cast( + pl.full([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32, value=0.0), + target_type=pl.BF16, ) - down_zero_chunk = pl.full([BATCH_TILE, K_CHUNK], dtype=pl.FP32, value=0.0) - down_proj_tile = pl.assemble(down_proj_tile, down_zero_chunk, [0, k0]) + mlp_tile = pl.assemble(mlp_tile, zero_mlp, [0, o0]) - # Stage 6: MLP: gate/up projections + SiLU. + # Stage 5 & 6 & 7: MLP: gate/up projections + SiLU. for ob in pl.range(MLP_OUT_BLOCKS): o0 = ob * MLP_OUT_CHUNK - # Stage 6a: MLP: gate projections with pl.incore(): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) @@ -131,7 +135,6 @@ def scope3( wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) - # Stage 6b: MLP: up projections with pl.incore(): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) @@ -142,34 +145,33 @@ def scope3( wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) up_acc = pl.matmul_acc(up_acc, post_chunk, wu) - # Stage 6c: MLP: silu with pl.auto_incore(): sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) + mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) - # Stage 7: Down projection: accumulate in parallel. - for dob in pl.range(HIDDEN_BLOCKS): - d0 = dob * K_CHUNK - with pl.incore(): + # Stage 8 & 9: Down projection + final residual writeback. + for dob in pl.range(HIDDEN_BLOCKS): + d0 = dob * K_CHUNK + with pl.incore(): + mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) + w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) + down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) + for ob in pl.range(1, MLP_OUT_BLOCKS): + o0 = ob * MLP_OUT_CHUNK + down_mlp_chunk_bf16 = pl.slice( + mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, o0] + ) w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) - down_next = pl.matmul(mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) - - with pl.incore(): - down_prev = pl.slice(down_proj_tile, [BATCH_TILE, K_CHUNK], [0, d0]) - accum = pl.add(down_prev, down_next) - down_proj_tile = pl.assemble(down_proj_tile, accum, [0, d0]) - - # Stage 8: Final residual: down_proj + resid1, write to output. - for ob in pl.range(HIDDEN_BLOCKS): - o0 = ob * K_CHUNK + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) with pl.incore(): - down_acc = pl.add( - pl.slice(down_proj_tile, [BATCH_TILE, K_CHUNK], [0, o0]), - pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, o0]), + out_chunk = pl.add( + down_acc, + pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]), ) - out_f32 = pl.cast(down_acc, target_type=pl.BF16) - out = pl.assemble(out, out_f32, [b0, o0]) + out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) + out = pl.assemble(out, out_chunk_cast, [b0, d0]) return out @@ -195,7 +197,6 @@ def golden(tensors: dict, params: dict | None = None) -> None: w_up = tensors["w_up"] # [H, I], BF16 w_down = tensors["w_down"] # [I, H], BF16 - H = attn_out.shape[1] eps = 1e-6 # 1. Output projection (BF16 inputs, FP32 accumulation) + residual.