Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/basic/rmsnorm_incore_0.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
module attributes {pto.target_arch = "a5"} {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Add RUN directive so the new basic test actually executes

This new file in test/basic has no // RUN: directive, so it is not wired into the usual lit-style execution path used by other .pto tests in this directory. As written, the test content can be present in-tree without ever being exercised, which means regressions in this RMSNorm lowering path can slip through unnoticed; please add at least a compile RUN line (and preferably FileCheck assertions) to make this test actionable.

Useful? React with 👍 / 👎.

func.func @rmsnorm_incore_0(%arg0: !pto.ptr<bf16>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<bf16>, %arg3: index) attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c0i = arith.constant 0 : i64
%c64 = arith.constant 64 : i64
%c4160 = arith.constant 4160 : i64
%c12352 = arith.constant 12352 : i64
%c20544 = arith.constant 20544 : i64
%c20608 = arith.constant 20608 : i64
%c20672 = arith.constant 20672 : i64
%c16 = arith.constant 16 : index
%c5120 = arith.constant 5120 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c40 = arith.constant 40 : index
%c128 = arith.constant 128 : index
%cst_1 = arith.constant 1.953125e-04 : f32
%cst_2 = arith.constant 1.000000e-06 : f32
%hidden_states__ssa_v0_view = pto.make_tensor_view %arg0, shape = [%c16, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16>
%input_rms_weight__ssa_v0_view = pto.make_tensor_view %arg1, shape = [%c1, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xf32>
%normed_out__iter_v1_view = pto.make_tensor_view %arg2, shape = [%c16, %c5120], strides = [%c5120, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16>
%partial_sq_flat__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.texpands ins(%cst : f32) outs(%partial_sq_flat__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%partial_sq__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
scf.for %kb__idx_v0 = %c0 to %c40 step %c1 {
%8 = arith.muli %kb__idx_v0, %c128 : index
%t__tile = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%hidden_states__ssa_v0_pview = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%arg3, %8], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tload ins(%hidden_states__ssa_v0_pview : !pto.partition_tensor_view<16x128xbf16>) outs(%t__tile : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%x_chunk__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%t__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%0 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmul ins(%x_chunk__tile, %x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%tmp_tile = pto.alloc_tile addr = %c12352 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
pto.trowsum ins(%0, %tmp_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%1 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>)
%partial_sq__rm_a0_tmp_v0 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__rm_a1_tmp_v1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__row_major_tmp_v2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tadd ins(%partial_sq__rm_a0_tmp_v0, %partial_sq__rm_a1_tmp_v1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__row_major_tmp_v2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__tile_mv = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
pto.tmov ins(%2 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__tile_mv : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>)
}
%t__rm_a0_tmp_v3 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%t__row_major_tmp_v4 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmuls ins(%t__rm_a0_tmp_v3, %cst_1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%t__row_major_tmp_v4 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%3 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%variance__rm_a0_tmp_v5 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%variance__row_major_tmp_v6 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tadds ins(%variance__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The RMSNorm implementation appears to be missing the reciprocal square root operation. The current logic calculates the mean of squares plus epsilon (mean(x^2) + eps) and then multiplies the input by this value in the second loop (line 64). For a correct RMSNorm, the input should be multiplied by rsqrt(mean(x^2) + eps). Consider adding a pto.trsqrt operation after the epsilon addition to maintain semantic correctness for an RMSNorm test.

  pto.tadds ins(%variance__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
  pto.trsqrt ins(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%variance__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)

%variance__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
scf.for %9 = %c0 to %c40 step %c1 {
%10 = arith.muli %9, %c128 : index
%4 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%11 = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%arg3, %10], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tload ins(%11 : !pto.partition_tensor_view<16x128xbf16>) outs(%4 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%5 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%4{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%5 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%gamma__tile = pto.alloc_tile addr = %c20672 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%input_rms_weight__ssa_v0_pview = pto.partition_view %input_rms_weight__ssa_v0_view, offsets = [%c0, %10], sizes = [%c1, %c128] : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x128xf32>
pto.tload ins(%input_rms_weight__ssa_v0_pview : !pto.partition_tensor_view<1x128xf32>) outs(%gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%6 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.trowexpandmul ins(%5, %variance__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%6 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%normed__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcolexpandmul ins(%6, %gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%7 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%normed__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%7 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%normed_out__iter_v3_pview = pto.partition_view %normed_out__iter_v1_view, offsets = [%arg3, %10], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tstore ins(%7 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed_out__iter_v3_pview : !pto.partition_tensor_view<16x128xbf16>)
}
return
}
}
Loading