Conversation
|
/runa5 test/basic/rmsnorm_incore_0.pto --pto-level=level3 |
There was a problem hiding this comment.
Code Review
This pull request introduces a new test file, test/basic/rmsnorm_incore_0.pto, which implements an in-core RMSNorm kernel using the PTO dialect. A critical issue was identified in the implementation: the logic is missing the reciprocal square root operation. Currently, the kernel multiplies the input by the mean of squares plus epsilon, whereas the RMSNorm algorithm requires multiplication by the reciprocal square root of that value to ensure mathematical correctness.
| %3 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| %variance__rm_a0_tmp_v5 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %variance__row_major_tmp_v6 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tadds ins(%variance__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) |
There was a problem hiding this comment.
The RMSNorm implementation appears to be missing the reciprocal square root operation. The current logic calculates the mean of squares plus epsilon (mean(x^2) + eps) and then multiplies the input by this value in the second loop (line 64). For a correct RMSNorm, the input should be multiplied by rsqrt(mean(x^2) + eps). Consider adding a pto.trsqrt operation after the epsilon addition to maintain semantic correctness for an RMSNorm test.
pto.tadds ins(%variance__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.trsqrt ins(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f42ebae0cb
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| @@ -0,0 +1,74 @@ | |||
| module attributes {pto.target_arch = "a5"} { | |||
There was a problem hiding this comment.
Add RUN directive so the new basic test actually executes
This new file in test/basic has no // RUN: directive, so it is not wired into the usual lit-style execution path used by other .pto tests in this directory. As written, the test content can be present in-tree without ever being exercised, which means regressions in this RMSNorm lowering path can slip through unnoticed; please add at least a compile RUN line (and preferably FileCheck assertions) to make this test actionable.
Useful? React with 👍 / 👎.
Summary
test/basic/rmsnorm_incore_0.ptoas a standalone PTO basic testValidation
/home/gpt/PTOAS/build/tools/ptoas/ptoas --pto-level=level3 test/basic/rmsnorm_incore_0.pto