From 5c3808dbbb628e3227bfe8bb9d732c1850268830 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Mon, 13 Apr 2026 10:55:51 +0800 Subject: [PATCH] Refactor: replace manual chunked loops with pl.parallel chunk syntax and clean up scope2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace hand-rolled `for sb0 in pl.range(0, ctx_blocks, SB_BATCH)` + inner `pl.range(SB_BATCH)` + `if sb < ctx_blocks` guard pattern with `pl.parallel(ctx_blocks, chunk=SB_BATCH)` in scope2 and decode - Remove redundant zero-initialization loop for intermediate tensors in decode scope2 (all_raw_scores, all_exp_padded, etc. are overwritten by subsequent stages) - Rename program class Qwen3Scope123 → Qwen3Decode - Add qwen3_32b_decode_tile.py: InCore + Orchestration separated rewrite with explicit pl.load/pl.store/pl.move data movement --- examples/models/qwen3/qwen3_32b_decode.py | 143 +-- .../models/qwen3/qwen3_32b_decode_scope2.py | 103 +- .../models/qwen3/qwen3_32b_decode_tile.py | 1058 +++++++++++++++++ 3 files changed, 1155 insertions(+), 149 deletions(-) create mode 100644 examples/models/qwen3/qwen3_32b_decode_tile.py diff --git a/examples/models/qwen3/qwen3_32b_decode.py b/examples/models/qwen3/qwen3_32b_decode.py index 70964ca..7929238 100644 --- a/examples/models/qwen3/qwen3_32b_decode.py +++ b/examples/models/qwen3/qwen3_32b_decode.py @@ -82,7 +82,7 @@ def build_qwen3_decode_program( max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE @pl.program - class Qwen3Scope123: + class Qwen3Decode: @pl.function(type=pl.FunctionType.Opaque) def qwen3_decode( self, @@ -249,106 +249,63 @@ def qwen3_decode( q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0]) - # Workaround + # Stage 2: QK matmul for all active sb blocks. all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - all_raw_scores = pl.assemble( - all_raw_scores, - pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), - [sb * Q_HEAD_PAD, 0], - ) - all_exp_padded = pl.assemble( - all_exp_padded, - pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16), - [sb * Q_HEAD_PAD, 0], - ) - all_oi_tmp = pl.assemble( - all_oi_tmp, - pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), - [sb * Q_HEAD_PAD, 0], - ) - mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) - all_cur_mi = pl.assemble( - all_cur_mi, - pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]), - [sb * Q_HEAD_BATCH, 0], - ) - li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) - all_cur_li = pl.assemble( - all_cur_li, - pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]), - [sb * Q_HEAD_BATCH, 0], - ) - - # Stage 2: QK matmul for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - k_tile = pl.slice( - k_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) - all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + k_tile = pl.slice( + k_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) + all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) # Stage 3: softmax for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, - [Q_HEAD_BATCH, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], - ) - scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) - cur_li = pl.row_sum(exp_scores_fp32) - all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) - all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) - all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + scores_valid = pl.slice( + all_raw_scores, + [Q_HEAD_BATCH, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_BATCH, valid_len], + ) + scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) + scores = pl.mul(scores_padded, attn_scale) + cur_mi = pl.row_max(scores) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) + cur_li = pl.row_sum(exp_scores_fp32) + all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) + all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) + all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) # Stage 4: SV matmul for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - exp_tile = pl.slice( - all_exp_padded, - [Q_HEAD_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - ) - v_tile = pl.slice( - v_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + exp_tile = pl.slice( + all_exp_padded, + [Q_HEAD_PAD, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + ) + v_tile = pl.slice( + v_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) # Stage 5: online softmax accumulation and normalisation. with pl.at(level=pl.Level.CORE_GROUP): @@ -474,7 +431,7 @@ def qwen3_decode( return out - return Qwen3Scope123 + return Qwen3Decode def build_tensor_specs( diff --git a/examples/models/qwen3/qwen3_32b_decode_scope2.py b/examples/models/qwen3/qwen3_32b_decode_scope2.py index 9b3d636..84ee337 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope2.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope2.py @@ -156,66 +156,57 @@ def qwen3_scope2( all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - k_tile = pl.slice( - k_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) - all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + k_tile = pl.slice( + k_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) + all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) # Stage 3: softmax for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, - [Q_HEAD_BATCH, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], - ) - scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) - cur_li = pl.row_sum(exp_scores_fp32) - all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) - all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) - all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + scores_valid = pl.slice( + all_raw_scores, + [Q_HEAD_BATCH, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_BATCH, valid_len], + ) + scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) + scores = pl.mul(scores_padded, attn_scale) + cur_mi = pl.row_max(scores) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) + cur_li = pl.row_sum(exp_scores_fp32) + all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) + all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) + all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) # Stage 4: SV matmul for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.at(level=pl.Level.CORE_GROUP): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - exp_tile = pl.slice( - all_exp_padded, - [Q_HEAD_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - ) - v_tile = pl.slice( - v_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + exp_tile = pl.slice( + all_exp_padded, + [Q_HEAD_PAD, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + ) + v_tile = pl.slice( + v_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) # Stage 5: online softmax accumulation and normalisation. with pl.at(level=pl.Level.CORE_GROUP): diff --git a/examples/models/qwen3/qwen3_32b_decode_tile.py b/examples/models/qwen3/qwen3_32b_decode_tile.py new file mode 100644 index 0000000..fd47588 --- /dev/null +++ b/examples/models/qwen3/qwen3_32b_decode_tile.py @@ -0,0 +1,1058 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Qwen3-32B single-layer decode forward — tile DSL version. + +InCore + Orchestration separated rewrite of qwen3_32b_decode.py. Each +``pl.at(level=pl.Level.CORE_GROUP)`` block is extracted into an explicit +InCore kernel using ``pl.load`` / ``pl.store`` / ``pl.move`` for data +movement, while the Orchestration function manages loops and scheduling. + +InCore kernels are defined at builder scope (inside build_qwen3_decode_program +but outside the @pl.program class) and automatically added to the program when +called from the Orchestration function. + +Scope 1: + 1. RMSNorm of input hidden states + 2. Q/K/V projection via matmul + +Scope 2: + 1. K RoPE + cache write, V cache write, Q RoPE + pad + 2. QK matmul + 3. Softmax + 4. SV matmul + 5. Online-softmax accumulation + final normalisation + +Scope 3: + 1. Output projection: attn_out × wo + 2. Residual addition with hidden_states + 3. Post-attention RMSNorm + 4. MLP: gate/up projections, SiLU activation, down projection + 5. Final residual addition +""" +from __future__ import annotations + +import pypto.language as pl + +BATCH = 16 +MAX_SEQ = 4096 +NUM_HEADS = 64 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +HIDDEN = NUM_HEADS * HEAD_DIM # 8192 +INTERMEDIATE = 25600 +KV_HIDDEN = NUM_KV_HEADS * HEAD_DIM + +EPS = 1e-6 +HIDDEN_INV = 1.0 / HIDDEN + +# Scope 1 tiling constants. +K_CHUNK = 128 +Q_OUT_CHUNK = 64 +KV_OUT_CHUNK = 64 +BATCH_TILE = 16 + +# Scope 2 tiling constants. +Q_HEAD_BATCH = 8 +Q_HEAD_PAD = 16 +SEQ_TILE = 64 +SB_BATCH = 64 + +# Scope 3 tiling constants. +MLP_OUT_CHUNK = 256 + + +def build_qwen3_decode_program( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, +): + hidden = hidden_size + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + hidden_blocks = hidden // K_CHUNK + q_out_blocks = hidden // Q_OUT_CHUNK + kv_out_blocks = kv_hidden // KV_OUT_CHUNK + mlp_out_blocks = inter // MLP_OUT_CHUNK + cache_rows = batch * num_kv_heads * max_seq + half_dim = head_dim // 2 + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + attn_scale = 1.0 / (head_dim ** 0.5) + max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE + + # ── Scope 1 InCore kernels ───────────────────────────────────────── + + @pl.function(type=pl.FunctionType.InCore) + def kernel_build_normed_tile( + hidden_tile: pl.Tensor[[BATCH_TILE, hidden], pl.BF16], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + output: pl.Out[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + partial_sq = pl.create_tile([1, BATCH_TILE], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + partial_sq = pl.mul(partial_sq, 0.0) + + for kb in pl.range(hidden_blocks): + k0 = kb * K_CHUNK + tile_x = pl.load( + hidden_tile, [0, k0], [BATCH_TILE, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + tile_x_f32 = pl.cast(tile_x, target_type=pl.FP32) + squared = pl.mul(tile_x_f32, tile_x_f32) + tmp = pl.create_tile([BATCH_TILE, K_CHUNK], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + row_sum: pl.Tile[[BATCH_TILE, 1], pl.FP32] = pl.row_sum(squared, tmp) + partial_sq = pl.add(partial_sq, pl.reshape(row_sum, [1, BATCH_TILE])) + + variance_t: pl.Tile[[1, BATCH_TILE], pl.FP32] = pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS) + variance: pl.Tile[[BATCH_TILE, 1], pl.FP32] = pl.reshape(variance_t, [BATCH_TILE, 1]) + rms = pl.sqrt(variance) + inv_rms = pl.recip(rms) + + for kb, (out_iter,) in pl.range(hidden_blocks, init_values=(output,)): + k0 = kb * K_CHUNK + tile_x = pl.load( + hidden_tile, [0, k0], [BATCH_TILE, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + tile_x_f32 = pl.cast(tile_x, target_type=pl.FP32) + tile_gamma = pl.load( + input_rms_weight, [0, k0], [1, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + scaled = pl.row_expand_mul(tile_x_f32, inv_rms) + weighted = pl.col_expand_mul(scaled, tile_gamma) + weighted_bf16 = pl.cast(weighted, target_type=pl.BF16) + out_next = pl.store(weighted_bf16, [0, k0], out_iter) + (out_carry,) = pl.yield_(out_next) + + return out_carry + + @pl.function(type=pl.FunctionType.InCore) + def kernel_q_proj_reduce( + normed_tile: pl.Tensor[[BATCH_TILE, hidden], pl.BF16], + weight: pl.Tensor[[hidden, hidden], pl.BF16], + out_row: pl.Scalar[pl.INDEX], + out_col: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[batch, hidden], pl.FP32]], + ) -> pl.Tensor[[batch, hidden], pl.FP32]: + tile_a_l1 = pl.load(normed_tile, [0, 0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_b_l1 = pl.load(weight, [0, out_col], [K_CHUNK, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_a_l0a = pl.move(tile_a_l1, target_memory=pl.MemorySpace.Left) + tile_b_l0b = pl.move(tile_b_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul(tile_a_l0a, tile_b_l0b) + + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i_l1 = pl.load(normed_tile, [0, k0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_b_i_l1 = pl.load(weight, [k0, out_col], [K_CHUNK, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_a_i_l0a = pl.move(tile_a_i_l1, target_memory=pl.MemorySpace.Left) + tile_b_i_l0b = pl.move(tile_b_i_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul_acc(acc, tile_a_i_l0a, tile_b_i_l0b) + + out = pl.store(acc, [out_row, out_col], output) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_kv_proj_reduce( + normed_tile: pl.Tensor[[BATCH_TILE, hidden], pl.BF16], + weight: pl.Tensor[[hidden, kv_hidden], pl.BF16], + out_row: pl.Scalar[pl.INDEX], + out_col: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[batch, kv_hidden], pl.FP32]], + ) -> pl.Tensor[[batch, kv_hidden], pl.FP32]: + tile_a_l1 = pl.load(normed_tile, [0, 0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_b_l1 = pl.load(weight, [0, out_col], [K_CHUNK, KV_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_a_l0a = pl.move(tile_a_l1, target_memory=pl.MemorySpace.Left) + tile_b_l0b = pl.move(tile_b_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul(tile_a_l0a, tile_b_l0b) + + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i_l1 = pl.load(normed_tile, [0, k0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_b_i_l1 = pl.load(weight, [k0, out_col], [K_CHUNK, KV_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + tile_a_i_l0a = pl.move(tile_a_i_l1, target_memory=pl.MemorySpace.Left) + tile_b_i_l0b = pl.move(tile_b_i_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul_acc(acc, tile_a_i_l0a, tile_b_i_l0b) + + out = pl.store(acc, [out_row, out_col], output) + return out + + # ── Scope 2 InCore kernels ───────────────────────────────────────── + + @pl.function(type=pl.FunctionType.InCore) + def kernel_q_pad_init( + pad_offset: pl.Scalar[pl.INDEX], + output: pl.InOut[pl.Tensor[[batch * total_q_groups * Q_HEAD_PAD, head_dim], pl.BF16]], + ) -> pl.Tensor[[batch * total_q_groups * Q_HEAD_PAD, head_dim], pl.BF16]: + zero_tile = pl.create_tile([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + zero_tile = pl.mul(zero_tile, 0.0) + zero_bf16 = pl.cast(zero_tile, target_type=pl.BF16) + out = pl.store(zero_bf16, [pad_offset, 0], output) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_rope_kv_cache_q_pad( + k_proj: pl.Tensor[[batch, kv_hidden], pl.FP32], + v_proj: pl.Tensor[[batch, kv_hidden], pl.FP32], + q_proj: pl.Tensor[[batch, hidden], pl.FP32], + cos_lo: pl.Tensor[[1, half_dim], pl.FP32], + cos_hi: pl.Tensor[[1, half_dim], pl.FP32], + sin_lo: pl.Tensor[[1, half_dim], pl.FP32], + sin_hi: pl.Tensor[[1, half_dim], pl.FP32], + b: pl.Scalar[pl.INDEX], + ki: pl.Scalar[pl.INDEX], + cache_row: pl.Scalar[pl.INDEX], + q_pad_base: pl.Scalar[pl.INDEX], + q_base: pl.Scalar[pl.INDEX], + k_cache: pl.InOut[pl.Tensor[[cache_rows, head_dim], pl.BF16]], + v_cache: pl.InOut[pl.Tensor[[cache_rows, head_dim], pl.BF16]], + all_q_padded: pl.InOut[pl.Tensor[[batch * total_q_groups * Q_HEAD_PAD, head_dim], pl.BF16]], + ) -> tuple[ + pl.Tensor[[cache_rows, head_dim], pl.BF16], + pl.Tensor[[cache_rows, head_dim], pl.BF16], + pl.Tensor[[batch * total_q_groups * Q_HEAD_PAD, head_dim], pl.BF16], + ]: + kv_col = ki * head_dim + # K RoPE + cache update. + k_lo = pl.load(k_proj, [b, kv_col], [1, half_dim], target_memory=pl.MemorySpace.Vec) + k_hi = pl.load(k_proj, [b, kv_col + half_dim], [1, half_dim], target_memory=pl.MemorySpace.Vec) + cos_lo_t = pl.load(cos_lo, [0, 0], [1, half_dim], target_memory=pl.MemorySpace.Vec) + cos_hi_t = pl.load(cos_hi, [0, 0], [1, half_dim], target_memory=pl.MemorySpace.Vec) + sin_lo_t = pl.load(sin_lo, [0, 0], [1, half_dim], target_memory=pl.MemorySpace.Vec) + sin_hi_t = pl.load(sin_hi, [0, 0], [1, half_dim], target_memory=pl.MemorySpace.Vec) + + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo_t), + pl.col_expand_mul(k_hi, sin_lo_t), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi_t), + pl.col_expand_mul(k_lo, sin_hi_t), + ) + rot_lo_bf16 = pl.cast(rot_lo, target_type=pl.BF16) + rot_hi_bf16 = pl.cast(rot_hi, target_type=pl.BF16) + k_cache_out = pl.store(rot_lo_bf16, [cache_row, 0], k_cache) + k_cache_out = pl.store(rot_hi_bf16, [cache_row, half_dim], k_cache_out) + + # V cache update. + v_tile = pl.load(v_proj, [b, ki * head_dim], [1, head_dim], target_memory=pl.MemorySpace.Vec) + v_tile_bf16 = pl.cast(v_tile, target_type=pl.BF16) + v_cache_out = pl.store(v_tile_bf16, [cache_row, 0], v_cache) + + # Q RoPE + pad. + for qi in pl.range(Q_HEAD_BATCH): + q_col = (q_base + qi) * head_dim + q_lo = pl.load(q_proj, [b, q_col], [1, half_dim], target_memory=pl.MemorySpace.Vec) + q_hi = pl.load(q_proj, [b, q_col + half_dim], [1, half_dim], target_memory=pl.MemorySpace.Vec) + q_rot_lo = pl.sub( + pl.col_expand_mul(q_lo, cos_lo_t), + pl.col_expand_mul(q_hi, sin_lo_t), + ) + q_rot_hi = pl.add( + pl.col_expand_mul(q_hi, cos_hi_t), + pl.col_expand_mul(q_lo, sin_hi_t), + ) + q_rot_lo_bf16 = pl.cast(q_rot_lo, target_type=pl.BF16) + q_rot_hi_bf16 = pl.cast(q_rot_hi, target_type=pl.BF16) + all_q_padded = pl.store(q_rot_lo_bf16, [q_pad_base + ki * Q_HEAD_PAD + qi, 0], all_q_padded) + all_q_padded = pl.store(q_rot_hi_bf16, [q_pad_base + ki * Q_HEAD_PAD + qi, half_dim], all_q_padded) + + return k_cache_out, v_cache_out, all_q_padded + + @pl.function(type=pl.FunctionType.InCore) + def kernel_attn_qk_matmul( + q_padded: pl.Tensor[[Q_HEAD_PAD, head_dim], pl.BF16], + k_tile: pl.Tensor[[SEQ_TILE, head_dim], pl.BF16], + score_row: pl.Scalar[pl.INDEX], + all_raw_scores: pl.Out[pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.FP32]], + ) -> pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.FP32]: + q_l1 = pl.load(q_padded, [0, 0], [Q_HEAD_PAD, head_dim], target_memory=pl.MemorySpace.Mat) + k_l1 = pl.load(k_tile, [0, 0], [SEQ_TILE, head_dim], target_memory=pl.MemorySpace.Mat, transpose=True) + q_l0a = pl.move(q_l1, target_memory=pl.MemorySpace.Left) + k_l0b = pl.move(k_l1, target_memory=pl.MemorySpace.Right) + scores = pl.matmul(q_l0a, k_l0b) + out = pl.store(scores, [score_row, 0], all_raw_scores) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_attn_softmax_prepare( + all_raw_scores: pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.FP32], + sb: pl.Scalar[pl.INDEX], + valid_len: pl.Scalar[pl.INDEX], + all_exp_padded: pl.InOut[pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.BF16]], + all_cur_mi: pl.InOut[pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32]], + all_cur_li: pl.InOut[pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32]], + ) -> tuple[ + pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.BF16], + pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32], + pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32], + ]: + scores_valid = pl.load( + all_raw_scores, + [sb * Q_HEAD_PAD, 0], + [Q_HEAD_BATCH, SEQ_TILE], + valid_shapes=[Q_HEAD_BATCH, valid_len], + target_memory=pl.MemorySpace.Vec, + ) + scores_padded = pl.tile.fillpad(scores_valid, pad_value=pl.PadValue.min) + scores = pl.mul(scores_padded, attn_scale) + tmp = pl.create_tile([Q_HEAD_BATCH, SEQ_TILE], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + cur_mi = pl.row_max(scores, tmp) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) + cur_li = pl.row_sum(exp_scores_fp32, tmp) + + exp_out = pl.store(exp_scores_bf16, [sb * Q_HEAD_PAD, 0], all_exp_padded) + mi_out = pl.store(cur_mi, [sb * Q_HEAD_BATCH, 0], all_cur_mi) + li_out = pl.store(cur_li, [sb * Q_HEAD_BATCH, 0], all_cur_li) + return exp_out, mi_out, li_out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_attn_sv_matmul( + all_exp_padded: pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], pl.BF16], + v_tile: pl.Tensor[[SEQ_TILE, head_dim], pl.BF16], + sb: pl.Scalar[pl.INDEX], + all_oi_tmp: pl.Out[pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, head_dim], pl.FP32]], + ) -> pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, head_dim], pl.FP32]: + exp_l1 = pl.load(all_exp_padded, [sb * Q_HEAD_PAD, 0], [Q_HEAD_PAD, SEQ_TILE], target_memory=pl.MemorySpace.Mat) + v_l1 = pl.load(v_tile, [0, 0], [SEQ_TILE, head_dim], target_memory=pl.MemorySpace.Mat) + exp_l0a = pl.move(exp_l1, target_memory=pl.MemorySpace.Left) + v_l0b = pl.move(v_l1, target_memory=pl.MemorySpace.Right) + oi_tmp = pl.matmul(exp_l0a, v_l0b) + out = pl.store(oi_tmp, [sb * Q_HEAD_PAD, 0], all_oi_tmp) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_attn_online_update( + all_oi_tmp: pl.Tensor[[max_ctx_blocks * Q_HEAD_PAD, head_dim], pl.FP32], + all_cur_mi: pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32], + all_cur_li: pl.Tensor[[max_ctx_blocks * Q_HEAD_BATCH, 1], pl.FP32], + ctx_blocks: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[1, Q_HEAD_BATCH * head_dim], pl.BF16]], + ) -> pl.Tensor[[1, Q_HEAD_BATCH * head_dim], pl.BF16]: + oi = pl.load(all_oi_tmp, [0, 0], [Q_HEAD_BATCH, head_dim], target_memory=pl.MemorySpace.Vec) + mi = pl.load(all_cur_mi, [0, 0], [Q_HEAD_BATCH, 1], target_memory=pl.MemorySpace.Vec) + li = pl.load(all_cur_li, [0, 0], [Q_HEAD_BATCH, 1], target_memory=pl.MemorySpace.Vec) + + for sb in pl.range(1, ctx_blocks): + oi_tmp_valid = pl.load( + all_oi_tmp, [sb * Q_HEAD_PAD, 0], [Q_HEAD_BATCH, head_dim], + target_memory=pl.MemorySpace.Vec, + ) + cur_mi = pl.load( + all_cur_mi, [sb * Q_HEAD_BATCH, 0], [Q_HEAD_BATCH, 1], + target_memory=pl.MemorySpace.Vec, + ) + cur_li = pl.load( + all_cur_li, [sb * Q_HEAD_BATCH, 0], [Q_HEAD_BATCH, 1], + target_memory=pl.MemorySpace.Vec, + ) + mi_new = pl.maximum(mi, cur_mi) + alpha = pl.exp(pl.sub(mi, mi_new)) + beta = pl.exp(pl.sub(cur_mi, mi_new)) + li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) + oi = pl.add( + pl.row_expand_mul(oi, alpha), + pl.row_expand_mul(oi_tmp_valid, beta), + ) + mi = mi_new + + ctx = pl.row_expand_div(oi, li) + ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * head_dim]) + ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16) + out = pl.store(ctx_flat_bf16, [0, 0], output) + return out + + # ── Scope 3 InCore kernels ───────────────────────────────────────── + + @pl.function(type=pl.FunctionType.InCore) + def kernel_output_proj_reduce( + attn_out: pl.Tensor[[batch, hidden], pl.BF16], + wo: pl.Tensor[[hidden, hidden], pl.BF16], + b0: pl.Scalar[pl.INDEX], + out_col: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[BATCH_TILE, Q_OUT_CHUNK], pl.FP32]], + ) -> pl.Tensor[[BATCH_TILE, Q_OUT_CHUNK], pl.FP32]: + a_l1 = pl.load(attn_out, [b0, 0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + w_l1 = pl.load(wo, [0, out_col], [K_CHUNK, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + a_l0 = pl.move(a_l1, target_memory=pl.MemorySpace.Left) + w_l0 = pl.move(w_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul(a_l0, w_l0) + + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + a_i_l1 = pl.load(attn_out, [b0, k0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + w_i_l1 = pl.load(wo, [k0, out_col], [K_CHUNK, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + a_i_l0 = pl.move(a_i_l1, target_memory=pl.MemorySpace.Left) + w_i_l0 = pl.move(w_i_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul_acc(acc, a_i_l0, w_i_l0) + + out = pl.store(acc, [0, 0], output) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_residual_add_store( + o_acc: pl.Tensor[[BATCH_TILE, Q_OUT_CHUNK], pl.FP32], + hidden_states: pl.Tensor[[batch, hidden], pl.BF16], + b0: pl.Scalar[pl.INDEX], + col: pl.Scalar[pl.INDEX], + resid1_tile: pl.Out[pl.Tensor[[BATCH_TILE, hidden], pl.FP32]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.FP32]: + o_tile = pl.load(o_acc, [0, 0], [BATCH_TILE, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Vec) + resid = pl.load(hidden_states, [b0, col], [BATCH_TILE, Q_OUT_CHUNK], target_memory=pl.MemorySpace.Vec) + resid_f32 = pl.cast(resid, target_type=pl.FP32) + resid_sum = pl.add(o_tile, resid_f32) + out = pl.store(resid_sum, [0, col], resid1_tile) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_post_rmsnorm( + resid1_tile: pl.Tensor[[BATCH_TILE, hidden], pl.FP32], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + output: pl.Out[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + sq_sum = pl.create_tile([1, BATCH_TILE], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + sq_sum = pl.mul(sq_sum, 0.0) + + for kb in pl.range(hidden_blocks): + k0 = kb * K_CHUNK + x_chunk = pl.load( + resid1_tile, [0, k0], [BATCH_TILE, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + squared = pl.mul(x_chunk, x_chunk) + tmp = pl.create_tile([BATCH_TILE, K_CHUNK], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + row_s: pl.Tile[[BATCH_TILE, 1], pl.FP32] = pl.row_sum(squared, tmp) + sq_sum = pl.add(sq_sum, pl.reshape(row_s, [1, BATCH_TILE])) + + inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) + + for kb, (out_iter,) in pl.range(hidden_blocks, init_values=(output,)): + k0 = kb * K_CHUNK + x_chunk = pl.load( + resid1_tile, [0, k0], [BATCH_TILE, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + gamma = pl.load( + post_rms_weight, [0, k0], [1, K_CHUNK], + target_memory=pl.MemorySpace.Vec, + ) + scaled = pl.row_expand_mul(x_chunk, pl.reshape(inv_rms, [BATCH_TILE, 1])) + weighted = pl.col_expand_mul(scaled, gamma) + weighted_bf16 = pl.cast(weighted, target_type=pl.BF16) + out_next = pl.store(weighted_bf16, [0, k0], out_iter) + (out_carry,) = pl.yield_(out_next) + + return out_carry + + @pl.function(type=pl.FunctionType.InCore) + def kernel_mlp_matmul_reduce( + post_norm_tile: pl.Tensor[[BATCH_TILE, hidden], pl.BF16], + weight: pl.Tensor[[hidden, inter], pl.BF16], + out_col: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[BATCH_TILE, MLP_OUT_CHUNK], pl.FP32]], + ) -> pl.Tensor[[BATCH_TILE, MLP_OUT_CHUNK], pl.FP32]: + a_l1 = pl.load(post_norm_tile, [0, 0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + w_l1 = pl.load(weight, [0, out_col], [K_CHUNK, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + a_l0 = pl.move(a_l1, target_memory=pl.MemorySpace.Left) + w_l0 = pl.move(w_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul(a_l0, w_l0) + + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + a_i_l1 = pl.load(post_norm_tile, [0, k0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Mat) + w_i_l1 = pl.load(weight, [k0, out_col], [K_CHUNK, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + a_i_l0 = pl.move(a_i_l1, target_memory=pl.MemorySpace.Left) + w_i_l0 = pl.move(w_i_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul_acc(acc, a_i_l0, w_i_l0) + + out = pl.store(acc, [0, 0], output) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_silu_activation( + gate_acc: pl.Tensor[[BATCH_TILE, MLP_OUT_CHUNK], pl.FP32], + up_acc: pl.Tensor[[BATCH_TILE, MLP_OUT_CHUNK], pl.FP32], + out_col: pl.Scalar[pl.INDEX], + mlp_tile: pl.Out[pl.Tensor[[BATCH_TILE, inter], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, inter], pl.BF16]: + gate_t = pl.load(gate_acc, [0, 0], [BATCH_TILE, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Vec) + up_t = pl.load(up_acc, [0, 0], [BATCH_TILE, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Vec) + sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_t)), 1.0)) + mlp_chunk = pl.mul(pl.mul(gate_t, sigmoid), up_t) + mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) + out = pl.store(mlp_chunk_bf16, [0, out_col], mlp_tile) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_down_proj_reduce( + mlp_tile: pl.Tensor[[BATCH_TILE, inter], pl.BF16], + w_down: pl.Tensor[[inter, hidden], pl.BF16], + out_col: pl.Scalar[pl.INDEX], + output: pl.Out[pl.Tensor[[BATCH_TILE, K_CHUNK], pl.FP32]], + ) -> pl.Tensor[[BATCH_TILE, K_CHUNK], pl.FP32]: + a_l1 = pl.load(mlp_tile, [0, 0], [BATCH_TILE, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + w_l1 = pl.load(w_down, [0, out_col], [MLP_OUT_CHUNK, K_CHUNK], target_memory=pl.MemorySpace.Mat) + a_l0 = pl.move(a_l1, target_memory=pl.MemorySpace.Left) + w_l0 = pl.move(w_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul(a_l0, w_l0) + + for ob in pl.range(1, mlp_out_blocks): + o0 = ob * MLP_OUT_CHUNK + a_i_l1 = pl.load(mlp_tile, [0, o0], [BATCH_TILE, MLP_OUT_CHUNK], target_memory=pl.MemorySpace.Mat) + w_i_l1 = pl.load(w_down, [o0, out_col], [MLP_OUT_CHUNK, K_CHUNK], target_memory=pl.MemorySpace.Mat) + a_i_l0 = pl.move(a_i_l1, target_memory=pl.MemorySpace.Left) + w_i_l0 = pl.move(w_i_l1, target_memory=pl.MemorySpace.Right) + acc = pl.matmul_acc(acc, a_i_l0, w_i_l0) + + out = pl.store(acc, [0, 0], output) + return out + + @pl.function(type=pl.FunctionType.InCore) + def kernel_final_residual_store( + down_acc: pl.Tensor[[BATCH_TILE, K_CHUNK], pl.FP32], + resid1_tile: pl.Tensor[[BATCH_TILE, hidden], pl.FP32], + col: pl.Scalar[pl.INDEX], + b0: pl.Scalar[pl.INDEX], + out: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]], + ) -> pl.Tensor[[batch, hidden], pl.BF16]: + down_t = pl.load(down_acc, [0, 0], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Vec) + resid_t = pl.load(resid1_tile, [0, col], [BATCH_TILE, K_CHUNK], target_memory=pl.MemorySpace.Vec) + out_chunk = pl.add(down_t, resid_t) + out_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16) + result = pl.store(out_chunk_bf16, [b0, col], out) + return result + + # ── Program class (Orchestration only) ───────────────────────────── + + @pl.program + class Qwen3DecodeTile: + @pl.function(type=pl.FunctionType.Orchestration) + def qwen3_decode( + self, + hidden_states: pl.Tensor[[batch, hidden], pl.BF16], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + wq: pl.Tensor[[hidden, hidden], pl.BF16], + wk: pl.Tensor[[hidden, kv_hidden], pl.BF16], + wv: pl.Tensor[[hidden, kv_hidden], pl.BF16], + seq_lens: pl.Tensor[[batch], pl.INT32], + rope_cos: pl.Tensor[[max_seq, head_dim], pl.FP32], + rope_sin: pl.Tensor[[max_seq, head_dim], pl.FP32], + k_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16], + v_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16], + wo: pl.Tensor[[hidden, hidden], pl.BF16], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + w_gate: pl.Tensor[[hidden, inter], pl.BF16], + w_up: pl.Tensor[[hidden, inter], pl.BF16], + w_down: pl.Tensor[[inter, hidden], pl.BF16], + out: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]], + ) -> pl.Tensor[[batch, hidden], pl.BF16]: + # Intermediate FP32 tensors between scope 1 and scope 2. + q_proj = pl.create_tensor([batch, hidden], dtype=pl.FP32) + k_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32) + v_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32) + + # ── Scope 1: input RMSNorm + Q/K/V projection ── + for b0 in pl.range(0, batch, BATCH_TILE): + hidden_tile = pl.slice(hidden_states, [BATCH_TILE, hidden], [b0, 0]) + normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + normed_tile = kernel_build_normed_tile(hidden_tile, input_rms_weight, normed_tile) + + for ob in pl.range(q_out_blocks): + q0 = ob * Q_OUT_CHUNK + q_proj = kernel_q_proj_reduce(normed_tile, wq, b0, q0, q_proj) + + for ob in pl.range(kv_out_blocks): + kv0 = ob * KV_OUT_CHUNK + k_proj = kernel_kv_proj_reduce(normed_tile, wk, b0, kv0, k_proj) + v_proj = kernel_kv_proj_reduce(normed_tile, wv, b0, kv0, v_proj) + + # ── Scope 2: RoPE + KV cache update + grouped-query attention ── + # Pad q. + all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) + for idx in pl.range(batch * total_q_groups): + pad_offset = idx * Q_HEAD_PAD + Q_HEAD_BATCH + all_q_padded = kernel_q_pad_init(pad_offset, all_q_padded) + + attn_out_tensor = pl.create_tensor([batch, hidden], dtype=pl.BF16) + for b in pl.range(batch): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + cos_row = pl.slice(rope_cos, [1, head_dim], [pos, 0]) + sin_row = pl.slice(rope_sin, [1, head_dim], [pos, 0]) + cos_lo = pl.slice(cos_row, [1, half_dim], [0, 0]) + cos_hi = pl.slice(cos_row, [1, half_dim], [0, half_dim]) + sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0]) + sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) + + # Stage 1: K RoPE + cache update + V cache + Q RoPE + pad. + q_pad_base = b * total_q_groups * Q_HEAD_PAD + for ki in pl.range(num_kv_heads): + cache_row = b * num_kv_heads * max_seq + ki * max_seq + pos + q_base = ki * q_per_kv + k_cache, v_cache, all_q_padded = kernel_rope_kv_cache_q_pad( + k_proj, v_proj, q_proj, + cos_lo, cos_hi, sin_lo, sin_hi, + b, ki, cache_row, q_pad_base, q_base, + k_cache, v_cache, all_q_padded, + ) + + attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) + for gi in pl.range(total_q_groups): + kvh = gi // q_groups + qg = gi - kvh * q_groups + q_base_gi = kvh * q_per_kv + qg * Q_HEAD_BATCH + q_padded = pl.slice( + all_q_padded, + [Q_HEAD_PAD, head_dim], + [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0], + ) + + # Stage 2: QK matmul for all active sb blocks. + all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) + all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) + all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) + all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) + all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) + for sb in pl.range(ctx_blocks): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + k_tile = pl.slice( + k_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + all_raw_scores = kernel_attn_qk_matmul( + q_padded, k_tile, sb * Q_HEAD_PAD, all_raw_scores, + ) + + # Stage 3: softmax for all active sb blocks. + for sb in pl.range(ctx_blocks): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + all_exp_padded, all_cur_mi, all_cur_li = ( + kernel_attn_softmax_prepare( + all_raw_scores, sb, valid_len, + all_exp_padded, all_cur_mi, all_cur_li, + ) + ) + + # Stage 4: SV matmul for all active sb blocks. + for sb in pl.range(ctx_blocks): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + v_tile = pl.slice( + v_cache, + [SEQ_TILE, head_dim], + [cache_row0, 0], + ) + all_oi_tmp = kernel_attn_sv_matmul( + all_exp_padded, v_tile, sb, all_oi_tmp, + ) + + # Stage 5: online softmax accumulation and normalisation. + ctx_buf = pl.create_tensor([1, Q_HEAD_BATCH * head_dim], dtype=pl.BF16) + ctx_flat = kernel_attn_online_update( + all_oi_tmp, all_cur_mi, all_cur_li, ctx_blocks, ctx_buf, + ) + attn_row = pl.assemble( + attn_row, ctx_flat, [0, q_base_gi * head_dim], + ) + + attn_out_tensor = pl.assemble(attn_out_tensor, attn_row, [b, 0]) + + # ── Scope 3: output projection + residual + post RMSNorm + MLP + residual ── + for b0 in pl.range(0, batch, BATCH_TILE): + resid1_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.FP32) + + # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. + for ob in pl.range(q_out_blocks): + o0 = ob * Q_OUT_CHUNK + o_acc_buf = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + o_acc = kernel_output_proj_reduce( + attn_out_tensor, wo, b0, o0, o_acc_buf, + ) + # Stage 2: Residual addition with hidden_states. + resid1_tile = kernel_residual_add_store( + o_acc, hidden_states, b0, o0, resid1_tile, + ) + + # Stage 3: Post-attention RMSNorm. + post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + post_norm_tile = kernel_post_rmsnorm( + resid1_tile, post_rms_weight, post_norm_tile, + ) + + # Stage 4 & 5 & 6: MLP: gate/up projections + SiLU. + mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16) + for ob in pl.range(mlp_out_blocks): + o0 = ob * MLP_OUT_CHUNK + gate_buf = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + gate_acc = kernel_mlp_matmul_reduce( + post_norm_tile, w_gate, o0, gate_buf, + ) + up_buf = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + up_acc = kernel_mlp_matmul_reduce( + post_norm_tile, w_up, o0, up_buf, + ) + mlp_tile = kernel_silu_activation( + gate_acc, up_acc, o0, mlp_tile, + ) + + # Stage 7 & 8: Down projection + final residual writeback. + for dob in pl.range(hidden_blocks): + d0 = dob * K_CHUNK + down_buf = pl.create_tensor([BATCH_TILE, K_CHUNK], dtype=pl.FP32) + down_acc = kernel_down_proj_reduce( + mlp_tile, w_down, d0, down_buf, + ) + out = kernel_final_residual_store( + down_acc, resid1_tile, d0, b0, out, + ) + + return out + + return Qwen3DecodeTile + + +def build_tensor_specs( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, + use_max_seq: bool = False, +): + import torch + from pypto.runtime import TensorSpec + + hidden = num_heads * head_dim + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + cache_rows = batch * num_kv_heads * max_seq + + def init_hidden_states(): + return torch.rand(batch, hidden_size) - 0.5 + + def init_rms_weight(): + return torch.rand(1, hidden_size) - 0.5 + + def init_wq(): + return torch.rand(hidden_size, hidden_size) / hidden_size ** 0.5 + + def init_wk(): + return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_wv(): + return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_seq_lens(): + if use_max_seq: + return torch.full((batch,), max_seq, dtype=torch.int32) + return torch.randint(1, max_seq + 1, (batch,), dtype=torch.int32) + + def init_rope_cos(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_rope_sin(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_k_cache(): + return torch.rand(cache_rows, head_dim) - 0.5 + + def init_v_cache(): + return torch.rand(cache_rows, head_dim) - 0.5 + + def init_wo(): + return (torch.rand(hidden_size, hidden_size) - 0.5) / hidden_size ** 0.5 + + def init_post_rms_weight(): + return torch.ones(1, hidden_size) + + def init_w_gate(): + return (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_up(): + return (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_down(): + return (torch.rand(inter, hidden_size) - 0.5) / inter ** 0.5 + + return [ + TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, + init_value=init_hidden_states), + TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, + init_value=init_rms_weight), + TensorSpec("wq", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wq), + TensorSpec("wk", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wk), + TensorSpec("wv", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wv), + TensorSpec("seq_lens", [batch], torch.int32, init_value=init_seq_lens), + TensorSpec("rope_cos", [max_seq, head_dim], torch.float32, + init_value=init_rope_cos), + TensorSpec("rope_sin", [max_seq, head_dim], torch.float32, + init_value=init_rope_sin), + TensorSpec("k_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_k_cache), + TensorSpec("v_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_v_cache), + TensorSpec("wo", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wo), + TensorSpec("post_rms_weight", [1, hidden_size], torch.float32, + init_value=init_post_rms_weight), + TensorSpec("w_gate", [hidden_size, inter], torch.bfloat16, + init_value=init_w_gate), + TensorSpec("w_up", [hidden_size, inter], torch.bfloat16, + init_value=init_w_up), + TensorSpec("w_down", [inter, hidden_size], torch.bfloat16, + init_value=init_w_down), + TensorSpec("out", [batch, hidden], torch.bfloat16, is_output=True), + ] + + +def golden_qwen3_decode(tensors, params): + """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" + import math + + import torch + + hidden_states = tensors["hidden_states"] + input_rms_weight = tensors["input_rms_weight"] + wq = tensors["wq"] + wk = tensors["wk"] + wv = tensors["wv"] + seq_lens = tensors["seq_lens"] + rope_cos = tensors["rope_cos"] + rope_sin = tensors["rope_sin"] + k_cache = tensors["k_cache"].clone() + v_cache = tensors["v_cache"].clone() + wo = tensors["wo"] + post_rms_weight = tensors["post_rms_weight"] + w_gate = tensors["w_gate"] + w_up = tensors["w_up"] + w_down = tensors["w_down"] + + batch = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + kv_hidden = wk.shape[1] + head_dim = rope_cos.shape[1] + max_seq = rope_cos.shape[0] + num_kv_heads = kv_hidden // head_dim + num_heads = hidden_size // head_dim + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + half = head_dim // 2 + scale = 1.0 / math.sqrt(head_dim) + inter = w_gate.shape[1] + eps = 1e-6 + + # ── Scope 1 golden: RMSNorm + Q/K/V projection ── + q_proj = torch.zeros(batch, hidden_size, dtype=torch.float32) + k_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + v_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + + for b0 in range(0, batch, BATCH_TILE): + b_end = min(b0 + BATCH_TILE, batch) + x_tile = hidden_states[b0:b_end, :].float() + + sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32) + for k0 in range(0, hidden_size, K_CHUNK): + x_chunk = x_tile[:, k0:k0 + K_CHUNK] + sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) + variance = sq_sum / hidden_size + EPS + rms = torch.sqrt(variance) + normed = (x_tile / rms * input_rms_weight.float()).bfloat16() + + q_proj[b0:b_end, :] = (normed.float() @ wq.float()).float() + k_proj[b0:b_end, :] = (normed.float() @ wk.float()).float() + v_proj[b0:b_end, :] = (normed.float() @ wv.float()).float() + + # ── Scope 2 golden: RoPE + cache update + attention ── + attn_out = torch.zeros(batch, hidden_size, dtype=torch.float32) + + for b in range(batch): + ctx_len = seq_lens[b].item() + pos = ctx_len - 1 + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + + cos_row = rope_cos[pos : pos + 1, :] + sin_row = rope_sin[pos : pos + 1, :] + cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] + sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:] + + k_heads = k_proj[b].view(num_kv_heads, head_dim) + k_lo_h, k_hi_h = k_heads[:, :half], k_heads[:, half:] + k_rot = torch.cat([k_lo_h * cos_lo - k_hi_h * sin_lo, k_hi_h * cos_hi + k_lo_h * sin_hi], dim=-1) + + for ki in range(num_kv_heads): + cr = b * num_kv_heads * max_seq + ki * max_seq + pos + k_cache[cr, :] = k_rot[ki].to(torch.bfloat16) + v_cache[cr, :] = v_proj[b, ki * head_dim : (ki + 1) * head_dim].to(torch.bfloat16) + + q_heads = q_proj[b].view(num_heads, head_dim) + q_lo_h, q_hi_h = q_heads[:, :half], q_heads[:, half:] + q_rot = torch.cat([q_lo_h * cos_lo - q_hi_h * sin_lo, q_hi_h * cos_hi + q_lo_h * sin_hi], dim=-1) + + for kvh in range(num_kv_heads): + for qg in range(q_groups): + q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH + q_grp_bf16 = q_rot[q_base : q_base + Q_HEAD_BATCH, :].to(torch.bfloat16) + + oi = torch.zeros(Q_HEAD_BATCH, head_dim, dtype=torch.float32) + li = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + mi = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + + for sb in range(ctx_blocks): + s0 = sb * SEQ_TILE + valid_len = min(SEQ_TILE, ctx_len - s0) + cb = b * num_kv_heads * max_seq + kvh * max_seq + s0 + + k_tile = k_cache[cb : cb + SEQ_TILE, :] + v_tile = v_cache[cb : cb + SEQ_TILE, :] + + raw_scores = q_grp_bf16.float() @ k_tile.float().T + if valid_len < SEQ_TILE: + raw_scores[:, valid_len:] = torch.finfo(torch.float32).min + scores = raw_scores * scale + + cur_mi = scores.max(dim=-1, keepdim=True).values + exp_scores = torch.exp(scores - cur_mi) + exp_scores_bf16 = exp_scores.to(torch.bfloat16) + cur_li = exp_scores_bf16.float().sum(dim=-1, keepdim=True) + + oi_tmp = exp_scores_bf16.float() @ v_tile.float() + + if sb == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = torch.maximum(mi, cur_mi) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(cur_mi - mi_new) + li = alpha * li + beta * cur_li + oi = oi * alpha + oi_tmp * beta + mi = mi_new + + ctx = oi / li + for qi in range(Q_HEAD_BATCH): + qh = q_base + qi + attn_out[b, qh * head_dim : (qh + 1) * head_dim] = ctx[qi] + + # ── Scope 3 golden: output projection + residual + post RMSNorm + MLP + residual ── + # 1. Output projection (BF16 inputs, FP32 accumulation) + residual. + o_proj = torch.matmul(attn_out.float(), wo.float()) + resid1 = o_proj + hidden_states.float() + + # 2. Post-attention RMSNorm. + variance = resid1.pow(2).mean(dim=-1, keepdim=True) + inv_rms = torch.rsqrt(variance + eps) + normed_bf16 = (resid1 * inv_rms * post_rms_weight).bfloat16() + + # 3. SwiGLU MLP: gate/up projections, silu activation, down projection. + gate = torch.matmul(normed_bf16.float(), w_gate.float()) + up = torch.matmul(normed_bf16.float(), w_up.float()) + mlp_bf16 = (gate * torch.sigmoid(gate) * up).bfloat16() + down = torch.matmul(mlp_bf16.float(), w_down.float()) + + # 4. Final residual + cast to BF16. + tensors["out"][:] = (down + resid1).bfloat16() + + +def compile_and_run( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, + use_max_seq: bool = False, + platform: str = "a5", + device_id: int = 0, + dump_passes: bool = True, + runtime_profiling: bool = False, +): + from pypto.backend import BackendType + from pypto.ir.pass_manager import OptimizationStrategy + from pypto.runtime import RunConfig, run + + backend = BackendType.Ascend950 if platform.startswith("a5") else BackendType.Ascend910B + + program = build_qwen3_decode_program( + batch=batch, + max_seq=max_seq, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + tensor_specs = build_tensor_specs( + batch=batch, + max_seq=max_seq, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + use_max_seq=use_max_seq, + ) + + result = run( + program=program, + tensor_specs=tensor_specs, + golden=golden_qwen3_decode, + config=RunConfig( + platform=platform, + device_id=device_id, + rtol=3e-3, + atol=3e-3, + strategy=OptimizationStrategy.Default, + dump_passes=dump_passes, + backend_type=backend, + runtime_profiling=runtime_profiling, + ), + ) + return result + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a5", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("--runtime-profiling", action="store_true", default=False) + parser.add_argument("--max-seq", action="store_true", default=False, + help="set all seq_lens to MAX_SEQ (default: random)") + args = parser.parse_args() + + result = compile_and_run( + platform=args.platform, + device_id=args.device, + use_max_seq=args.max_seq, + runtime_profiling=args.runtime_profiling, + ) + if not result.passed: + if result.error: + print(f"Result: {result.error}") + raise SystemExit(1)