-
Notifications
You must be signed in to change notification settings - Fork 21
Refactor layer_norm to two-pass column-chunking pattern #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,13 +6,19 @@ | |||||||||||||
| # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||||||||||||||
| # See LICENSE in the root of the software repository for the full text of the License. | ||||||||||||||
| # ----------------------------------------------------------------------------------------------------------- | ||||||||||||||
| """LayerNorm — full layer normalization with row-only tiling. | ||||||||||||||
| """LayerNorm — full layer normalization with row + column tiling. | ||||||||||||||
|
|
||||||||||||||
| output[r, c] = (x[r, c] - mean(x[r, :])) / sqrt(var(x[r, :]) + eps) * gamma[c] + beta[c] | ||||||||||||||
|
|
||||||||||||||
| Rows are parallelised via pl.parallel (batch dimension). | ||||||||||||||
| The hidden dimension is loaded in full per tile (no column chunking), | ||||||||||||||
| keeping the kernel simple and single-pass friendly. | ||||||||||||||
| The hidden dimension is chunked with pl.range to accumulate the | ||||||||||||||
| sum and squared-sum reductions, then a second pass centres, normalises, | ||||||||||||||
| and applies gamma/beta. | ||||||||||||||
|
|
||||||||||||||
| This two-pass column-chunking pattern follows the same approach used | ||||||||||||||
| by rms_norm.py and the production LLM kernels (qwen3/deepseek). | ||||||||||||||
| Variance is computed via E[x^2] - E[x]^2 to avoid materialising the | ||||||||||||||
| centred tensor during the accumulation pass. | ||||||||||||||
|
|
||||||||||||||
| Input and output are FP32; gamma and beta are [1, hidden] weight vectors. | ||||||||||||||
| """ | ||||||||||||||
|
|
@@ -21,17 +27,20 @@ | |||||||||||||
| import pypto.language as pl | ||||||||||||||
|
|
||||||||||||||
| ROWS = 512 # batch / sequence length | ||||||||||||||
| HIDDEN = 256 # hidden dimension (normalised axis, fits in one tile) | ||||||||||||||
| HIDDEN = 512 # hidden dimension (normalised axis) | ||||||||||||||
| ROW_CHUNK = 32 # rows per parallel tile | ||||||||||||||
| HIDDEN_CHUNK = 64 # columns per sequential chunk | ||||||||||||||
| EPS = 1e-5 | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def build_layer_norm_program( | ||||||||||||||
| rows: int = ROWS, | ||||||||||||||
| hidden: int = HIDDEN, | ||||||||||||||
| row_chunk: int = ROW_CHUNK, | ||||||||||||||
| hidden_chunk: int = HIDDEN_CHUNK, | ||||||||||||||
| eps: float = EPS, | ||||||||||||||
| ): | ||||||||||||||
| hidden_blocks = hidden // hidden_chunk | ||||||||||||||
| hidden_inv = 1.0 / hidden | ||||||||||||||
|
|
||||||||||||||
| @pl.program | ||||||||||||||
|
|
@@ -46,31 +55,49 @@ def layer_norm( | |||||||||||||
| ) -> pl.Tensor[[rows, hidden], pl.FP32]: | ||||||||||||||
| with pl.auto_incore(): | ||||||||||||||
| for r in pl.parallel(0, rows, row_chunk, chunk=1): | ||||||||||||||
| tile_x = pl.slice(x, [row_chunk, hidden], [r, 0]) | ||||||||||||||
| gamma_tile = pl.slice(gamma, [1, hidden], [0, 0]) | ||||||||||||||
| beta_tile = pl.slice(beta, [1, hidden], [0, 0]) | ||||||||||||||
|
|
||||||||||||||
| # Step 1: row mean — pre-scale before row_sum, no reshape | ||||||||||||||
| mean = pl.row_sum(pl.mul(tile_x, hidden_inv)) | ||||||||||||||
|
|
||||||||||||||
| # Step 2: row variance + eps — pre-scale and pre-add | ||||||||||||||
| centred = pl.row_expand_sub(tile_x, mean) | ||||||||||||||
| var_eps = pl.row_sum( | ||||||||||||||
| pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv) | ||||||||||||||
| # Pass 1: accumulate sum(x) and sum(x^2) across hidden chunks | ||||||||||||||
| # row_sum produces [row_chunk, 1] col_major; accumulate | ||||||||||||||
| # in [1, row_chunk] for scalar ops (same as rms_norm). | ||||||||||||||
| x_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | ||||||||||||||
| x_sum = pl.mul(x_sum, 0.0) | ||||||||||||||
| sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | ||||||||||||||
| sq_sum = pl.mul(sq_sum, 0.0) | ||||||||||||||
|
Comment on lines
+61
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The initialization of
Suggested change
|
||||||||||||||
| for hb in pl.range(hidden_blocks): | ||||||||||||||
| h0 = hb * hidden_chunk | ||||||||||||||
| x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) | ||||||||||||||
| x_sum = pl.add( | ||||||||||||||
| x_sum, pl.reshape(pl.row_sum(x_chunk), [1, row_chunk]) | ||||||||||||||
| ) | ||||||||||||||
| sq_sum = pl.add( | ||||||||||||||
| sq_sum, | ||||||||||||||
| pl.reshape( | ||||||||||||||
| pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, row_chunk] | ||||||||||||||
| ), | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # mean and inv_std via E[x^2] - E[x]^2 | ||||||||||||||
| mean_T = pl.mul(x_sum, hidden_inv) | ||||||||||||||
| var_T = pl.sub( | ||||||||||||||
| pl.mul(sq_sum, hidden_inv), pl.mul(mean_T, mean_T) | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # Step 3: normalise — single reshape pair for sqrt | ||||||||||||||
| std = pl.reshape( | ||||||||||||||
| pl.sqrt(pl.reshape(var_eps, [1, row_chunk])), | ||||||||||||||
| [row_chunk, 1], | ||||||||||||||
| ) | ||||||||||||||
| normed = pl.row_expand_div(centred, std) | ||||||||||||||
|
|
||||||||||||||
| # Step 4: apply gamma scale and beta offset | ||||||||||||||
| scaled = pl.col_expand_mul(normed, gamma_tile) | ||||||||||||||
| ones = pl.add(pl.sub(tile_x, tile_x), 1.0) | ||||||||||||||
| result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile)) | ||||||||||||||
| y = pl.assemble(y, result, [r, 0]) | ||||||||||||||
| inv_std_T = pl.rsqrt(pl.add(var_T, eps)) | ||||||||||||||
| mean = pl.reshape(mean_T, [row_chunk, 1]) | ||||||||||||||
| inv_std = pl.reshape(inv_std_T, [row_chunk, 1]) | ||||||||||||||
|
|
||||||||||||||
| # Pass 2: centre, normalise, apply gamma/beta | ||||||||||||||
| for hb in pl.range(hidden_blocks): | ||||||||||||||
| h0 = hb * hidden_chunk | ||||||||||||||
| x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) | ||||||||||||||
| gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0]) | ||||||||||||||
| beta_chunk = pl.slice(beta, [1, hidden_chunk], [0, h0]) | ||||||||||||||
| centred = pl.row_expand_sub(x_chunk, mean) | ||||||||||||||
| normed = pl.row_expand_mul(centred, inv_std) | ||||||||||||||
| scaled = pl.col_expand_mul(normed, gamma_chunk) | ||||||||||||||
| ones = pl.add(pl.sub(x_chunk, x_chunk), 1.0) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||||||||||
| result = pl.add( | ||||||||||||||
| scaled, pl.col_expand_mul(ones, beta_chunk) | ||||||||||||||
| ) | ||||||||||||||
| y = pl.assemble(y, result, [r, h0]) | ||||||||||||||
|
|
||||||||||||||
| return y | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -107,6 +134,7 @@ def compile_and_run( | |||||||||||||
| rows: int = ROWS, | ||||||||||||||
| hidden: int = HIDDEN, | ||||||||||||||
| row_chunk: int = ROW_CHUNK, | ||||||||||||||
| hidden_chunk: int = HIDDEN_CHUNK, | ||||||||||||||
| platform: str = "a2a3", | ||||||||||||||
| device_id: int = 11, | ||||||||||||||
| dump_passes: bool = True, | ||||||||||||||
|
|
@@ -119,6 +147,7 @@ def compile_and_run( | |||||||||||||
| rows=rows, | ||||||||||||||
| hidden=hidden, | ||||||||||||||
| row_chunk=row_chunk, | ||||||||||||||
| hidden_chunk=hidden_chunk, | ||||||||||||||
| ) | ||||||||||||||
| tensor_specs = build_tensor_specs( | ||||||||||||||
| rows=rows, | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard
hidden_chunkvalues that do not evenly tilehidden.hidden_blocks = hidden // hidden_chunktruncates. For anyhidden % hidden_chunk != 0, pass 1 drops the tail from the mean/variance reduction and pass 2 never writes the tail columns back toy. Ifhidden_chunk > hidden, both loops skip entirely and the output stays unwritten. Add a precondition here or explicit tail handling before building the loops.🛠️ Suggested guard
def build_layer_norm_program( rows: int = ROWS, hidden: int = HIDDEN, row_chunk: int = ROW_CHUNK, hidden_chunk: int = HIDDEN_CHUNK, eps: float = EPS, ): + if hidden <= 0: + raise ValueError(f"`hidden` must be > 0, got {hidden}") + if hidden_chunk <= 0 or hidden % hidden_chunk != 0: + raise ValueError( + "`hidden_chunk` must be a positive divisor of `hidden` " + f"(got hidden={hidden}, hidden_chunk={hidden_chunk})" + ) hidden_blocks = hidden // hidden_chunk hidden_inv = 1.0 / hidden🤖 Prompt for AI Agents