Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/samples/Sync/decode_projection_incore_0.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
module attributes {pto.target_arch = "a5"} {
func.func @decode_projection_incore_0(%arg0: !pto.ptr<bf16>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<bf16>, %arg3: index) attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c0i = arith.constant 0 : i64
%c64 = arith.constant 64 : i64
%c4160 = arith.constant 4160 : i64
%c12352 = arith.constant 12352 : i64
%c20544 = arith.constant 20544 : i64
%c20608 = arith.constant 20608 : i64
%c20672 = arith.constant 20672 : i64
%c16 = arith.constant 16 : index
%c8192 = arith.constant 8192 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%9 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%cst_1 = arith.constant 1.220703e-04 : f32
%cst_2 = arith.constant 1.000000e-06 : f32
%hidden_states__ssa_v0_view = pto.make_tensor_view %arg0, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16>
%input_rms_weight__ssa_v0_view = pto.make_tensor_view %arg1, shape = [%c1, %c8192], strides = [%c8192, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xf32>
%normed_tile__ssa_v0_view = pto.make_tensor_view %arg2, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout<nd>}: !pto.tensor_view<?x?xbf16>
%partial_sq_flat__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.texpands ins(%cst : f32) outs(%partial_sq_flat__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%partial_sq__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
scf.for %kb__idx_v0 = %c0 to %9 step %c1 {
%10 = arith.muli %kb__idx_v0, %c128 : index
%t__tile = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%hidden_states__ssa_v0_pview = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%c0, %10], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tload ins(%hidden_states__ssa_v0_pview : !pto.partition_tensor_view<16x128xbf16>) outs(%t__tile : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%x_chunk__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%t__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%0 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmul ins(%x_chunk__tile, %x_chunk__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%tmp_tile = pto.alloc_tile addr = %c12352 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
pto.trowsum ins(%0, %tmp_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%1 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>)
%partial_sq__rm_a0_tmp_v0 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__rm_a1_tmp_v1 = pto.alloc_tile addr = %c20544 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__row_major_tmp_v2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tadd ins(%partial_sq__rm_a0_tmp_v0, %partial_sq__rm_a1_tmp_v1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__row_major_tmp_v2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%2 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%partial_sq__tile_mv = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
pto.tmov ins(%2 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%partial_sq__tile_mv : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>)
}
%t__rm_a0_tmp_v3 = pto.alloc_tile addr = %c20608 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%t__row_major_tmp_v4 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmuls ins(%t__rm_a0_tmp_v3, %cst_1 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%t__row_major_tmp_v4 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%3 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%t__rm_a0_tmp_v5 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%t__row_major_tmp_v6 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tadds ins(%t__rm_a0_tmp_v5, %cst_2 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%t__row_major_tmp_v6 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%4 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%inv_rms_tile__rm_a0_tmp_v7 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%inv_rms_tile__row_major_tmp_v8 = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.trsqrt ins(%inv_rms_tile__rm_a0_tmp_v7 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%inv_rms_tile__row_major_tmp_v8 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=16, v_row=1, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%inv_rms_tile__tile = pto.alloc_tile addr = %c0i : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>
scf.for %11 = %c0 to %9 step %c1 {
%12 = arith.muli %11, %c128 : index
%5 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%13 = pto.partition_view %hidden_states__ssa_v0_view, offsets = [%c0, %12], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tload ins(%13 : !pto.partition_tensor_view<16x128xbf16>) outs(%5 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%6 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%5{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%6 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%gamma__tile = pto.alloc_tile addr = %c20672 : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%input_rms_weight__ssa_v0_pview = pto.partition_view %input_rms_weight__ssa_v0_view, offsets = [%c0, %12], sizes = [%c1, %c128] : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x128xf32>
pto.tload ins(%input_rms_weight__ssa_v0_pview : !pto.partition_tensor_view<1x128xf32>) outs(%gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%7 = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.trowexpandmul ins(%6, %inv_rms_tile__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) outs(%7 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%normed__tile = pto.alloc_tile addr = %c4160 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcolexpandmul ins(%7, %gamma__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed__tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%8 = pto.alloc_tile addr = %c64 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tcvt ins(%normed__tile{rmode = #pto<round_mode ROUND>} : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%8 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
%normed_tile__iter_v1_pview = pto.partition_view %normed_tile__ssa_v0_view, offsets = [%c0, %12], sizes = [%c16, %c128] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x128xbf16>
pto.tstore ins(%8 : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=128, v_row=16, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%normed_tile__iter_v1_pview : !pto.partition_tensor_view<16x128xbf16>)
}
return
}
}
Loading
Loading