-
Notifications
You must be signed in to change notification settings - Fork 30
test: add rmsnorm PTO basic test #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| module attributes {pto.target_arch = "a5"} { | ||
| func.func @rmsnorm_incore_0(%arg0: !pto.ptr<bf16>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<bf16>, %arg3: index) attributes {pto.kernel_kind = #pto.kernel_kind<vector>} { | ||
| %c0i = arith.constant 0 : i64 | ||
| %c64 = arith.constant 64 : i64 | ||
| %c4160 = arith.constant 4160 : i64 | ||
| %c12352 = arith.constant 12352 : i64 | ||
| %c20544 = arith.constant 20544 : i64 | ||
| %c20608 = arith.constant 20608 : i64 | ||
| %c20672 = arith.constant 20672 : i64 | ||
| %c16 = arith.constant 16 : index | ||
| %c5120 = arith.constant 5120 : index | ||
| %c1 = arith.constant 1 : index | ||
| %cst = arith.constant 0.000000e+00 : f32 | ||
| %c0 = arith.constant 0 : index | ||
| %c40 = arith.constant 40 : index | ||
| %c128 = arith.constant 128 : index | ||
| %cst_1 = arith.constant 1.953125e-04 : f32 | ||
| %cst_2 = arith.constant 1.000000e-06 : f32 | ||
| %hidden_states__ssa_v0_view = pto.make_tensor_view %arg0, shape = [%c16, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16> | ||
| %input_rms_weight__ssa_v0_view = pto.make_tensor_view %arg1, shape = [%c1, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xf32> | ||
| %normed_out__iter_v1_view = pto.make_tensor_view %arg2, shape = [%c16, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16> | ||
| %partial_sq_flat__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.texpands ins(%cst : f32) outs(%partial_sq_flat__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %partial_sq__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| scf.for %kb__idx_v0 = %c0 to %c40 step %c1 { | ||
| %8 = arith.muli %kb__idx_v0, %c128 : index | ||
| %t__tile = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %hidden_states__ssa_v0_pview = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%arg3, %8], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16> | ||
| pto.tload ins(%hidden_states__ssa_v0_pview : !pto.partition_tensor_view<16x128xbf16>) outs(%t__tile : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %x_chunk__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tcvt ins(%t__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %0 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tmul ins(%x_chunk__tile, %x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %tmp_tile = pto.alloc_tile addr = %c12352 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.trowsum ins(%0, %tmp_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%1 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) | ||
| %partial_sq__rm_a0_tmp_v0 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %partial_sq__rm_a1_tmp_v1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %partial_sq__row_major_tmp_v2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tadd ins(%partial_sq__rm_a0_tmp_v0, %partial_sq__rm_a1_tmp_v1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__row_major_tmp_v2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| %partial_sq__tile_mv = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tmov ins(%2 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__tile_mv : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) | ||
| } | ||
| %t__rm_a0_tmp_v3 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %t__row_major_tmp_v4 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tmuls ins(%t__rm_a0_tmp_v3, %cst_1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%t__row_major_tmp_v4 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %3 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| %variance__rm_a0_tmp_v5 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %variance__row_major_tmp_v6 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tadds ins(%variance__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The RMSNorm implementation appears to be missing the reciprocal square root operation. The current logic calculates the mean of squares plus epsilon ( |
||
| %variance__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| scf.for %9 = %c0 to %c40 step %c1 { | ||
| %10 = arith.muli %9, %c128 : index | ||
| %4 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %11 = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%arg3, %10], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16> | ||
| pto.tload ins(%11 : !pto.partition_tensor_view<16x128xbf16>) outs(%4 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %5 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tcvt ins(%4{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%5 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %gamma__tile = pto.alloc_tile addr = %c20672 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %input_rms_weight__ssa_v0_pview = pto.partition_view %input_rms_weight__ssa_v0_view, offsets = [%c0, %10], sizes = [%c1, %c128] : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x128xf32> | ||
| pto.tload ins(%input_rms_weight__ssa_v0_pview : !pto.partition_tensor_view<1x128xf32>) outs(%gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %6 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.trowexpandmul ins(%5, %variance__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%6 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %normed__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tcolexpandmul ins(%6, %gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %7 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.tcvt ins(%normed__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%7 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) | ||
| %normed_out__iter_v3_pview = pto.partition_view %normed_out__iter_v1_view, offsets = [%arg3, %10], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16> | ||
| pto.tstore ins(%7 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed_out__iter_v3_pview : !pto.partition_tensor_view<16x128xbf16>) | ||
| } | ||
| return | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new file in
test/basichas no// RUN:directive, so it is not wired into the usual lit-style execution path used by other.ptotests in this directory. As written, the test content can be present in-tree without ever being exercised, which means regressions in this RMSNorm lowering path can slip through unnoticed; please add at least a compile RUN line (and preferablyFileCheckassertions) to make this test actionable.Useful? React with 👍 / 👎.