diff --git a/.claude/rules/coding-style.md b/.claude/rules/coding-style.md index 9e8d37e..2d0e206 100644 --- a/.claude/rules/coding-style.md +++ b/.claude/rules/coding-style.md @@ -32,7 +32,7 @@ PyPTO supports three function types: |------|---------|-------------| | `pl.FunctionType.InCore` | Runs on AICore. Manually load/store between GM and UB. | Explicit control over data movement and memory placement | | `pl.FunctionType.Orchestration` | Host/AICPU scheduling. Calls InCore kernels, manages tensor allocation. | Composing InCore kernels into a computation graph | -| `pl.FunctionType.Opaque` | Compiler decides InCore/Orchestration boundary. Use with `pl.auto_incore()`. | When you don't need manual placement control | +| `pl.FunctionType.Opaque` | Compiler decides InCore/Orchestration boundary. Use with `pl.at()`. | When you don't need manual placement control | ### Explicit InCore + Orchestration (pypto standard style) @@ -70,7 +70,7 @@ class HelloWorldProgram: class SoftmaxProgram: @pl.function(type=pl.FunctionType.Opaque) def softmax(self, input_tensor: pl.Tensor[[B, S, H], pl.FP32], ...): - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b in pl.parallel(0, B, 1, chunk=4): ... ``` diff --git a/docs/para_for.md b/docs/para_for.md index 2623bd4..5a41902 100644 --- a/docs/para_for.md +++ b/docs/para_for.md @@ -151,7 +151,7 @@ When a **`with pl.incore`** scope **covers** nested chunked loops, each such chu **Example:** two chunked loops inside one incore scope. ```python -with pl.incore(): +with pl.at(level=pl.Level.CORE_GROUP): for i in pl.range(0, 4096, chunk=1024): for j in pl.range(0, 2048, chunk=512): @@ -179,7 +179,7 @@ The compiler should **try to interchange** loop nesting so that: for c_1 in ...: # chunk loop 1 for c_2 in ...: # chunk loop 2 ... - with pl.incore(): # incore placed here (see §5.4) + with pl.at(level=pl.Level.CORE_GROUP): # incore placed here (see §5.4) for i in ...: # in_chunk loop 1 for j in ...: # in_chunk loop 2 @@ -214,7 +214,7 @@ for c in pl.range(0, 4): t_end = min(t_start + 1024, 4096) r_start = t_start * TILE_M # or derived from t_start x_tile = pl.slice(x, [TILE_M, N], [r_start, 0]) - with pl.incore(): # placed to encompass only the in_chunk loop + body + with pl.at(level=pl.Level.CORE_GROUP): # placed to encompass only the in_chunk loop + body for t in pl.range(t_start, t_end): # in_chunk loop inside incore # body: e.g. load(x_tile), softmax, store ... diff --git a/docs/pto2_rt.md b/docs/pto2_rt.md index 6a2baa3..54da4bc 100644 --- a/docs/pto2_rt.md +++ b/docs/pto2_rt.md @@ -765,7 +765,7 @@ The `docs/pypto-frontend-coding-style.md` describes the Python-to-C++ code gener | Type | Description | |------|-------------| -| **Opaque** | Default function type; may contain `pl.incore()` calls | +| **Opaque** | Default function type; may contain `pl.at()` scopes | | **Orchestration** | Host/AICPU orchestration function; calls InCore functions | | **InCore** | AICore kernel subgraph (load/compute/store) | diff --git a/docs/pypto-frontend-coding-style.md b/docs/pypto-frontend-coding-style.md index cb815e5..0f29c59 100644 --- a/docs/pypto-frontend-coding-style.md +++ b/docs/pypto-frontend-coding-style.md @@ -131,11 +131,11 @@ def BuildExampleGraph( --- -## 5. InCore 作用域与匿名 InCore(with pl.incore()) +## 5. InCore 作用域与匿名 InCore(with pl.at()) ### 5.1 语法 -在 **Opaque** 函数内用 `with pl.incore():` 标记一段“匿名” InCore 区域;解析后生成 `ScopeStmt(scope_type=InCore)`。 +在 **Opaque** 函数内用 `with pl.at(level=pl.Level.CORE_GROUP):` 标记一段”匿名” InCore 区域;解析后生成 `ScopeStmt(scope_type=InCore)`。 ```python @pl.program @@ -143,7 +143,7 @@ class Before: @pl.function # 默认 Opaque def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: y = x + 1 - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile = pl.load(y, [0], [64]) tile_sq = pl.mul(tile, tile) result = pl.store(tile_sq, [0], [64], x) @@ -163,7 +163,7 @@ class Before: | 类型 | 写法 | 用途 | |------|------|------| -| Opaque | 默认 / `pl.FunctionType.Opaque` | 未指定,可含 `pl.incore()` 待 outline | +| Opaque | 默认 / `pl.FunctionType.Opaque` | 未指定,可含 `pl.at()` 待 outline | | Orchestration | `pl.FunctionType.Orchestration` | Host/AICPU 编排,调用 InCore | | InCore | `pl.FunctionType.InCore` | AICore 上的子图(load/compute/store) | diff --git a/examples/beginner/hello_world.py b/examples/beginner/hello_world.py index 0d8ac95..276c710 100644 --- a/examples/beginner/hello_world.py +++ b/examples/beginner/hello_world.py @@ -38,7 +38,7 @@ def add_one( x: pl.Tensor[[rows, cols], pl.FP32], y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]], ) -> pl.Tensor[[rows, cols], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for r in pl.parallel(0, rows, 1, chunk=row_chunk): tile_x = pl.slice(x, [1, cols], [r, 0]) tile_y = pl.add(tile_x, 1.0) diff --git a/examples/beginner/matmul.py b/examples/beginner/matmul.py index 4693567..0d5db5a 100644 --- a/examples/beginner/matmul.py +++ b/examples/beginner/matmul.py @@ -48,7 +48,7 @@ def matmul( b: pl.Tensor[[k, n], pl.FP32], c: pl.Out[pl.Tensor[[m, n], pl.FP32]], ) -> pl.Tensor[[m, n], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for mb in pl.parallel(0, m, m_tile, chunk=m_chunk): for nb in pl.parallel(0, n, n_tile, chunk=n_chunk): tile_a = pl.slice(a, [m_tile, k], [mb, 0]) diff --git a/examples/intermediate/gemm.py b/examples/intermediate/gemm.py index d780e2f..91e0220 100644 --- a/examples/intermediate/gemm.py +++ b/examples/intermediate/gemm.py @@ -53,7 +53,7 @@ def gemm( b: pl.Tensor[[k, n], pl.FP32], c: pl.Out[pl.Tensor[[m, n], pl.FP32]], ) -> pl.Tensor[[m, n], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for mb in pl.parallel(0, m, m_tile, chunk=m_chunk): for nb in pl.parallel(0, n, n_tile, chunk=n_chunk): # First K-tile: initialize accumulator via matmul diff --git a/examples/intermediate/layer_norm.py b/examples/intermediate/layer_norm.py index ec69b25..c2f691b 100644 --- a/examples/intermediate/layer_norm.py +++ b/examples/intermediate/layer_norm.py @@ -44,7 +44,7 @@ def layer_norm( beta: pl.Tensor[[1, hidden], pl.FP32], y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]], ) -> pl.Tensor[[rows, hidden], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for r in pl.parallel(0, rows, row_chunk, chunk=1): tile_x = pl.slice(x, [row_chunk, hidden], [r, 0]) gamma_tile = pl.slice(gamma, [1, hidden], [0, 0]) diff --git a/examples/intermediate/rms_norm.py b/examples/intermediate/rms_norm.py index 8006873..6fe73f2 100644 --- a/examples/intermediate/rms_norm.py +++ b/examples/intermediate/rms_norm.py @@ -50,7 +50,7 @@ def rms_norm( gamma: pl.Tensor[[1, hidden], pl.FP32], y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]], ) -> pl.Tensor[[rows, hidden], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for r in pl.parallel(0, rows, row_chunk, chunk=1): # Pass 1: accumulate sum(x^2) across hidden chunks # row_sum produces [row_chunk, 1] col_major; scalar ops diff --git a/examples/intermediate/rope.py b/examples/intermediate/rope.py index 0b30349..2e6a43d 100644 --- a/examples/intermediate/rope.py +++ b/examples/intermediate/rope.py @@ -59,7 +59,7 @@ def rope( sin: pl.Tensor[[1, head_dim], pl.FP32], y: pl.Out[pl.Tensor[[total_rows, head_dim], pl.FP32]], ) -> pl.Tensor[[total_rows, head_dim], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b in pl.parallel(0, batch, 1, chunk=batch_chunk): # Slice cos/sin lo/hi halves directly from tensor # so each becomes a separate tile.load (no textract). diff --git a/examples/intermediate/softmax.py b/examples/intermediate/softmax.py index a96eadb..5be36c6 100644 --- a/examples/intermediate/softmax.py +++ b/examples/intermediate/softmax.py @@ -38,7 +38,7 @@ def softmax( x: pl.Tensor[[rows, cols], pl.FP32], y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]], ) -> pl.Tensor[[rows, cols], pl.FP32]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for r in pl.parallel(0, rows, row_chunk, chunk=1): tile_x = pl.slice(x, [row_chunk, cols], [r, 0]) diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_back.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_back.py index 47719b9..ff93315 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_back.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_back.py @@ -76,7 +76,7 @@ def deepseek_v3_2_decode_back_layer( w_down: pl.Tensor[[INTER_CFG, HIDDEN_CFG], pl.BF16], out: pl.Tensor[[BATCH_CFG, HIDDEN_CFG], pl.BF16], ) -> pl.Tensor[[BATCH_CFG, HIDDEN_CFG], pl.BF16]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): node_id = pl.tensor.read(node_id_t, [0]) combined = pl.create_tensor([BATCH_CFG, ATTN_OUT_CFG], dtype=pl.FP32) # Read combine results from this node view. diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py index ad6c978..f23fbc7 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py @@ -156,7 +156,7 @@ def deepseek_v3_2_decode_front_layer( qr = pl.create_tensor([BATCH_CFG, Q_LORA_RANK_CFG], dtype=pl.BF16) q_proj = pl.create_tensor([BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16) kv_a = pl.create_tensor([BATCH_CFG, KV_A_OUT], dtype=pl.BF16) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0) # Keep an explicit local Vec pad tensor alive in this scope so @@ -233,7 +233,7 @@ def deepseek_v3_2_decode_front_layer( # - C: sparse attention consumes merged topk immediately # This avoids materializing topk intermediates across kernel boundaries. attn_front = pl.create_tensor([BATCH_CFG, ATTN_OUT_CFG], dtype=pl.FP32) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): layer_id = pl.tensor.read(layer_id_t, [0]) for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): ctx_len = pl.tensor.read(seq_lens, [b]) diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back.py b/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back.py index 7644e43..1bd8edb 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back.py @@ -75,7 +75,7 @@ def deepseek_v3_2_prefill_back_layer( w_down: pl.Tensor[[INTER_CFG, HIDDEN_CFG], pl.BF16], out: pl.Tensor[[BATCH_CFG, MAX_SEQ_CFG, HIDDEN_CFG], pl.BF16], ) -> pl.Tensor[[BATCH_CFG, MAX_SEQ_CFG, HIDDEN_CFG], pl.BF16]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): node_id = pl.tensor.read(node_id_t, [0]) for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): seq_len_b = pl.tensor.read(seq_lens, [b]) diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_front.py b/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_front.py index 36af3bb..b52bd59 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_front.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_prefill_front.py @@ -141,7 +141,7 @@ def deepseek_v3_2_prefill_front_layer( w_latent_to_v: pl.Tensor[[NUM_HEADS_CFG, KV_LORA_RANK_CFG, V_HEAD_DIM_CFG], pl.BF16], dispatch_buf: pl.Tensor[[EP_NODES_CFG, BATCH_CFG, MAX_SEQ_CFG, ATTN_OUT_CFG], pl.BF16], ) -> pl.Tensor[[EP_NODES_CFG, BATCH_CFG, MAX_SEQ_CFG, ATTN_OUT_CFG], pl.BF16]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): layer_id = pl.tensor.read(layer_id_t, [0]) for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): diff --git a/examples/models/kimi/kimi_k2_decode.py b/examples/models/kimi/kimi_k2_decode.py index c043c8b..750ff20 100644 --- a/examples/models/kimi/kimi_k2_decode.py +++ b/examples/models/kimi/kimi_k2_decode.py @@ -137,7 +137,7 @@ def kimi_k2_decode_layer( # ========================================================================= # Scope 1: Input RMSNorm + QKV Projection # ========================================================================= - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) @@ -241,7 +241,7 @@ def kimi_k2_decode_layer( ) # Flash Decoding Attention per head - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) attn_row = pl.mul(attn_row, 0.0) @@ -325,7 +325,7 @@ def kimi_k2_decode_layer( # ========================================================================= # Scope 3: Output Projection + Residual + Post RMSNorm + MoE # ========================================================================= - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): # Output projection + residual resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) diff --git a/examples/models/milm/milm_decode.py b/examples/models/milm/milm_decode.py index dd3c882..77b19f5 100644 --- a/examples/models/milm/milm_decode.py +++ b/examples/models/milm/milm_decode.py @@ -128,7 +128,7 @@ def milm_decode_layer( # Scope 1: Input RMSNorm + QKV Projection # Optimized with chunked computation to reduce InCore pressure # ========================================================================= - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): # Compute sum of squares for RMSNorm sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) @@ -234,7 +234,7 @@ def milm_decode_layer( ) # Flash Decoding Attention (per head with GQA) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) attn_row = pl.mul(attn_row, 0.0) @@ -318,7 +318,7 @@ def milm_decode_layer( # ========================================================================= # Scope 3: Output Projection + Residual + Post RMSNorm + SwiGLU MLP # ========================================================================= - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): # Output projection + residual (first residual connection) resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) diff --git a/examples/models/qwen3/qwen3-32b.py b/examples/models/qwen3/qwen3-32b.py index 61373ff..d53454f 100644 --- a/examples/models/qwen3/qwen3-32b.py +++ b/examples/models/qwen3/qwen3-32b.py @@ -106,7 +106,7 @@ def qwen3_decode_layer( attn_out = pl.create_tensor([BATCH_CFG, HIDDEN_CFG], dtype=pl.FP32) # Scope 1: input RMSNorm + Q/K/V projection. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) @@ -197,7 +197,7 @@ def qwen3_decode_layer( [cache_row, 0], ) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) attn_row = pl.mul(attn_row, 0.0) @@ -269,7 +269,7 @@ def qwen3_decode_layer( attn_out = pl.assemble(attn_out, attn_row, [b, 0]) # Scope 3: output projection + residual + post RMSNorm + MLP + residual. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) diff --git a/examples/models/qwen3/qwen3_32b_decode.py b/examples/models/qwen3/qwen3_32b_decode.py index 5f4057f..70964ca 100644 --- a/examples/models/qwen3/qwen3_32b_decode.py +++ b/examples/models/qwen3/qwen3_32b_decode.py @@ -112,7 +112,7 @@ def qwen3_decode( for b0 in pl.range(0, batch, BATCH_TILE): normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(hidden_blocks): k0 = kb * K_CHUNK @@ -144,7 +144,7 @@ def qwen3_decode( for ob in pl.range(q_out_blocks): q0 = ob * Q_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) @@ -157,7 +157,7 @@ def qwen3_decode( for ob in pl.range(kv_out_blocks): kv0 = ob * KV_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) @@ -167,7 +167,7 @@ def qwen3_decode( tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) @@ -181,7 +181,7 @@ def qwen3_decode( # ── Scope 2: RoPE + KV cache update + grouped-query attention ── # Pad q all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for idx in pl.range(batch * total_q_groups): all_q_padded = pl.assemble( all_q_padded, @@ -202,7 +202,7 @@ def qwen3_decode( sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) # Stage 1: K RoPE + cache update + V cache + Q RoPE + pad. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for ki in pl.parallel(0, num_kv_heads, chunk=8): # K RoPE + cache update. kv_col = ki * head_dim @@ -256,7 +256,7 @@ def qwen3_decode( all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -290,7 +290,7 @@ def qwen3_decode( # Stage 2: QK matmul for all active sb blocks. for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -306,7 +306,7 @@ def qwen3_decode( # Stage 3: softmax for all active sb blocks. for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -331,7 +331,7 @@ def qwen3_decode( # Stage 4: SV matmul for all active sb blocks. for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -351,7 +351,7 @@ def qwen3_decode( all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) # Stage 5: online softmax accumulation and normalisation. - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0]) mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0]) li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0]) @@ -383,7 +383,7 @@ def qwen3_decode( for ob in pl.range(q_out_blocks): o0 = ob * Q_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0]) w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0]) o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) @@ -394,7 +394,7 @@ def qwen3_decode( o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) # Stage 2: Residual addition with hidden_states - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): resid = pl.cast( pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), target_type=pl.FP32, @@ -404,7 +404,7 @@ def qwen3_decode( # Stage 3: Post-attention RMSNorm post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(hidden_blocks): k0 = kb * K_CHUNK @@ -424,7 +424,7 @@ def qwen3_decode( mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16) for ob in pl.range(mlp_out_blocks): o0 = ob * MLP_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) gate_acc = pl.matmul(post_chunk_0, wg_0, out_dtype=pl.FP32) @@ -434,7 +434,7 @@ def qwen3_decode( wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) up_acc = pl.matmul(post_chunk_0, wu_0, out_dtype=pl.FP32) @@ -444,7 +444,7 @@ def qwen3_decode( wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) up_acc = pl.matmul_acc(up_acc, post_chunk, wu) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) @@ -453,7 +453,7 @@ def qwen3_decode( # Stage 7 & 8: Down projection + final residual writeback. for dob in pl.range(hidden_blocks): d0 = dob * K_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) @@ -464,7 +464,7 @@ def qwen3_decode( ) w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): out_chunk = pl.add( down_acc, pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]), diff --git a/examples/models/qwen3/qwen3_32b_decode_mixed.py b/examples/models/qwen3/qwen3_32b_decode_mixed.py index 552fdd1..e4ffc76 100644 --- a/examples/models/qwen3/qwen3_32b_decode_mixed.py +++ b/examples/models/qwen3/qwen3_32b_decode_mixed.py @@ -153,7 +153,7 @@ def qwen3_decode_layer( normed_buf = pl.create_tensor([BATCH_CFG, HIDDEN_CFG], dtype=pl.BF16) # Initialize intermediate tensors to zero so assemble generates inout. - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for ob in pl.range(Q_OUT_BLOCKS): q0 = ob * Q_OUT_CHUNK zero_1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) @@ -161,7 +161,7 @@ def qwen3_decode_layer( q_proj = pl.assemble(q_proj, zero_1, [0, q0]) attn_out = pl.assemble(attn_out, zero_1_bf, [0, q0]) normed_buf = pl.assemble(normed_buf, zero_1_bf, [0, q0]) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for ob in pl.range(KV_OUT_BLOCKS): kv0 = ob * KV_OUT_CHUNK zero_2 = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0) @@ -170,7 +170,7 @@ def qwen3_decode_layer( # Scope 1 input RMSNorm + Q/K/V projection # Stage 1: RMSNorm — two-pass over all batch tiles, results in normed_buf. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for b0 in pl.parallel(0, BATCH_CFG, BATCH_TILE, chunk=1): # Phase 1: accumulate squared sum in [1, BATCH_TILE], compute inv_rms. sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) @@ -248,7 +248,7 @@ def qwen3_decode_layer( sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # Stage 1: per-head K gather + RoPE + cache update. for ki in pl.range(NUM_KV_HEADS_CFG): kv_col = ki * HEAD_DIM_CFG @@ -295,7 +295,7 @@ def qwen3_decode_layer( # Pad Q for cube fractal alignment. q_padded = pl.create_tensor([Q_HEAD_PAD, HEAD_DIM_CFG], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # Stage 2: per-head Q gather + RoPE + pad + init accumulators. for qi in pl.range(Q_HEAD_BATCH): q_col = (q_base + qi) * HEAD_DIM_CFG @@ -332,7 +332,7 @@ def qwen3_decode_layer( cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 raw_scores_pad = pl.create_tensor([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # QK matmul: padded Q × K^T. k_tile = pl.slice( k_cache, @@ -342,7 +342,7 @@ def qwen3_decode_layer( raw_scores_pad = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) exp_padded = pl.create_tensor([Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # Softmax: slice valid rows from padded scores. scores_valid = pl.slice( raw_scores_pad, @@ -361,7 +361,7 @@ def qwen3_decode_layer( exp_padded = pl.assemble(exp_padded, exp_scores_bf16, [0, 0]) oi_tmp_pad = pl.create_tensor([Q_HEAD_PAD, HEAD_DIM_CFG], dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # SV matmul: padded exp_scores × V. v_tile = pl.slice( v_cache, @@ -370,7 +370,7 @@ def qwen3_decode_layer( ) oi_tmp_pad = pl.matmul(exp_padded, v_tile, out_dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): # Slice valid rows from padded SV result. oi_tmp = pl.slice(oi_tmp_pad, [Q_HEAD_BATCH, HEAD_DIM_CFG], [0, 0]) if sb == 0: @@ -386,7 +386,7 @@ def qwen3_decode_layer( pl.row_expand_mul(oi_tmp, beta)) mi = mi_new - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): ctx = pl.row_expand_div(oi, li) ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * HEAD_DIM_CFG]) ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope1.py b/examples/models/qwen3/qwen3_32b_decode_scope1.py index d431aab..d6e3c6f 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope1.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope1.py @@ -70,7 +70,7 @@ def qwen3_scope1( normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) # Stage 1: RMSNorm + apply weights (vector ops only). - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(hidden_blocks): k0 = kb * K_CHUNK @@ -105,7 +105,7 @@ def qwen3_scope1( for ob in pl.range(q_out_blocks): q0 = ob * Q_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) @@ -122,7 +122,7 @@ def qwen3_scope1( for ob in pl.range(kv_out_blocks): kv0 = ob * KV_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) @@ -135,7 +135,7 @@ def qwen3_scope1( k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope12.py b/examples/models/qwen3/qwen3_32b_decode_scope12.py deleted file mode 100644 index 49973f3..0000000 --- a/examples/models/qwen3/qwen3_32b_decode_scope12.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright (c) PyPTO Contributors. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ----------------------------------------------------------------------------------------------------------- -"""Qwen3-32B decode Scope 1+2 — RMSNorm + projection + RoPE + attention. - -Scope 1: - 1. RMSNorm of input hidden states - 2. Q/K/V projection via matmul - -Scope 2: - 1. K RoPE + cache write, V cache write, Q RoPE + pad - 2. QK matmul - 3. Softmax - 4. SV matmul - 5. Online-softmax accumulation + final normalisation -Intermediate q_proj/k_proj/v_proj are FP32 GM tensors between the two scopes. -""" -from __future__ import annotations - -import pypto.language as pl - -BATCH = 16 -MAX_SEQ = 4096 -NUM_HEADS = 64 -NUM_KV_HEADS = 8 -HEAD_DIM = 128 -HIDDEN = NUM_HEADS * HEAD_DIM # 8192 -KV_HIDDEN = NUM_KV_HEADS * HEAD_DIM - -EPS = 1e-6 -HIDDEN_INV = 1.0 / HIDDEN - -# Scope 1 tiling constants. -K_CHUNK = 128 -Q_OUT_CHUNK = 64 -KV_OUT_CHUNK = 64 -BATCH_TILE = 16 - -# Scope 2 tiling constants. -Q_HEAD_BATCH = 8 -Q_HEAD_PAD = 16 -SEQ_TILE = 64 -SB_BATCH = 64 - - -def build_qwen3_scope12_program( - batch: int = BATCH, - max_seq: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - num_kv_heads: int = NUM_KV_HEADS, - head_dim: int = HEAD_DIM, -): - hidden = hidden_size - kv_hidden = num_kv_heads * head_dim - hidden_blocks = hidden // K_CHUNK - q_out_blocks = hidden // Q_OUT_CHUNK - kv_out_blocks = kv_hidden // KV_OUT_CHUNK - cache_rows = batch * num_kv_heads * max_seq - half_dim = head_dim // 2 - q_per_kv = num_heads // num_kv_heads - q_groups = q_per_kv // Q_HEAD_BATCH - total_q_groups = num_kv_heads * q_groups - attn_scale = 1.0 / (head_dim ** 0.5) - max_ctx_blocks = (max_seq + SEQ_TILE - 1) // SEQ_TILE - - @pl.program - class Qwen3Scope12: - @pl.function(type=pl.FunctionType.Opaque) - def qwen3_scope12( - self, - hidden_states: pl.Tensor[[batch, hidden], pl.BF16], - input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], - wq: pl.Tensor[[hidden, hidden], pl.BF16], - wk: pl.Tensor[[hidden, kv_hidden], pl.BF16], - wv: pl.Tensor[[hidden, kv_hidden], pl.BF16], - seq_lens: pl.Tensor[[batch], pl.INT32], - rope_cos: pl.Tensor[[max_seq, head_dim], pl.FP32], - rope_sin: pl.Tensor[[max_seq, head_dim], pl.FP32], - k_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16], - v_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16], - attn_out: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]], - ) -> pl.Tensor[[batch, hidden], pl.BF16]: - # Intermediate FP32 tensors between scope 1 and scope 2. - q_proj = pl.create_tensor([batch, hidden], dtype=pl.FP32) - k_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32) - v_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32) - - # ── Scope 1: input RMSNorm + Q/K/V projection ── - for b0 in pl.range(0, batch, BATCH_TILE): - normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) - - with pl.incore(): - partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK - x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), - target_type=pl.FP32, - ) - partial_sq = pl.add( - partial_sq, - pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]), - ) - # Compute variance in [1, BATCH_TILE], then reshape to [BATCH_TILE, 1] - # for row_expand_mul broadcasting. - variance = pl.reshape( - pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), - [BATCH_TILE, 1], - ) - inv_rms = pl.recip(pl.sqrt(variance)) - - for kb in pl.range(hidden_blocks): - k0 = kb * K_CHUNK - x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), - target_type=pl.FP32, - ) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) - normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) - - for ob in pl.range(q_out_blocks): - q0 = ob * Q_OUT_CHUNK - with pl.incore(): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_b = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) - q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_b_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) - q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) - - for ob in pl.range(kv_out_blocks): - kv0 = ob * KV_OUT_CHUNK - with pl.incore(): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - with pl.incore(): - tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) - tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) - - # Padding q - all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) - with pl.incore(): - for idx in pl.range(batch * total_q_groups): - all_q_padded = pl.assemble( - all_q_padded, - pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), - [idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0], - ) - - # ── Scope 2: RoPE + KV cache update + grouped-query attention ── - for b in pl.range(batch): - ctx_len = pl.tensor.read(seq_lens, [b]) - pos = ctx_len - 1 - ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE - cos_row = pl.slice(rope_cos, [1, head_dim], [pos, 0]) - sin_row = pl.slice(rope_sin, [1, head_dim], [pos, 0]) - cos_lo = pl.slice(cos_row, [1, half_dim], [0, 0]) - cos_hi = pl.slice(cos_row, [1, half_dim], [0, half_dim]) - sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0]) - sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) - - # Stage 1: K RoPE + cache update + V cache + Q RoPE + pad. - with pl.auto_incore(): - for ki in pl.parallel(0, num_kv_heads, chunk=8): - # K RoPE + cache update. - kv_col = ki * head_dim - k_lo = pl.slice(k_proj, [1, half_dim], [b, kv_col]) - k_hi = pl.slice(k_proj, [1, half_dim], [b, kv_col + half_dim]) - rot_lo = pl.sub( - pl.col_expand_mul(k_lo, cos_lo), - pl.col_expand_mul(k_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(k_hi, cos_hi), - pl.col_expand_mul(k_lo, sin_hi), - ) - cache_row = b * num_kv_heads * max_seq + ki * max_seq + pos - k_cache = pl.assemble(k_cache, pl.cast(rot_lo, target_type=pl.BF16), [cache_row, 0]) - k_cache = pl.assemble(k_cache, pl.cast(rot_hi, target_type=pl.BF16), [cache_row, half_dim]) - # V cache update. - v_cache = pl.assemble( - v_cache, - pl.cast(pl.slice(v_proj, [1, head_dim], [b, ki * head_dim]), target_type=pl.BF16), - [cache_row, 0], - ) - # Q RoPE + pad (ki == kvh since q_groups == 1). - q_base = ki * q_per_kv - for qi in pl.range(Q_HEAD_BATCH): - q_col = (q_base + qi) * head_dim - q_lo = pl.slice(q_proj, [1, half_dim], [b, q_col]) - q_hi = pl.slice(q_proj, [1, half_dim], [b, q_col + half_dim]) - rot_lo_bf16 = pl.cast( - pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), - target_type=pl.BF16, - ) - rot_hi_bf16 = pl.cast( - pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), - target_type=pl.BF16, - ) - all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0]) - all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim]) - - attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) - for gi in pl.range(total_q_groups): - kvh = gi // q_groups - qg = gi - kvh * q_groups - q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH - q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0]) - - # Workaround - all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) - all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) - all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) - all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) - all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - all_raw_scores = pl.assemble( - all_raw_scores, - pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), - [sb * Q_HEAD_PAD, 0], - ) - all_exp_padded = pl.assemble( - all_exp_padded, - pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16), - [sb * Q_HEAD_PAD, 0], - ) - all_oi_tmp = pl.assemble( - all_oi_tmp, - pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), - [sb * Q_HEAD_PAD, 0], - ) - mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) - all_cur_mi = pl.assemble( - all_cur_mi, - pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]), - [sb * Q_HEAD_BATCH, 0], - ) - li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) - all_cur_li = pl.assemble( - all_cur_li, - pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]), - [sb * Q_HEAD_BATCH, 0], - ) - - # Stage 2: QK matmul for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - k_tile = pl.slice( - k_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) - all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) - - # Stage 3: softmax for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, - [Q_HEAD_BATCH, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], - ) - scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) - cur_li = pl.row_sum(exp_scores_fp32) - all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) - all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) - all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) - - # Stage 4: SV matmul for all active sb blocks. - for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): - for si in pl.range(SB_BATCH): - sb = sb0 + si - if sb < ctx_blocks: - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - exp_tile = pl.slice( - all_exp_padded, - [Q_HEAD_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - ) - v_tile = pl.slice( - v_cache, - [SEQ_TILE, head_dim], - [cache_row0, 0], - ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) - - # Stage 5: online softmax accumulation and normalisation. - with pl.incore(): - oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0]) - mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0]) - li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0]) - for sb in pl.range(1, ctx_blocks): - oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0]) - cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) - cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) - mi_new = pl.maximum(mi, cur_mi) - alpha = pl.exp(pl.sub(mi, mi_new)) - beta = pl.exp(pl.sub(cur_mi, mi_new)) - li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) - oi = pl.add(pl.row_expand_mul(oi, alpha), - pl.row_expand_mul(oi_tmp_valid, beta)) - mi = mi_new - ctx = pl.row_expand_div(oi, li) - ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * head_dim]) - ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16) - attn_row = pl.assemble( - attn_row, ctx_flat_bf16, [0, q_base * head_dim], - ) - - attn_out = pl.assemble(attn_out, attn_row, [b, 0]) - - return attn_out - - return Qwen3Scope12 - - -def build_tensor_specs( - batch: int = BATCH, - max_seq: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - num_kv_heads: int = NUM_KV_HEADS, - head_dim: int = HEAD_DIM, - use_max_seq: bool = False, -): - import torch - from pypto.runtime import TensorSpec - - hidden = num_heads * head_dim - kv_hidden = num_kv_heads * head_dim - cache_rows = batch * num_kv_heads * max_seq - - def init_hidden_states(): - return torch.rand(batch, hidden_size) - 0.5 - - def init_rms_weight(): - return torch.rand(1, hidden_size) - 0.5 - - def init_wq(): - return torch.rand(hidden_size, hidden_size) / hidden_size ** 0.5 - - def init_wk(): - return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 - - def init_wv(): - return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 - - def init_seq_lens(): - if use_max_seq: - return torch.full((batch,), max_seq, dtype=torch.int32) - return torch.randint(1, max_seq + 1, (batch,), dtype=torch.int32) - - def init_rope_cos(): - return torch.rand(max_seq, head_dim) - 0.5 - - def init_rope_sin(): - return torch.rand(max_seq, head_dim) - 0.5 - - def init_k_cache(): - return torch.rand(cache_rows, head_dim) - 0.5 - - def init_v_cache(): - return torch.rand(cache_rows, head_dim) - 0.5 - - return [ - TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, - init_value=init_hidden_states), - TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, - init_value=init_rms_weight), - TensorSpec("wq", [hidden_size, hidden_size], torch.bfloat16, - init_value=init_wq), - TensorSpec("wk", [hidden_size, kv_hidden], torch.bfloat16, - init_value=init_wk), - TensorSpec("wv", [hidden_size, kv_hidden], torch.bfloat16, - init_value=init_wv), - TensorSpec("seq_lens", [batch], torch.int32, init_value=init_seq_lens), - TensorSpec("rope_cos", [max_seq, head_dim], torch.float32, - init_value=init_rope_cos), - TensorSpec("rope_sin", [max_seq, head_dim], torch.float32, - init_value=init_rope_sin), - TensorSpec("k_cache", [cache_rows, head_dim], torch.bfloat16, - init_value=init_k_cache), - TensorSpec("v_cache", [cache_rows, head_dim], torch.bfloat16, - init_value=init_v_cache), - TensorSpec("attn_out", [batch, hidden], torch.bfloat16, is_output=True), - ] - - -def golden_qwen3_scope12(tensors, params): - """PyTorch reference: scope1 (RMSNorm + projection) then scope2 (attention).""" - import math - - import torch - - hidden_states = tensors["hidden_states"] - input_rms_weight = tensors["input_rms_weight"] - wq = tensors["wq"] - wk = tensors["wk"] - wv = tensors["wv"] - seq_lens = tensors["seq_lens"] - rope_cos = tensors["rope_cos"] - rope_sin = tensors["rope_sin"] - k_cache = tensors["k_cache"].clone() - v_cache = tensors["v_cache"].clone() - - batch = hidden_states.shape[0] - hidden_size = hidden_states.shape[1] - kv_hidden = wk.shape[1] - head_dim = rope_cos.shape[1] - max_seq = rope_cos.shape[0] - num_kv_heads = kv_hidden // head_dim - num_heads = hidden_size // head_dim - q_per_kv = num_heads // num_kv_heads - q_groups = q_per_kv // Q_HEAD_BATCH - half = head_dim // 2 - scale = 1.0 / math.sqrt(head_dim) - - # ── Scope 1 golden: RMSNorm + Q/K/V projection ── - q_proj = torch.zeros(batch, hidden_size, dtype=torch.float32) - k_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) - v_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) - - for b0 in range(0, batch, BATCH_TILE): - b_end = min(b0 + BATCH_TILE, batch) - x_tile = hidden_states[b0:b_end, :].float() - - sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32) - for k0 in range(0, hidden_size, K_CHUNK): - x_chunk = x_tile[:, k0:k0 + K_CHUNK] - sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) - variance = sq_sum / hidden_size + EPS - rms = torch.sqrt(variance) - normed = (x_tile / rms * input_rms_weight.float()).bfloat16() - - q_proj[b0:b_end, :] = (normed.float() @ wq.float()).float() - k_proj[b0:b_end, :] = (normed.float() @ wk.float()).float() - v_proj[b0:b_end, :] = (normed.float() @ wv.float()).float() - - # ── Scope 2 golden: RoPE + cache update + attention ── - attn_out = torch.zeros(batch, hidden_size, dtype=torch.float32) - - for b in range(batch): - ctx_len = seq_lens[b].item() - pos = ctx_len - 1 - ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE - - cos_row = rope_cos[pos : pos + 1, :] - sin_row = rope_sin[pos : pos + 1, :] - cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] - sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:] - - k_heads = k_proj[b].view(num_kv_heads, head_dim) - k_lo_h, k_hi_h = k_heads[:, :half], k_heads[:, half:] - k_rot = torch.cat([k_lo_h * cos_lo - k_hi_h * sin_lo, k_hi_h * cos_hi + k_lo_h * sin_hi], dim=-1) - - for ki in range(num_kv_heads): - cr = b * num_kv_heads * max_seq + ki * max_seq + pos - k_cache[cr, :] = k_rot[ki].to(torch.bfloat16) - v_cache[cr, :] = v_proj[b, ki * head_dim : (ki + 1) * head_dim].to(torch.bfloat16) - - q_heads = q_proj[b].view(num_heads, head_dim) - q_lo_h, q_hi_h = q_heads[:, :half], q_heads[:, half:] - q_rot = torch.cat([q_lo_h * cos_lo - q_hi_h * sin_lo, q_hi_h * cos_hi + q_lo_h * sin_hi], dim=-1) - - for kvh in range(num_kv_heads): - for qg in range(q_groups): - q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH - q_grp_bf16 = q_rot[q_base : q_base + Q_HEAD_BATCH, :].to(torch.bfloat16) - - oi = torch.zeros(Q_HEAD_BATCH, head_dim, dtype=torch.float32) - li = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) - mi = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) - - for sb in range(ctx_blocks): - s0 = sb * SEQ_TILE - valid_len = min(SEQ_TILE, ctx_len - s0) - cb = b * num_kv_heads * max_seq + kvh * max_seq + s0 - - k_tile = k_cache[cb : cb + SEQ_TILE, :] - v_tile = v_cache[cb : cb + SEQ_TILE, :] - - raw_scores = q_grp_bf16.float() @ k_tile.float().T - if valid_len < SEQ_TILE: - raw_scores[:, valid_len:] = torch.finfo(torch.float32).min - scores = raw_scores * scale - - cur_mi = scores.max(dim=-1, keepdim=True).values - exp_scores = torch.exp(scores - cur_mi) - exp_scores_bf16 = exp_scores.to(torch.bfloat16) - cur_li = exp_scores_bf16.float().sum(dim=-1, keepdim=True) - - oi_tmp = exp_scores_bf16.float() @ v_tile.float() - - if sb == 0: - oi = oi_tmp - li = cur_li - mi = cur_mi - else: - mi_new = torch.maximum(mi, cur_mi) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(cur_mi - mi_new) - li = alpha * li + beta * cur_li - oi = oi * alpha + oi_tmp * beta - mi = mi_new - - ctx = oi / li - for qi in range(Q_HEAD_BATCH): - qh = q_base + qi - attn_out[b, qh * head_dim : (qh + 1) * head_dim] = ctx[qi] - - tensors["attn_out"][:] = attn_out.to(torch.bfloat16) - - -def compile_and_run( - batch: int = BATCH, - max_seq: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - num_kv_heads: int = NUM_KV_HEADS, - head_dim: int = HEAD_DIM, - use_max_seq: bool = False, - platform: str = "a5", - device_id: int = 0, - dump_passes: bool = True, - runtime_profiling: bool = False, -): - from pypto.backend import BackendType - from pypto.ir.pass_manager import OptimizationStrategy - from pypto.runtime import RunConfig, run - - backend = BackendType.Ascend950 if platform.startswith("a5") else BackendType.Ascend910B - - program = build_qwen3_scope12_program( - batch=batch, - max_seq=max_seq, - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - ) - tensor_specs = build_tensor_specs( - batch=batch, - max_seq=max_seq, - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - use_max_seq=use_max_seq, - ) - - result = run( - program=program, - tensor_specs=tensor_specs, - golden=golden_qwen3_scope12, - config=RunConfig( - platform=platform, - device_id=device_id, - rtol=1e-3, - atol=1e-3, - strategy=OptimizationStrategy.Default, - dump_passes=dump_passes, - backend_type=backend, - runtime_profiling=runtime_profiling, - ), - ) - return result - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--platform", type=str, default="a5", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) - parser.add_argument("-d", "--device", type=int, default=0) - parser.add_argument("--runtime-profiling", action="store_true", default=False) - parser.add_argument("--max-seq", action="store_true", default=False, - help="set all seq_lens to MAX_SEQ (default: random)") - args = parser.parse_args() - - result = compile_and_run( - platform=args.platform, - device_id=args.device, - use_max_seq=args.max_seq, - runtime_profiling=args.runtime_profiling, - ) - if not result.passed: - if result.error: - print(f"Result: {result.error}") - raise SystemExit(1) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope2.py b/examples/models/qwen3/qwen3_32b_decode_scope2.py index 45b7539..9b3d636 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope2.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope2.py @@ -66,7 +66,7 @@ def qwen3_scope2( ) -> pl.Tensor[[batch, hidden], pl.BF16]: # Padding q all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for idx in pl.range(batch * total_q_groups): all_q_padded = pl.assemble( all_q_padded, @@ -86,7 +86,7 @@ def qwen3_scope2( sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) # Stage 1: K RoPE + cache update + V cache + Q RoPE + pad. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for ki in pl.parallel(0, num_kv_heads, chunk=8): # K RoPE + cache update. kv_col = ki * head_dim @@ -157,7 +157,7 @@ def qwen3_scope2( all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -173,7 +173,7 @@ def qwen3_scope2( # Stage 3: softmax for all active sb blocks. for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -198,7 +198,7 @@ def qwen3_scope2( # Stage 4: SV matmul for all active sb blocks. for sb0 in pl.range(0, ctx_blocks, SB_BATCH): - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for si in pl.range(SB_BATCH): sb = sb0 + si if sb < ctx_blocks: @@ -218,7 +218,7 @@ def qwen3_scope2( all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) # Stage 5: online softmax accumulation and normalisation. - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0]) mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0]) li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0]) diff --git a/examples/models/qwen3/qwen3_32b_decode_scope3.py b/examples/models/qwen3/qwen3_32b_decode_scope3.py index dfa3b0f..9eb453c 100644 --- a/examples/models/qwen3/qwen3_32b_decode_scope3.py +++ b/examples/models/qwen3/qwen3_32b_decode_scope3.py @@ -67,7 +67,7 @@ def scope3( for ob in pl.range(Q_OUT_BLOCKS): o0 = ob * Q_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0]) w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0]) o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) @@ -78,7 +78,7 @@ def scope3( o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) # Stage 1: Residual addition with hidden_states - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): resid = pl.cast( pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), target_type=pl.FP32, @@ -88,7 +88,7 @@ def scope3( # Stage 2: Post-attention RMSNorm post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK @@ -108,7 +108,7 @@ def scope3( mlp_tile = pl.create_tensor([BATCH_TILE, INTER_CFG], dtype=pl.BF16) for ob in pl.range(MLP_OUT_BLOCKS): o0 = ob * MLP_OUT_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) gate_acc = pl.matmul(post_chunk_0, wg_0, out_dtype=pl.FP32) @@ -118,7 +118,7 @@ def scope3( wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) up_acc = pl.matmul(post_chunk_0, wu_0, out_dtype=pl.FP32) @@ -128,7 +128,7 @@ def scope3( wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) up_acc = pl.matmul_acc(up_acc, post_chunk, wu) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) @@ -137,7 +137,7 @@ def scope3( # Stage 6 & 7: Down projection + final residual writeback. for dob in pl.range(HIDDEN_BLOCKS): d0 = dob * K_CHUNK - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) @@ -148,7 +148,7 @@ def scope3( ) w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): out_chunk = pl.add( down_acc, pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]), diff --git a/examples/models/qwen3/qwen3_32b_prefill.py b/examples/models/qwen3/qwen3_32b_prefill.py index c657e9a..af377fb 100644 --- a/examples/models/qwen3/qwen3_32b_prefill.py +++ b/examples/models/qwen3/qwen3_32b_prefill.py @@ -107,7 +107,7 @@ def qwen3_prefill_layer( # Uses full [TOK_TILE, ...] views from hidden_states even on the # tail tile — padding rows map to allocated-but-unused MAX_SEQ # slots, keeping every GM view >= 512 B aligned. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sq_sum = pl.create_tensor([TOK_TILE, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) for kb in pl.range(HIDDEN_BLOCKS): @@ -191,7 +191,7 @@ def qwen3_prefill_layer( sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) attn_row = pl.mul(attn_row, 0.0) # First loop: update KV cache (separate from attention to avoid tensor view issues) @@ -301,7 +301,7 @@ def qwen3_prefill_layer( attn_tile = pl.assemble(attn_tile, attn_row, [ti, 0]) # Scope 3: output projection + residual + post-rms + MLP + residual. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): resid1_tile = pl.create_tensor([TOK_TILE, HIDDEN_CFG], dtype=pl.FP32) for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): o0 = ob * Q_OUT_CHUNK diff --git a/examples/models/qwen3/qwen3_32b_prefill_tilelet.py b/examples/models/qwen3/qwen3_32b_prefill_tilelet.py index 309ba9e..1d02c67 100644 --- a/examples/models/qwen3/qwen3_32b_prefill_tilelet.py +++ b/examples/models/qwen3/qwen3_32b_prefill_tilelet.py @@ -110,7 +110,7 @@ def qwen3_prefill_layer( # Uses full [TOK_TILE, ...] views from hidden_states even on the # tail tile — padding rows map to allocated-but-unused MAX_SEQ # slots, keeping every GM view >= 512 B aligned. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): sq_sum = pl.create_tensor([TOK_TILE, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) for kb in pl.range(HIDDEN_BLOCKS): @@ -195,7 +195,7 @@ def qwen3_prefill_layer( sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) k_group = pl.create_tensor([NUM_KV_HEADS_CFG, HEAD_DIM_CFG], dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for ki in pl.range(NUM_KV_HEADS_CFG): kv_col = ki * HEAD_DIM_CFG k_group = pl.assemble( @@ -205,7 +205,7 @@ def qwen3_prefill_layer( [ki, 0], ) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): k_lo = pl.slice(k_group, [NUM_KV_HEADS_CFG, HEAD_DIM_CFG // 2], [0, 0]) k_hi = pl.slice(k_group, [NUM_KV_HEADS_CFG, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) @@ -230,7 +230,7 @@ def qwen3_prefill_layer( ) attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): attn_row = pl.mul(attn_row, 0.0) for gi in pl.parallel(0, TOTAL_Q_GROUPS, 1): @@ -239,7 +239,7 @@ def qwen3_prefill_layer( q_base = kvh * Q_PER_KV_CFG + qg * Q_HEAD_BATCH q_group = pl.create_tensor([Q_HEAD_BATCH, HEAD_DIM_CFG], dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): for qi in pl.range(Q_HEAD_BATCH): q_col = (q_base + qi) * HEAD_DIM_CFG q_group = pl.assemble( @@ -249,7 +249,7 @@ def qwen3_prefill_layer( [qi, 0], ) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): q_lo = pl.slice(q_group, [Q_HEAD_BATCH, HEAD_DIM_CFG // 2], [0, 0]) q_hi = pl.slice(q_group, [Q_HEAD_BATCH, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) @@ -273,7 +273,7 @@ def qwen3_prefill_layer( valid_len = pl.min(SEQ_TILE, ctx_len - s0) cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): k_tile = pl.slice( k_cache, [SEQ_TILE, HEAD_DIM_CFG], @@ -282,7 +282,7 @@ def qwen3_prefill_layer( ) raw_scores = pl.matmul(q_rot_bf16, k_tile, b_trans=True, out_dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): scores = pl.mul(raw_scores, ATTN_SCALE) scores_valid = pl.slice( scores, @@ -298,7 +298,7 @@ def qwen3_prefill_layer( exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) exp_pad_bf16 = pl.cast(exp_pad, target_type=pl.BF16) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): v_tile = pl.slice( v_cache, [SEQ_TILE, HEAD_DIM_CFG], @@ -307,7 +307,7 @@ def qwen3_prefill_layer( ) oi_tmp = pl.matmul(exp_pad_bf16, v_tile, out_dtype=pl.FP32) - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): if sb == 0: oi = oi_tmp li = cur_li @@ -321,7 +321,7 @@ def qwen3_prefill_layer( pl.row_expand_mul(oi_tmp, beta)) mi = mi_new - with pl.incore(): + with pl.at(level=pl.Level.CORE_GROUP): ctx = pl.row_expand_div(oi, li) for qi in pl.range(Q_HEAD_BATCH): q_col = (q_base + qi) * HEAD_DIM_CFG @@ -334,7 +334,7 @@ def qwen3_prefill_layer( attn_tile = pl.assemble(attn_tile, attn_row, [ti, 0]) # Scope 3: output projection + residual + post-rms + MLP + residual. - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): resid1_tile = pl.create_tensor([TOK_TILE, HIDDEN_CFG], dtype=pl.FP32) for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): o0 = ob * Q_OUT_CHUNK diff --git a/examples/models/qwen3/qwen3_32b_training_forward_and_backward.py b/examples/models/qwen3/qwen3_32b_training_forward_and_backward.py index c0f2368..cdc2c69 100644 --- a/examples/models/qwen3/qwen3_32b_training_forward_and_backward.py +++ b/examples/models/qwen3/qwen3_32b_training_forward_and_backward.py @@ -121,7 +121,7 @@ def qwen3_32b_training_forward_and_backward_layer( pl.Tensor[[BATCH_CFG, MAX_SEQ_CFG, HIDDEN_CFG], pl.BF16], pl.Tensor[[1], pl.FP32], ]: - with pl.auto_incore(): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): grad_wq = pl.mul(grad_wq, 0.0) grad_wk = pl.mul(grad_wk, 0.0) grad_wv = pl.mul(grad_wv, 0.0)