11import contextlib
22import sympy
3+ import sys
34import re
45import os
56from functools import reduce
@@ -374,9 +375,51 @@ def _convert_sympy_to_mlir_expr(self, expr, sorted_args):
374375 expr = expr .replace (target_arg , new_arg )
375376 indices .append (str (new_arg ))
376377
377- expr_str = str (expr )
378- if "//" in expr_str :
379- expr_str = expr_str .replace ("//" , " floordiv " )
378+ # Convert ModularIndexing and FloorDiv to sympy expressions
379+ # ModularIndexing(x, y, z) means (x // y) % z -> Mod(FloorDiv(x, y), z)
380+ # FloorDiv(x, y) means x // y -> will be converted to floordiv in string representation
381+ # Use preorder_traversal to find all instances
382+ replacements = {}
383+ for sub in sympy .preorder_traversal (expr ):
384+ if isinstance (sub , ModularIndexing ):
385+ # Convert ModularIndexing to Mod(FloorDiv(...), ...)
386+ if sub .args [1 ] != 1 :
387+ floor_div = FloorDiv (sub .args [0 ], sub .args [1 ])
388+ else :
389+ floor_div = sub .args [0 ]
390+ mod_expr = sympy .Mod (floor_div , sub .args [2 ])
391+ replacements [sub ] = mod_expr
392+ elif isinstance (sub , FloorDiv ):
393+ # Keep FloorDiv as is, will be handled in custom string conversion
394+ # We need to mark it for special handling
395+ pass
396+
397+ # Apply replacements
398+ for old_expr , new_expr in replacements .items ():
399+ expr = expr .subs (old_expr , new_expr )
400+
401+ # Custom string conversion for MLIR affine expressions
402+ def mlir_str (expr ):
403+ """Convert sympy expression to MLIR affine expression string"""
404+ if isinstance (expr , FloorDiv ):
405+ return f"({ mlir_str (expr .args [0 ])} floordiv { mlir_str (expr .args [1 ])} )"
406+ elif isinstance (expr , sympy .Mod ):
407+ return f"({ mlir_str (expr .args [0 ])} mod { mlir_str (expr .args [1 ])} )"
408+ elif isinstance (expr , sympy .Add ):
409+ terms = [mlir_str (term ) for term in expr .args ]
410+ return " + " .join (terms )
411+ elif isinstance (expr , sympy .Mul ):
412+ factors = [mlir_str (factor ) for factor in expr .args ]
413+ return " * " .join (factors )
414+ elif isinstance (expr , sympy .Symbol ):
415+ return str (expr )
416+ elif expr .is_number :
417+ return str (expr )
418+ else :
419+ # Fallback to string representation
420+ return str (expr )
421+
422+ expr_str = mlir_str (expr )
380423 return expr_str , indices
381424
382425 def parse_indices (self , expr , comments = "" , indices = None , indirect_dims = []) -> common .CSEVariable :
@@ -470,6 +513,10 @@ def load(self, name: str, index: sympy.Expr):
470513 tile_numel_per_lane = local_tile_desc .get_numel_per_lane ()
471514 tile_shape = local_tile_desc .get_mlir_shape (mlir_dtype )
472515 tile_stride = local_tile_desc .get_tile_stride ()
516+
517+ if len (dram_stride ) != len (tile_stride ):
518+ print ("here" )
519+
473520 # Compute vector unit size
474521 vshape = self .kernel_group .tile_desc .get_mlir_vshape (mlir_dtype )
475522 compute_vec_size = self .kernel_group .tile_desc .get_compute_vec_size ()
@@ -1157,7 +1204,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
11571204 sorted_constraints = sorted (axis_constraints , key = lambda c : int (c .args [1 ]))
11581205 for constraint in sorted_constraints [1 :]:
11591206 index = index .replace (constraint .original_expr , 0 )
1160-
1207+ if self .is_modular_indexing (index ):
1208+ print ("here" )
11611209 # Calculate dram stride
11621210 dram_stride = [0 ] * local_tile_desc .get_nr_dim ()
11631211 if index .is_Symbol :
@@ -1167,13 +1215,20 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
11671215 pass
11681216 else :
11691217 dram_dict = defaultdict (list )
1218+ implicit_dim_divisors = defaultdict (lambda : sys .maxsize )
11701219 # Assume that div will have high priority than mod
11711220 for arg in index .as_ordered_terms ():
11721221 coeff , dim = arg .as_coeff_mul ()
11731222 if len (dim ) == 0 :
11741223 continue
11751224 real_dim = list (dim [0 ].free_symbols )[0 ]
1176- dram_dict [str (real_dim )].append (coeff )
1225+ if dim [0 ].has (ModularIndexing ):
1226+ if dim [0 ].args [1 ] < implicit_dim_divisors [str (real_dim )]:
1227+ implicit_dim_divisors [str (real_dim )] = dim [0 ].args [1 ]
1228+ dram_dict [str (real_dim )] = [coeff ]
1229+ else :
1230+ dram_dict [str (real_dim )].append (coeff )
1231+
11771232 # Add missing dims if not added
11781233 max_dim = len (self .ranges ) if not store_reduction else len (self .ranges ) - 1
11791234 for i in range (max_dim ):
0 commit comments