From 44a886316b6c56f8e5e6b37f64f096c468a6048c Mon Sep 17 00:00:00 2001 From: ndleslx Date: Fri, 10 Apr 2026 14:58:19 +0800 Subject: [PATCH] Update: optimize Qwen3 scope1 decode scheduling - parallelize RMS partial reduction and Q/K/V output chunk loops - increase K_CHUNK to 512 and normalize chunks on demand to reduce wall time --- .../models/qwen3/qwen3_32b_decode_scope1.py | 126 ++++++++++-------- 1 file changed, 69 insertions(+), 57 deletions(-) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope1.py b/examples/models/qwen3/qwen3_32b_decode_scope1.py index d6e3c6f..da2bc55 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope1.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope1.py @@ -29,7 +29,7 @@ HIDDEN_INV = 1.0 / HIDDEN # Vector TILELET budget (2 KB = 2048 B, FP32 = 4 B/elem): -K_CHUNK = 128 +K_CHUNK = 512 Q_OUT_CHUNK = 64 KV_OUT_CHUNK = 64 MLP_OUT_CHUNK = 64 @@ -67,86 +67,98 @@ def qwen3_scope1( pl.Tensor[[batch, kv_hidden], pl.FP32], ]: for b0 in pl.range(0, batch, BATCH_TILE): - normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + # Stage 1: compute per-chunk RMS partials, reduce once. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + sq_partials = pl.create_tensor([hidden_blocks, BATCH_TILE], dtype=pl.FP32) + for kb_i in pl.parallel(hidden_blocks, chunk=2): + k0_i = kb_i * K_CHUNK + x_i = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0_i]), + target_type=pl.FP32, + ) + partial_sq_i = pl.reshape(pl.row_sum(pl.mul(x_i, x_i)), [1, BATCH_TILE]) + sq_partials = pl.assemble(sq_partials, partial_sq_i, [kb_i, 0]) - # Stage 1: RMSNorm + apply weights (vector ops only). with pl.at(level=pl.Level.CORE_GROUP): partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK - x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), - target_type=pl.FP32, - ) - partial_sq = pl.add( - partial_sq, - pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]), - ) - # Compute variance in [1, BATCH_TILE], then reshape to [BATCH_TILE, 1] - # for row_expand_mul broadcasting. + for kb_i in pl.range(hidden_blocks): + partial_sq_i = pl.slice(sq_partials, [1, BATCH_TILE], [kb_i, 0]) + partial_sq = pl.add(partial_sq, partial_sq_i) variance = pl.reshape( pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), [BATCH_TILE, 1], ) - inv_rms = pl.recip(pl.sqrt(variance)) - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK + + # Stage 2: Q projection (vector normalization + matmul_acc). + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(q_out_blocks, chunk=3): + q0 = ob * Q_OUT_CHUNK x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, 0]), target_type=pl.FP32, ) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) - normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) - - # Stage 2: Q projection (matmul + matmul_acc in single incore). - for ob in pl.range(q_out_blocks): - q0 = ob * Q_OUT_CHUNK - - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) - q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, 0]) + normed_chunk = pl.cast( + pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma), + target_type=pl.BF16, + ) + tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) + q_acc = pl.matmul(normed_chunk, tile_b_i, out_dtype=pl.FP32) for kb in pl.range(1, hidden_blocks): k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) + x_chunk = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.FP32, + ) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed_chunk = pl.cast( + pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma), + target_type=pl.BF16, + ) tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) - q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) - - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + q_acc = pl.matmul_acc(q_acc, normed_chunk, tile_b_i) - # Stage 3: K/V projection (matmul + matmul_acc in single incore). - for ob in pl.range(kv_out_blocks): - kv0 = ob * KV_OUT_CHUNK + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) + # Stage 3: K/V projection (vector normalization + matmul_acc). + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(kv_out_blocks, chunk=3): + kv0 = ob * KV_OUT_CHUNK + x_chunk = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, 0]), + target_type=pl.FP32, + ) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, 0]) + normed_chunk = pl.cast( + pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma), + target_type=pl.BF16, + ) + tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + k_acc = pl.matmul(normed_chunk, tile_wk_i, out_dtype=pl.FP32) + tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + v_acc = pl.matmul(normed_chunk, tile_wv_i, out_dtype=pl.FP32) for kb in pl.range(1, hidden_blocks): k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) + x_chunk = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.FP32, + ) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed_chunk = pl.cast( + pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma), + target_type=pl.BF16, + ) tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) + k_acc = pl.matmul_acc(k_acc, normed_chunk, tile_wk_i) tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) + v_acc = pl.matmul_acc(v_acc, normed_chunk, tile_wv_i) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) return q_proj, k_proj, v_proj