diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 37db4956..a3ae6192 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -238,7 +238,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %chunk_val = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> %local_max = arith.maximumf %chunk_val, %iter_max : vector<{{ chunk_size }}x{{ data_stype }}> affine.yield %local_max : vector<{{ chunk_size }}x{{ data_stype }}> - } + } { accumulation_loop=true } %max_cast = vector.shape_cast %chunk_max_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> @@ -284,7 +284,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %chunk_exp = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> - } + } { accumulation_loop=true } %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> @@ -301,7 +301,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: { idx_map = array } ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) - } {inner_loop=true} + } { accumulation_loop=true } // out @ row_sum^(-1) %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> @@ -317,7 +317,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %out_dram_offset = affine.apply {{ out_offset_map }}(%index0, %index1, %index3) {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=8, dram_stride=out_dram_stride, dram_offset="out_dram_offset") }} - } { accumulation_loop=true } + } { outer_loop=true } } { outer_loop=true } } { outer_loop=true } return