Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""LayerNorm — full layer normalization with row-only tiling.
"""LayerNorm — full layer normalization with row + column tiling.

output[r, c] = (x[r, c] - mean(x[r, :])) / sqrt(var(x[r, :]) + eps) * gamma[c] + beta[c]

Rows are parallelised via pl.parallel (batch dimension).
The hidden dimension is loaded in full per tile (no column chunking),
keeping the kernel simple and single-pass friendly.
The hidden dimension is chunked with pl.range to accumulate the
sum and squared-sum reductions, then a second pass centres, normalises,
and applies gamma/beta.

This two-pass column-chunking pattern follows the same approach used
by rms_norm.py and the production LLM kernels (qwen3/deepseek).
Variance is computed via E[x^2] - E[x]^2 to avoid materialising the
centred tensor during the accumulation pass.

Input and output are FP32; gamma and beta are [1, hidden] weight vectors.
"""
Expand All @@ -21,17 +27,20 @@
import pypto.language as pl

ROWS = 512 # batch / sequence length
HIDDEN = 256 # hidden dimension (normalised axis, fits in one tile)
HIDDEN = 512 # hidden dimension (normalised axis)
ROW_CHUNK = 32 # rows per parallel tile
HIDDEN_CHUNK = 64 # columns per sequential chunk
EPS = 1e-5


def build_layer_norm_program(
rows: int = ROWS,
hidden: int = HIDDEN,
row_chunk: int = ROW_CHUNK,
hidden_chunk: int = HIDDEN_CHUNK,
eps: float = EPS,
):
hidden_blocks = hidden // hidden_chunk
hidden_inv = 1.0 / hidden
Comment on lines 36 to 44
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard hidden_chunk values that do not evenly tile hidden.

hidden_blocks = hidden // hidden_chunk truncates. For any hidden % hidden_chunk != 0, pass 1 drops the tail from the mean/variance reduction and pass 2 never writes the tail columns back to y. If hidden_chunk > hidden, both loops skip entirely and the output stays unwritten. Add a precondition here or explicit tail handling before building the loops.

🛠️ Suggested guard
 def build_layer_norm_program(
     rows: int = ROWS,
     hidden: int = HIDDEN,
     row_chunk: int = ROW_CHUNK,
     hidden_chunk: int = HIDDEN_CHUNK,
     eps: float = EPS,
 ):
+    if hidden <= 0:
+        raise ValueError(f"`hidden` must be > 0, got {hidden}")
+    if hidden_chunk <= 0 or hidden % hidden_chunk != 0:
+        raise ValueError(
+            "`hidden_chunk` must be a positive divisor of `hidden` "
+            f"(got hidden={hidden}, hidden_chunk={hidden_chunk})"
+        )
     hidden_blocks = hidden // hidden_chunk
     hidden_inv = 1.0 / hidden
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/layer_norm.py` around lines 36 - 44, The build_layer_norm_program
currently computes hidden_blocks = hidden // hidden_chunk which silently drops a
tail when hidden % hidden_chunk != 0 (and fails entirely if hidden_chunk >
hidden); update build_layer_norm_program to guard and handle tails: either
validate the inputs and raise a clear exception if hidden % hidden_chunk != 0 or
compute hidden_blocks = math.ceil(hidden / hidden_chunk) and add explicit
per-block logic in the loops (using hidden_chunk for full blocks and computing a
last_block_size = hidden - (hidden_blocks-1)*hidden_chunk for the final partial
block) so mean/variance reductions and writes to y only process the actual
remaining columns and use the correct normalization factor (use hidden or the
per-row actual element count) for the tail; reference the symbols hidden_blocks,
hidden_chunk, hidden_inv, and hidden when making the change.


@pl.program
Expand All @@ -46,31 +55,49 @@ def layer_norm(
) -> pl.Tensor[[rows, hidden], pl.FP32]:
with pl.auto_incore():
for r in pl.parallel(0, rows, row_chunk, chunk=1):
tile_x = pl.slice(x, [row_chunk, hidden], [r, 0])
gamma_tile = pl.slice(gamma, [1, hidden], [0, 0])
beta_tile = pl.slice(beta, [1, hidden], [0, 0])

# Step 1: row mean — pre-scale before row_sum, no reshape
mean = pl.row_sum(pl.mul(tile_x, hidden_inv))

# Step 2: row variance + eps — pre-scale and pre-add
centred = pl.row_expand_sub(tile_x, mean)
var_eps = pl.row_sum(
pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv)
# Pass 1: accumulate sum(x) and sum(x^2) across hidden chunks
# row_sum produces [row_chunk, 1] col_major; accumulate
# in [1, row_chunk] for scalar ops (same as rms_norm).
x_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
x_sum = pl.mul(x_sum, 0.0)
sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
sq_sum = pl.mul(sq_sum, 0.0)
Comment on lines +61 to +64
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of x_sum and sq_sum can be made more concise by combining the tensor creation and the zeroing operation into a single statement for each tensor. This improves readability and reduces redundancy.

Suggested change
x_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
x_sum = pl.mul(x_sum, 0.0)
sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
sq_sum = pl.mul(sq_sum, 0.0)
x_sum = pl.mul(pl.create_tensor([1, row_chunk], dtype=pl.FP32), 0.0)
sq_sum = pl.mul(pl.create_tensor([1, row_chunk], dtype=pl.FP32), 0.0)

for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
x_sum = pl.add(
x_sum, pl.reshape(pl.row_sum(x_chunk), [1, row_chunk])
)
sq_sum = pl.add(
sq_sum,
pl.reshape(
pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, row_chunk]
),
)

# mean and inv_std via E[x^2] - E[x]^2
mean_T = pl.mul(x_sum, hidden_inv)
var_T = pl.sub(
pl.mul(sq_sum, hidden_inv), pl.mul(mean_T, mean_T)
)

# Step 3: normalise — single reshape pair for sqrt
std = pl.reshape(
pl.sqrt(pl.reshape(var_eps, [1, row_chunk])),
[row_chunk, 1],
)
normed = pl.row_expand_div(centred, std)

# Step 4: apply gamma scale and beta offset
scaled = pl.col_expand_mul(normed, gamma_tile)
ones = pl.add(pl.sub(tile_x, tile_x), 1.0)
result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile))
y = pl.assemble(y, result, [r, 0])
inv_std_T = pl.rsqrt(pl.add(var_T, eps))
mean = pl.reshape(mean_T, [row_chunk, 1])
inv_std = pl.reshape(inv_std_T, [row_chunk, 1])

# Pass 2: centre, normalise, apply gamma/beta
for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0])
beta_chunk = pl.slice(beta, [1, hidden_chunk], [0, h0])
centred = pl.row_expand_sub(x_chunk, mean)
normed = pl.row_expand_mul(centred, inv_std)
scaled = pl.col_expand_mul(normed, gamma_chunk)
ones = pl.add(pl.sub(x_chunk, x_chunk), 1.0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ones tensor is being recreated in every iteration of the for hb in pl.range(hidden_blocks): loop. Since its shape and value are constant within this loop, it can be created once before the loop begins to avoid redundant computation and improve performance. You could create a template tensor of shape [row_chunk, hidden_chunk] before the loop and use that to generate the ones tensor.

result = pl.add(
scaled, pl.col_expand_mul(ones, beta_chunk)
)
y = pl.assemble(y, result, [r, h0])

return y

Expand Down Expand Up @@ -107,6 +134,7 @@ def compile_and_run(
rows: int = ROWS,
hidden: int = HIDDEN,
row_chunk: int = ROW_CHUNK,
hidden_chunk: int = HIDDEN_CHUNK,
platform: str = "a2a3",
device_id: int = 11,
dump_passes: bool = True,
Expand All @@ -119,6 +147,7 @@ def compile_and_run(
rows=rows,
hidden=hidden,
row_chunk=row_chunk,
hidden_chunk=hidden_chunk,
)
tensor_specs = build_tensor_specs(
rows=rows,
Expand Down
Loading