Refactor layer_norm to two-pass column-chunking pattern#31
Refactor layer_norm to two-pass column-chunking pattern#31zhangqi-chen wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
Replace the single-pass row-only tiling with a two-pass approach that chunks the hidden dimension, matching the pattern used by rms_norm.py and the production LLM kernels (qwen3/deepseek). Pass 1 accumulates sum(x) and sum(x^2) across hidden chunks, then computes mean and inv_std via E[x^2] - E[x]^2. Pass 2 centres, normalises, and applies gamma/beta per chunk. This enables larger hidden dimensions (HIDDEN bumped from 256 to 512) by avoiding loading the full hidden axis in a single tile.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors layer_norm.py to a more scalable two-pass column-chunking pattern, which is a solid improvement for handling larger hidden dimensions. The implementation correctly uses the E[x^2] - E[x]^2 formula for variance, which is memory-efficient. The code is clear and aligns well with existing patterns in the codebase. I've included a couple of suggestions for minor performance and style enhancements.
| x_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | ||
| x_sum = pl.mul(x_sum, 0.0) | ||
| sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | ||
| sq_sum = pl.mul(sq_sum, 0.0) |
There was a problem hiding this comment.
The initialization of x_sum and sq_sum can be made more concise by combining the tensor creation and the zeroing operation into a single statement for each tensor. This improves readability and reduces redundancy.
| x_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | |
| x_sum = pl.mul(x_sum, 0.0) | |
| sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) | |
| sq_sum = pl.mul(sq_sum, 0.0) | |
| x_sum = pl.mul(pl.create_tensor([1, row_chunk], dtype=pl.FP32), 0.0) | |
| sq_sum = pl.mul(pl.create_tensor([1, row_chunk], dtype=pl.FP32), 0.0) |
| centred = pl.row_expand_sub(x_chunk, mean) | ||
| normed = pl.row_expand_mul(centred, inv_std) | ||
| scaled = pl.col_expand_mul(normed, gamma_chunk) | ||
| ones = pl.add(pl.sub(x_chunk, x_chunk), 1.0) |
There was a problem hiding this comment.
The ones tensor is being recreated in every iteration of the for hb in pl.range(hidden_blocks): loop. Since its shape and value are constant within this loop, it can be created once before the loop begins to avoid redundant computation and improve performance. You could create a template tensor of shape [row_chunk, hidden_chunk] before the loop and use that to generate the ones tensor.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/layer_norm.py`:
- Around line 36-44: The build_layer_norm_program currently computes
hidden_blocks = hidden // hidden_chunk which silently drops a tail when hidden %
hidden_chunk != 0 (and fails entirely if hidden_chunk > hidden); update
build_layer_norm_program to guard and handle tails: either validate the inputs
and raise a clear exception if hidden % hidden_chunk != 0 or compute
hidden_blocks = math.ceil(hidden / hidden_chunk) and add explicit per-block
logic in the loops (using hidden_chunk for full blocks and computing a
last_block_size = hidden - (hidden_blocks-1)*hidden_chunk for the final partial
block) so mean/variance reductions and writes to y only process the actual
remaining columns and use the correct normalization factor (use hidden or the
per-row actual element count) for the tail; reference the symbols hidden_blocks,
hidden_chunk, hidden_inv, and hidden when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: b805272c-c629-4a73-9b53-b4e0d8a45439
📒 Files selected for processing (1)
examples/layer_norm.py
| def build_layer_norm_program( | ||
| rows: int = ROWS, | ||
| hidden: int = HIDDEN, | ||
| row_chunk: int = ROW_CHUNK, | ||
| hidden_chunk: int = HIDDEN_CHUNK, | ||
| eps: float = EPS, | ||
| ): | ||
| hidden_blocks = hidden // hidden_chunk | ||
| hidden_inv = 1.0 / hidden |
There was a problem hiding this comment.
Guard hidden_chunk values that do not evenly tile hidden.
hidden_blocks = hidden // hidden_chunk truncates. For any hidden % hidden_chunk != 0, pass 1 drops the tail from the mean/variance reduction and pass 2 never writes the tail columns back to y. If hidden_chunk > hidden, both loops skip entirely and the output stays unwritten. Add a precondition here or explicit tail handling before building the loops.
🛠️ Suggested guard
def build_layer_norm_program(
rows: int = ROWS,
hidden: int = HIDDEN,
row_chunk: int = ROW_CHUNK,
hidden_chunk: int = HIDDEN_CHUNK,
eps: float = EPS,
):
+ if hidden <= 0:
+ raise ValueError(f"`hidden` must be > 0, got {hidden}")
+ if hidden_chunk <= 0 or hidden % hidden_chunk != 0:
+ raise ValueError(
+ "`hidden_chunk` must be a positive divisor of `hidden` "
+ f"(got hidden={hidden}, hidden_chunk={hidden_chunk})"
+ )
hidden_blocks = hidden // hidden_chunk
hidden_inv = 1.0 / hidden🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/layer_norm.py` around lines 36 - 44, The build_layer_norm_program
currently computes hidden_blocks = hidden // hidden_chunk which silently drops a
tail when hidden % hidden_chunk != 0 (and fails entirely if hidden_chunk >
hidden); update build_layer_norm_program to guard and handle tails: either
validate the inputs and raise a clear exception if hidden % hidden_chunk != 0 or
compute hidden_blocks = math.ceil(hidden / hidden_chunk) and add explicit
per-block logic in the loops (using hidden_chunk for full blocks and computing a
last_block_size = hidden - (hidden_blocks-1)*hidden_chunk for the final partial
block) so mean/variance reductions and writes to y only process the actual
remaining columns and use the correct normalization factor (use hidden or the
per-row actual element count) for the tail; reference the symbols hidden_blocks,
hidden_chunk, hidden_inv, and hidden when making the change.
Summary
layer_norm.pyfrom single-pass row-only tiling to a two-pass row+column chunking patternsum(x)andsum(x²)across hidden chunks; Pass 2 centres, normalises, and applies gamma/beta per chunkE[x²] - E[x]²to avoid materialising the centred tensor during accumulationrms_norm.pyand production LLM kernels (qwen3/deepseek)HIDDEN_CHUNK = 64constantTesting
Summary by CodeRabbit