@@ -327,7 +327,7 @@ def get_padding_type(self):
327327 # return 1
328328 return 0
329329
330- def convert_index (self , expr , buffer ):
330+ def convert_index (self , expr ):
331331 if len (expr .free_symbols ) != 1 :
332332 raise NotImplementedError ("Not supporting this view operation...!" )
333333
@@ -346,17 +346,37 @@ def convert_index(self, expr, buffer):
346346 first_arg = expr .args [0 ]
347347 if len (first_arg .free_symbols ) != 1 :
348348 raise NotImplementedError ("What is this case?" )
349+
350+ # Create affine.apply operation
349351 indices = [list (first_arg .free_symbols )[0 ]]
350- args = ", " .join (map (str , indices ))
351- map_var = self .map_cse .generate (self .global_vars , f"affine_map<({ args } ) -> ({ expr_str } )>" )
352- args = ", " .join ([f"%{ i } " for i in indices ])
353- index = self .apply_cse .generate (buffer , f"affine.apply #{ map_var } ({ args } )" )
352+ with self .override_buffer_cse (buffer = self .global_vars , cse = self .map_cse ):
353+ map_var = ops .affine_map (indices , expr_str )
354+ index = ops .affine_apply (map_var , indices )
354355 return index
355356
356- def parse_indices (self , expr , buffer = None , comments = "" , indirect_dims = []) -> common .CSEVariable :
357- if buffer is None :
358- buffer = self .applys
357+ def _convert_sympy_to_mlir_expr (self , expr , sorted_args ):
358+ """
359+ Convert sympy expression to MLIR affine map expression by replacing index variables.
360+ """
361+ indices = []
362+
363+ for arg in sorted_args :
364+ if arg .is_Mul and arg .args [0 ].is_number :
365+ target_arg = arg .args [1 ]
366+ elif not arg .is_number :
367+ target_arg = arg
368+ else :
369+ continue
370+ new_arg = sympy .Symbol (str (self .convert_index (target_arg )))
371+ expr = expr .replace (target_arg , new_arg )
372+ indices .append (str (new_arg ))
373+
374+ expr_str = str (expr )
375+ if "//" in expr_str :
376+ expr_str = expr_str .replace ("//" , " floordiv " )
377+ return expr_str , indices
359378
379+ def parse_indices (self , expr , comments = "" , indices = None , indirect_dims = []) -> common .CSEVariable :
360380 # Constant case
361381 if expr .is_number and len (indirect_dims ) == 0 :
362382 return self .get_const_cse (int (expr ))
@@ -372,33 +392,25 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com
372392 # Sort index variable.. ex) (%index1, %index0)
373393 args_dict = {term : list (term .free_symbols )[0 ] for term in args if term .free_symbols }
374394 sorted_args = sorted (args_dict .keys (), key = lambda term : str (args_dict [term ]))
375- indices = []
376- for arg in sorted_args :
377- if arg .is_Mul and arg .args [0 ].is_number :
378- new_arg = sympy .Symbol (str (self .convert_index (arg .args [1 ], buffer )))
379- expr = expr .replace (arg .args [1 ], new_arg )
380- indices .append (str (new_arg ))
381- elif not arg .is_number :
382- new_arg = sympy .Symbol (str (self .convert_index (arg , buffer )))
383- expr = expr .replace (arg , new_arg )
384- indices .append (str (new_arg ))
395+
396+ # Convert sympy expression to affine map expression
397+ expr_str , indices = self ._convert_sympy_to_mlir_expr (expr , sorted_args )
385398
386399 # Extract index var
387- indirect_args = [f"%{ i } " for i in indirect_dims ]
388- if len (indirect_args ):
400+ if len (indirect_dims ):
389401 comments = "{indirect_access} " + comments # Add indirect access attribute
390- expr_str = str (expr )
391- if "//" in expr_str :
392- expr_str = expr_str .replace ("//" , " floordiv " )
393- args = ", " .join (map (str , indices ))
394- map_var = self .map_cse .generate (self .global_vars , f"affine_map<({ args } )[{ ',' .join (indirect_dims )} ] -> ({ expr_str } )>" )
395- args = ", " .join ([f"%{ i } " for i in indices ])
396- index = self .apply_cse .generate (buffer , f"affine.apply #{ map_var } ({ args } )[{ ',' .join (indirect_args )} ] { comments } " )
402+ indirect_args = [f"%{ i } " for i in indirect_dims ]
403+ # Create affine.apply operation
404+ with self .override_buffer_cse (buffer = self .global_vars , cse = self .map_cse ):
405+ map_var = ops .affine_map (indices , expr_str , symbol_names = indirect_dims )
406+
407+ if hasattr (self , "dim_aliasing" ):
408+ indices = [self .dim_aliasing .get (index , index ) for index in indices ]
409+ index = ops .affine_apply (map_var , indices , indirect_dims = indirect_args , comment = comments )
397410 return index
398411
399- def parse_index_list (self , expr_list :list , buffer = None , offset = sympy .Number (0 )) -> common .CSEVariable :
400- if buffer is None :
401- buffer = self .applys
412+ def parse_index_list (self , expr_list :list , offset = sympy .Number (0 )) -> common .CSEVariable :
413+ """ Need to override buffer and cse to use this function. """
402414 expr_list = [arg for arg in expr_list ]
403415 dim_list = [f"d{ i } " for i in range (len (expr_list ))]
404416
@@ -413,11 +425,11 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
413425 new_expr_list = [0 ] * len (expr_list )
414426 for idx , arg in enumerate (expr_list ):
415427 if arg .is_Mul and arg .args [0 ].is_number :
416- new_arg = sympy .Symbol (str (self .convert_index (arg .args [1 ], buffer )))
428+ new_arg = sympy .Symbol (str (self .convert_index (arg .args [1 ])))
417429 new_expr_list [idx ] = arg .subs (arg .args [1 ], dim_list [idx ])
418430 indices .append (str (new_arg ))
419431 elif not arg .is_number :
420- new_arg = sympy .Symbol (str (self .convert_index (arg , buffer )))
432+ new_arg = sympy .Symbol (str (self .convert_index (arg )))
421433 new_expr_list [idx ] = new_arg .subs (new_arg , dim_list [idx ])
422434 indices .append (str (new_arg ))
423435 else :
@@ -427,11 +439,11 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
427439 indices .append (str (new_arg ))
428440
429441 # Extract index var
442+ # Create affine.apply operation
430443 expr_str = str (sum (new_expr_list ) + offset )
431- args = ", " .join (map (str , dim_list ))
432- map_var = self .map_cse .generate (self .global_vars , f"affine_map<({ args } )[] -> ({ expr_str } )>" )
433- args = ", " .join ([f"%{ i } " for i in indices ])
434- index = self .apply_cse .generate (buffer , f"affine.apply #{ map_var } ({ args } )[]" )
444+ with self .override_buffer_cse (buffer = self .global_vars , cse = self .map_cse ):
445+ map_var = ops .affine_map (dim_list , expr_str )
446+ index = ops .affine_apply (map_var , indices )
435447 return index
436448
437449 def load (self , name : str , index : sympy .Expr ):
@@ -1080,7 +1092,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
10801092 if broadcast and (total_dims != local_dims or (self .reduction_depth != len (total_dims ) and total_dims [:self .reduction_depth ] == local_dims )):
10811093 local_dims = total_dims # Brodatcast tile shape
10821094
1083- index_var = self .parse_indices (index , buffer = buffer , indirect_dims = indirect_dims , comments = f"// store_reduction={ store_reduction } " )
1095+ with self .override_buffer_cse (buffer = buffer , cse = self .apply_cse ):
1096+ index_var = self .parse_indices (index , indirect_dims = indirect_dims , comments = f"// store_reduction={ store_reduction } " )
10841097
10851098 if kg_tile_desc .vmap .vlane_split_axis in local_dims :
10861099 local_vlane_split_axis = local_dims .index (kg_tile_desc .vmap .vlane_split_axis )
0 commit comments