Skip to content

[A2A3/A5 E2E] qwen3_32b_prefill_tilelet — Full Prefill Layer Validation #71

@zhangqi-chen

Description

@zhangqi-chen

Objective

End-to-end validation of examples/models/qwen3/qwen3_32b_prefill_tilelet.py on both the A2A3 and A5 platforms, covering all three pl.auto_incore() / pl.incore() scopes of the prefill layer.

Each session in the batch has a variable input sequence length (up to MAX_SEQ=4096). Tokens are processed in TOK_TILE=4 chunks; the program only computes valid tokens per session. Tensors are padded to MAX_SEQ on the sequence axis; padding rows are harmless.

Scopes

Scope 1 — Input RMSNorm + Q/K/V Projection (lines ~113–177)

  • Input RMSNorm: per-row squared-sum accumulated in [TOK_TILE=4, K_CHUNK=128] FP32 chunks via pl.row_sum, followed by pl.rsqrt to compute inv_rms (shape [TOK_TILE, 1]).
  • Q projection: for each Q_OUT_BLOCKS=80 output block, matmul over all HIDDEN_BLOCKS=40 K-chunks using pl.add(q_acc, pl.matmul(...)) pattern, assembled to q_proj_tile (shape [TOK_TILE, HIDDEN] BF16). Parallelised with chunk=8.
  • K/V projection: fused K+V in the same inner loop, each KV_OUT_BLOCKS=8 output block accumulates k_acc and v_acc, assembled to k_proj_tile / v_proj_tile (shape [TOK_TILE, KV_HIDDEN] BF16). Parallelised with chunk=8.
  • 3D → 2D reshape: hidden_states is 3D [BATCH, MAX_SEQ, HIDDEN]; slices are [1, TOK_TILE, K_CHUNK] with valid_shape=[1, valid_tok, K_CHUNK], then pl.reshape to [TOK_TILE, K_CHUNK] for 2D matmul.
  • Scope style: pl.auto_incore() — compiler decides incore/orchestration boundary.

Key tiling constants: TOK_TILE=4, K_CHUNK=128, Q_OUT_CHUNK=64, KV_OUT_CHUNK=64.

Scope 2 — RoPE + KV Cache Update + Causal Attention (lines ~184–334)

  • Per-token iteration: for ti in pl.range(valid_tok) — each token is processed individually (causal; context length = pos + 1).
  • K gather + RoPE: explicit pl.incore() gathers K heads from k_proj_tile into k_group (shape [NUM_KV_HEADS=8, HEAD_DIM=128] FP32), then applies RoPE rotation via pl.concat(rot_lo, rot_hi) and writes to k_cache / v_cache.
  • Q gather + RoPE: per attention group, gathers Q_HEAD_BATCH=4 Q heads from q_proj_tile, applies RoPE, produces q_rot_bf16.
  • Causal attention (online softmax): for each ctx_blocks KV tile ([SEQ_TILE=64, HEAD_DIM=128] BF16 = 16 KB = TILE MAX with valid_shape), four separate pl.incore() stages:
    1. QK matmul → raw_scores
    2. Scale → row_maxexprow_sum → zero-pad → BF16 cast
    3. SV matmul → oi_tmp
    4. Online rescale (flash-attention style mi/li/oi update)
  • Result assembly: row_expand_div(oi, li), scatter per-head results into attn_row, then assemble into attn_tile.
  • Scope style: explicit pl.incore() blocks — multiple small kernels per token.

Key tiling constants: Q_HEAD_BATCH=4, SEQ_TILE=64, HEAD_DIM=128, Q_GROUPS=2, TOTAL_Q_GROUPS=16.

Scope 3 — Output Projection + Post-RMSNorm + MLP + Residual (lines ~337–421)

  • Output projection: matmul attn_tile × wo accumulated in [TOK_TILE=4, Q_OUT_CHUNK=64] FP32 tiles via pl.add(o_acc, pl.matmul(...)) pattern; first residual add (hidden_states + o_proj) assembled into resid1_tile. Parallelised with chunk=8.
  • Post RMSNorm: per-row squared-sum over resid1_tile, pl.rsqrt, gamma-scaled, assembled into post_norm_tile (BF16).
  • MLP gate/up projections: for each MLP_OUT_BLOCKS=400 output block, accumulate gate_acc and up_acc via matmul over all HIDDEN_BLOCKS; apply SiLU (gate × sigmoid(gate) × up).
  • Down projection: accumulate w_down matmul result into down_proj_tile ([TOK_TILE, HIDDEN] FP32) in [MLP_OUT_CHUNK=64, K_CHUNK=128] BF16 tiles = 16 KB = TILE MAX. Inner loop parallelised with chunk=4.
  • Second residual add: down_proj + resid1 cast to BF16 and assembled into 3D output tensor out at [b, p0, o0].
  • Scope style: pl.auto_incore() — compiler decides incore/orchestration boundary.

Key tiling constants: TOK_TILE=4, K_CHUNK=128, Q_OUT_CHUNK=64, MLP_OUT_CHUNK=64.

TILELET / TILE Budget

Vector TILELET budget (2 KB = 2048 B, FP32 = 4 B/elem):
  [TOK_TILE, K_CHUNK]       FP32 = [4,128] × 4 =  2048 B = 2 KB  ✓ MAX
  [TOK_TILE, Q_OUT_CHUNK]   FP32 = [4, 64] × 4 =  1024 B = 1 KB  (50%)
  [TOK_TILE, KV_OUT_CHUNK]  FP32 = [4, 64] × 4 =  1024 B = 1 KB  (50%)
  [TOK_TILE, MLP_OUT_CHUNK] FP32 = [4, 64] × 4 =  1024 B = 1 KB  (50%)
  [Q_HEAD_BATCH, HEAD_DIM]  FP32 = [4,128] × 4 =  2048 B = 2 KB  ✓ MAX (attn)
  [Q_HEAD_BATCH, SEQ_TILE]  FP32 = [4, 64] × 4 =  1024 B = 1 KB  (attn scores)
  [NUM_KV_HEADS, HEAD_DIM]  FP32 = [8,128] × 4 =  4096 B = 4 KB  (K RoPE, 2×TILELET)

Cube TILE budget (16 KB = 16384 B, BF16 = 2 B/elem):
  [K_CHUNK, Q_OUT_CHUNK]    BF16 = [128, 64] × 2 = 16384 B = 16 KB ✓ MAX
  [K_CHUNK, KV_OUT_CHUNK]   BF16 = [128, 64] × 2 = 16384 B = 16 KB ✓ MAX
  [K_CHUNK, MLP_OUT_CHUNK]  BF16 = [128, 64] × 2 = 16384 B = 16 KB ✓ MAX
  [SEQ_TILE, HEAD_DIM]      BF16 = [ 64,128] × 2 = 16384 B = 16 KB ✓ MAX (attn)
  [MLP_OUT_CHUNK, K_CHUNK]  BF16 = [ 64,128] × 2 = 16384 B = 16 KB ✓ MAX (down proj)

Platform Targets

Platform Status Notes
A2A3 TBD Default platform (compile_and_run defaults to a2a3)
A5 TBD BackendType.Ascend950; needs --platform a5

File

examples/models/qwen3/qwen3_32b_prefill_tilelet.py

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions