@@ -429,7 +429,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
429429 ).group
430430 prologue_tile_desc = kernel .set_tile_size (kernel .prologue_info , prologue = True )
431431 kernel .kernel_group .set_tile_info (prologue_tile_desc )
432- vars , reduction_vars = kernel .set_ranges (group , reduction_group )
432+ vars , reduction_vars = kernel .set_ranges (group , reduction_group , list ( self . dim_aliasing . values ()) )
433433 for node in prologue_nodes :
434434 # Reuse created spad
435435 read_list = sorted ([i .name for i in node .read_writes .reads ])
@@ -469,10 +469,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
469469 _ , (group , reduction_group ) = max (
470470 epilogue_nodes , key = lambda x : int (x .is_reduction ())
471471 ).group
472- vars , reduction_vars = kernel .set_ranges (group , reduction_group )
472+ vars , reduction_vars = kernel .set_ranges (group , reduction_group , list ( self . dim_aliasing . values ()) )
473473 for node in epilogue_nodes :
474474 node .codegen ((vars , reduction_vars ))
475475
476+ with self as kernel :
476477 src_code = (
477478 partial_code
478479 if isinstance (partial_code , str )
@@ -855,7 +856,7 @@ def load_epilogue(self, name: str, index: sympy.Expr):
855856 # Want to use tile_desc from epilogue_info
856857 with self .override_buffer_cse (buffer = self .applys , cse = self .apply_cse ):
857858 index_var = self .parse_indices (index )
858- dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .keys ()]
859+ dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .values ()]
859860 vlane_split_axis = self .kernel_group .tile_desc .vmap .vlane_split_axis
860861 vlane_stride = self .kernel_group .tile_desc .vmap .vlane_stride
861862 tile_shape = self .kernel_group .tile_desc .get_mlir_shape (mlir_dtype )
@@ -892,7 +893,6 @@ def load_epilogue(self, name: str, index: sympy.Expr):
892893 map_var = ops .affine_map (["d0" , "d1" ], f"d0 + d1*{ (self .r_tile_size )} " )
893894 with self .override_buffer_cse (buffer = self .loads ):
894895 offset = ops .affine_apply (map_var , [self .compute_idx , self .reduction_loop_idx ])
895- #offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})")
896896 compute_index_var = "," .join ([f"%{ zero_var } " ] * (self .kernel_group .tile_desc .get_nr_dim ()- 1 ) + [f"%{ offset } " ])
897897
898898 with self .override_buffer_cse (buffer = self .loads ):
@@ -908,7 +908,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
908908
909909 with self .override_buffer_cse (buffer = self .applys , cse = self .apply_cse ):
910910 index_var = self .parse_indices (index )
911- dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .keys ()]
911+ dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .values ()]
912912 vlane_split_axis = self .kernel_group .tile_desc .vmap .vlane_split_axis
913913 vlane_stride = self .kernel_group .tile_desc .vmap .vlane_stride
914914 tile_shape = self .kernel_group .tile_desc .get_mlir_shape (mlir_dtype )
@@ -1012,7 +1012,7 @@ def store_reduction_epilogue(self, name, index, value):
10121012
10131013 with self .override_buffer_cse (buffer = self .reductions_suffix , cse = self .apply_cse ):
10141014 index_var = self .parse_indices (index , comments = "// Store reduction" )
1015- dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .keys ()][:- 1 ] # Assume that there is only one reduction axis
1015+ dram_stride = [index .coeff (sympy .Symbol (val )) for val in self .dim_aliasing .values ()][:- 1 ] # Assume that there is only one reduction axis
10161016 vlane_split_axis = self .kernel_group .tile_desc .vmap .vlane_split_axis
10171017 vlane_stride = self .kernel_group .tile_desc .vmap .vlane_stride
10181018
@@ -1123,22 +1123,6 @@ def set_tile_size(self, template_fusion_info, prologue=False):
11231123 self .compute_body_loop .step = tile_desc .get_compute_vec_size ()
11241124 return tile_desc
11251125
1126- def rename_indexing (self , index ) -> sympy .Expr :
1127- # First step: replace dim_name with tmp_+dim_aliased_name to avoid circular dependencies
1128- # (e.g., {"index0":"index1", "index1":"index0"})
1129- tmp_subs = {
1130- sympy .Symbol (dim_name ): sympy .Symbol ("tmp_" + dim_aliased_name )
1131- for dim_name , dim_aliased_name in self .dim_aliasing .items ()
1132- }
1133- index = index .subs (tmp_subs )
1134- # Second step: replace tmp_+dim_aliased_name with dim_aliased_name
1135- final_subs = {
1136- sympy .Symbol ("tmp_" + dim_aliased_name ): sympy .Symbol (dim_aliased_name )
1137- for dim_aliased_name in self .dim_aliasing .values ()
1138- }
1139- index = index .subs (final_subs )
1140- return index
1141-
11421126class MLIRTemplateCaller (CUDATemplateCaller ):
11431127 def __str__ (self ):
11441128 return f"MLIRTemplateCaller(source_file={ self .bmreq .source_file } )"
0 commit comments