diff --git a/examples/models/qwen3/qwen3_32b_decode.py b/examples/models/qwen3/qwen3_32b_decode.py index 70964ca..e19c7a8 100644 --- a/examples/models/qwen3/qwen3_32b_decode.py +++ b/examples/models/qwen3/qwen3_32b_decode.py @@ -42,7 +42,7 @@ HIDDEN_INV = 1.0 / HIDDEN # Scope 1 tiling constants. -K_CHUNK = 128 +SCOPE1_K_CHUNK = 512 Q_OUT_CHUNK = 64 KV_OUT_CHUNK = 64 BATCH_TILE = 16 @@ -54,6 +54,7 @@ SB_BATCH = 64 # Scope 3 tiling constants. +K_CHUNK = 128 MLP_OUT_CHUNK = 256 @@ -69,6 +70,7 @@ def build_qwen3_decode_program( hidden = hidden_size kv_hidden = num_kv_heads * head_dim inter = intermediate_size + scope1_hidden_blocks = hidden // SCOPE1_K_CHUNK hidden_blocks = hidden // K_CHUNK q_out_blocks = hidden // Q_OUT_CHUNK kv_out_blocks = kv_hidden // KV_OUT_CHUNK @@ -114,10 +116,10 @@ def qwen3_decode( with pl.at(level=pl.Level.CORE_GROUP): partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK + for kb in pl.range(scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + pl.slice(hidden_states, [BATCH_TILE, SCOPE1_K_CHUNK], [b0, k0]), target_type=pl.FP32, ) partial_sq = pl.add( @@ -132,51 +134,50 @@ def qwen3_decode( ) inv_rms = pl.recip(pl.sqrt(variance)) - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK + for kb in pl.range(scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + pl.slice(hidden_states, [BATCH_TILE, SCOPE1_K_CHUNK], [b0, k0]), target_type=pl.FP32, ) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + gamma = pl.slice(input_rms_weight, [1, SCOPE1_K_CHUNK], [0, k0]) normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) - for ob in pl.range(q_out_blocks): - q0 = ob * Q_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(q_out_blocks, chunk=4): + q0 = ob * Q_OUT_CHUNK + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [0, q0]) q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [k0, q0]) q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) - for ob in pl.range(kv_out_blocks): - kv0 = ob * KV_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(kv_out_blocks, chunk=4): + kv0 = ob * KV_OUT_CHUNK + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_wk = pl.slice(wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0]) k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_wk_i = pl.slice(wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - with pl.at(level=pl.Level.CORE_GROUP): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_wv = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0]) v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_wv_i = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) # ── Scope 2: RoPE + KV cache update + grouped-query attention ── # Pad q @@ -408,16 +409,22 @@ def qwen3_decode( sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(hidden_blocks): k0 = kb * K_CHUNK - x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - sq_sum = pl.add(sq_sum, pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE])) + resid_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) + sq_sum = pl.add( + sq_sum, + pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]), + ) inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) for kb in pl.range(hidden_blocks): k0 = kb * K_CHUNK - x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), gamma) - normed_bf16 = pl.cast(normed, target_type=pl.BF16) + resid_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) + post_gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0]) + post_normed = pl.col_expand_mul( + pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), + post_gamma, + ) + normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) # Stage 4 & 5 & 6: MLP: gate/up projections + SiLU. @@ -622,8 +629,8 @@ def golden_qwen3_decode(tensors, params): x_tile = hidden_states[b0:b_end, :].float() sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32) - for k0 in range(0, hidden_size, K_CHUNK): - x_chunk = x_tile[:, k0:k0 + K_CHUNK] + for k0 in range(0, hidden_size, SCOPE1_K_CHUNK): + x_chunk = x_tile[:, k0:k0 + SCOPE1_K_CHUNK] sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) variance = sq_sum / hidden_size + EPS rms = torch.sqrt(variance) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope1.py b/examples/models/qwen3/qwen3_32b_decode_scope1.py index d6e3c6f..f44f427 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope1.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope1.py @@ -29,7 +29,7 @@ HIDDEN_INV = 1.0 / HIDDEN # Vector TILELET budget (2 KB = 2048 B, FP32 = 4 B/elem): -K_CHUNK = 128 +K_CHUNK = 512 Q_OUT_CHUNK = 64 KV_OUT_CHUNK = 64 MLP_OUT_CHUNK = 64 @@ -102,10 +102,10 @@ def qwen3_scope1( normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) # Stage 2: Q projection (matmul + matmul_acc in single incore). - for ob in pl.range(q_out_blocks): - q0 = ob * Q_OUT_CHUNK + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(q_out_blocks, chunk=4): + q0 = ob * Q_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) @@ -116,13 +116,13 @@ def qwen3_scope1( tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) # Stage 3: K/V projection (matmul + matmul_acc in single incore). - for ob in pl.range(kv_out_blocks): - kv0 = ob * KV_OUT_CHUNK + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for ob in pl.parallel(kv_out_blocks, chunk=4): + kv0 = ob * KV_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) @@ -133,9 +133,8 @@ def qwen3_scope1( tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) @@ -146,7 +145,7 @@ def qwen3_scope1( tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) return q_proj, k_proj, v_proj