-
Notifications
You must be signed in to change notification settings - Fork 19
Qwen3 Scope3 refactor with performance improvements #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |||||
| import pypto.language as pl | ||||||
|
|
||||||
| BATCH = 16 | ||||||
| HIDDEN = 5120 | ||||||
| HIDDEN = 8192 | ||||||
| INTERMEDIATE = 25600 | ||||||
|
|
||||||
| EPS = 1e-6 | ||||||
|
|
@@ -63,14 +63,14 @@ def scope3( | |||||
| for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): | ||||||
| resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) | ||||||
|
|
||||||
| # Stage 1: Initialize resid1_tile accumulator in parallel. | ||||||
| # Stage 0: Initialize resid1_tile accumulator in parallel. | ||||||
| with pl.auto_incore(): | ||||||
| for ob in pl.parallel(0, Q_OUT_BLOCKS, chunk=8): | ||||||
| o0 = ob * Q_OUT_CHUNK | ||||||
| zero_resid1 = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) | ||||||
| resid1_tile = pl.assemble(resid1_tile, zero_resid1, [0, o0]) | ||||||
|
|
||||||
| # Stage 2: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. | ||||||
| # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. | ||||||
| for ob in pl.range(Q_OUT_BLOCKS): | ||||||
| o0 = ob * Q_OUT_CHUNK | ||||||
|
|
||||||
|
|
@@ -84,21 +84,18 @@ def scope3( | |||||
| w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) | ||||||
| o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) | ||||||
|
|
||||||
| resid1_tile = pl.assemble(resid1_tile, o_acc, [0, o0]) | ||||||
|
|
||||||
| # Stage 2: Residual addition with hidden_states | ||||||
| with pl.incore(): | ||||||
| resid = pl.cast( | ||||||
| pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), | ||||||
| target_type=pl.FP32, | ||||||
| ) | ||||||
| mm_out = pl.slice(resid1_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]) | ||||||
| add_resid = pl.add(mm_out, resid) | ||||||
| resid1_tile = pl.assemble(resid1_tile, add_resid, [0, o0]) | ||||||
| resid_sum = pl.add(o_acc, resid) | ||||||
| resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) | ||||||
|
|
||||||
| # Stage 3 & 4 & 5: Post-attention RMSNorm: compute inv_rms over resid1_tile + normalize + initialize down_proj_tile accumulator. | ||||||
| # Stage 3: Post-attention RMSNorm | ||||||
| post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) | ||||||
| down_proj_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) | ||||||
| with pl.auto_incore(): | ||||||
| with pl.incore(): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching from
Suggested change
|
||||||
| sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) | ||||||
| for kb in pl.range(HIDDEN_BLOCKS): | ||||||
| k0 = kb * K_CHUNK | ||||||
|
|
@@ -111,16 +108,23 @@ def scope3( | |||||
| x_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, k0]) | ||||||
| gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0]) | ||||||
| normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, pl.reshape(inv_rms, [BATCH_TILE, 1])), gamma) | ||||||
| post_norm_tile = pl.assemble( | ||||||
| post_norm_tile, pl.cast(normed, target_type=pl.BF16), [0, k0] | ||||||
| normed_bf16 = pl.cast(normed, target_type=pl.BF16) | ||||||
| post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) | ||||||
|
|
||||||
| # Stage 4: Initialize mlp_tile | ||||||
| mlp_tile = pl.create_tensor([BATCH_TILE, INTER_CFG], dtype=pl.BF16) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| with pl.auto_incore(): | ||||||
| for ob in pl.parallel(0, MLP_OUT_BLOCKS, chunk=8): | ||||||
| o0 = ob * MLP_OUT_CHUNK | ||||||
| zero_mlp = pl.cast( | ||||||
| pl.full([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32, value=0.0), | ||||||
| target_type=pl.BF16, | ||||||
| ) | ||||||
| down_zero_chunk = pl.full([BATCH_TILE, K_CHUNK], dtype=pl.FP32, value=0.0) | ||||||
| down_proj_tile = pl.assemble(down_proj_tile, down_zero_chunk, [0, k0]) | ||||||
| mlp_tile = pl.assemble(mlp_tile, zero_mlp, [0, o0]) | ||||||
|
|
||||||
| # Stage 6: MLP: gate/up projections + SiLU. | ||||||
| # Stage 5 & 6 & 7: MLP: gate/up projections + SiLU. | ||||||
| for ob in pl.range(MLP_OUT_BLOCKS): | ||||||
| o0 = ob * MLP_OUT_CHUNK | ||||||
| # Stage 6a: MLP: gate projections | ||||||
| with pl.incore(): | ||||||
| post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||||||
| wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) | ||||||
|
|
@@ -131,7 +135,6 @@ def scope3( | |||||
| wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) | ||||||
| gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) | ||||||
|
|
||||||
| # Stage 6b: MLP: up projections | ||||||
| with pl.incore(): | ||||||
| post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||||||
| wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) | ||||||
|
|
@@ -142,34 +145,33 @@ def scope3( | |||||
| wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) | ||||||
| up_acc = pl.matmul_acc(up_acc, post_chunk, wu) | ||||||
|
|
||||||
| # Stage 6c: MLP: silu | ||||||
| with pl.auto_incore(): | ||||||
| sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) | ||||||
| mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) | ||||||
| mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) | ||||||
| mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) | ||||||
|
|
||||||
| # Stage 7: Down projection: accumulate in parallel. | ||||||
| for dob in pl.range(HIDDEN_BLOCKS): | ||||||
| d0 = dob * K_CHUNK | ||||||
| with pl.incore(): | ||||||
| # Stage 8 & 9: Down projection + final residual writeback. | ||||||
| for dob in pl.range(HIDDEN_BLOCKS): | ||||||
| d0 = dob * K_CHUNK | ||||||
| with pl.incore(): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
| mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) | ||||||
| w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) | ||||||
| down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) | ||||||
| for ob in pl.range(1, MLP_OUT_BLOCKS): | ||||||
| o0 = ob * MLP_OUT_CHUNK | ||||||
| down_mlp_chunk_bf16 = pl.slice( | ||||||
| mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, o0] | ||||||
| ) | ||||||
| w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) | ||||||
| down_next = pl.matmul(mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) | ||||||
|
|
||||||
| with pl.incore(): | ||||||
| down_prev = pl.slice(down_proj_tile, [BATCH_TILE, K_CHUNK], [0, d0]) | ||||||
| accum = pl.add(down_prev, down_next) | ||||||
| down_proj_tile = pl.assemble(down_proj_tile, accum, [0, d0]) | ||||||
|
|
||||||
| # Stage 8: Final residual: down_proj + resid1, write to output. | ||||||
| for ob in pl.range(HIDDEN_BLOCKS): | ||||||
| o0 = ob * K_CHUNK | ||||||
| down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) | ||||||
| with pl.incore(): | ||||||
| down_acc = pl.add( | ||||||
| pl.slice(down_proj_tile, [BATCH_TILE, K_CHUNK], [0, o0]), | ||||||
| pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, o0]), | ||||||
| out_chunk = pl.add( | ||||||
| down_acc, | ||||||
| pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]), | ||||||
| ) | ||||||
| out_f32 = pl.cast(down_acc, target_type=pl.BF16) | ||||||
| out = pl.assemble(out, out_f32, [b0, o0]) | ||||||
| out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) | ||||||
| out = pl.assemble(out, out_chunk_cast, [b0, d0]) | ||||||
|
|
||||||
| return out | ||||||
|
|
||||||
|
|
@@ -195,7 +197,6 @@ def golden(tensors: dict, params: dict | None = None) -> None: | |||||
| w_up = tensors["w_up"] # [H, I], BF16 | ||||||
| w_down = tensors["w_down"] # [I, H], BF16 | ||||||
|
|
||||||
| H = attn_out.shape[1] | ||||||
| eps = 1e-6 | ||||||
|
|
||||||
| # 1. Output projection (BF16 inputs, FP32 accumulation) + residual. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Increasing$16 \times 8192 \times 4 = 512$ KB, which exceeds the 248 KB limit mentioned in the
HIDDENto 8192 while maintainingBATCH_TILE = 16(line 33) will likely cause a memory overflow in the on-chip vector buffer. For instance,resid1_tile(FP32) now requiresqwen3_tilelet.mddocumentation. This is consistent with the507018runtime error reported in the PR description. Consider reducingBATCH_TILEto 4 to stay within hardware constraints, as seen in the mainqwen3_32b_decode.pyfile.