Skip to content

Commit 1c99b05

Browse files
committed
[Fix] WIP
1 parent f60cbe5 commit 1c99b05

2 files changed

Lines changed: 73 additions & 6 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import sympy
3+
import sys
34
import re
45
import os
56
from functools import reduce
@@ -374,9 +375,51 @@ def _convert_sympy_to_mlir_expr(self, expr, sorted_args):
374375
expr = expr.replace(target_arg, new_arg)
375376
indices.append(str(new_arg))
376377

377-
expr_str = str(expr)
378-
if "//" in expr_str:
379-
expr_str = expr_str.replace("//", " floordiv ")
378+
# Convert ModularIndexing and FloorDiv to sympy expressions
379+
# ModularIndexing(x, y, z) means (x // y) % z -> Mod(FloorDiv(x, y), z)
380+
# FloorDiv(x, y) means x // y -> will be converted to floordiv in string representation
381+
# Use preorder_traversal to find all instances
382+
replacements = {}
383+
for sub in sympy.preorder_traversal(expr):
384+
if isinstance(sub, ModularIndexing):
385+
# Convert ModularIndexing to Mod(FloorDiv(...), ...)
386+
if sub.args[1] != 1:
387+
floor_div = FloorDiv(sub.args[0], sub.args[1])
388+
else:
389+
floor_div = sub.args[0]
390+
mod_expr = sympy.Mod(floor_div, sub.args[2])
391+
replacements[sub] = mod_expr
392+
elif isinstance(sub, FloorDiv):
393+
# Keep FloorDiv as is, will be handled in custom string conversion
394+
# We need to mark it for special handling
395+
pass
396+
397+
# Apply replacements
398+
for old_expr, new_expr in replacements.items():
399+
expr = expr.subs(old_expr, new_expr)
400+
401+
# Custom string conversion for MLIR affine expressions
402+
def mlir_str(expr):
403+
"""Convert sympy expression to MLIR affine expression string"""
404+
if isinstance(expr, FloorDiv):
405+
return f"({mlir_str(expr.args[0])} floordiv {mlir_str(expr.args[1])})"
406+
elif isinstance(expr, sympy.Mod):
407+
return f"({mlir_str(expr.args[0])} mod {mlir_str(expr.args[1])})"
408+
elif isinstance(expr, sympy.Add):
409+
terms = [mlir_str(term) for term in expr.args]
410+
return " + ".join(terms)
411+
elif isinstance(expr, sympy.Mul):
412+
factors = [mlir_str(factor) for factor in expr.args]
413+
return " * ".join(factors)
414+
elif isinstance(expr, sympy.Symbol):
415+
return str(expr)
416+
elif expr.is_number:
417+
return str(expr)
418+
else:
419+
# Fallback to string representation
420+
return str(expr)
421+
422+
expr_str = mlir_str(expr)
380423
return expr_str, indices
381424

382425
def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> common.CSEVariable:
@@ -470,6 +513,10 @@ def load(self, name: str, index: sympy.Expr):
470513
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
471514
tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype)
472515
tile_stride = local_tile_desc.get_tile_stride()
516+
517+
if len(dram_stride) != len(tile_stride):
518+
print("here")
519+
473520
# Compute vector unit size
474521
vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype)
475522
compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size()
@@ -1157,7 +1204,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
11571204
sorted_constraints = sorted(axis_constraints, key=lambda c: int(c.args[1]))
11581205
for constraint in sorted_constraints[1:]:
11591206
index = index.replace(constraint.original_expr, 0)
1160-
1207+
if self.is_modular_indexing(index):
1208+
print("here")
11611209
# Calculate dram stride
11621210
dram_stride = [0] * local_tile_desc.get_nr_dim()
11631211
if index.is_Symbol:
@@ -1167,13 +1215,20 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
11671215
pass
11681216
else:
11691217
dram_dict = defaultdict(list)
1218+
implicit_dim_divisors = defaultdict(lambda: sys.maxsize)
11701219
# Assume that div will have high priority than mod
11711220
for arg in index.as_ordered_terms():
11721221
coeff, dim = arg.as_coeff_mul()
11731222
if len(dim) == 0:
11741223
continue
11751224
real_dim = list(dim[0].free_symbols)[0]
1176-
dram_dict[str(real_dim)].append(coeff)
1225+
if dim[0].has(ModularIndexing):
1226+
if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]:
1227+
implicit_dim_divisors[str(real_dim)] = dim[0].args[1]
1228+
dram_dict[str(real_dim)] = [coeff]
1229+
else:
1230+
dram_dict[str(real_dim)].append(coeff)
1231+
11771232
# Add missing dims if not added
11781233
max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1
11791234
for i in range(max_dim):

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N
504504
vlane_stride=vlane_stride
505505
)
506506

507-
self.implicit_dim_size = None
507+
self.implicit_dim_size = dict()
508508
self.nr_rdim = 0
509509
self.offset = sympy.Integer(0) # Dram offset
510510

@@ -654,6 +654,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
654654
def indirect_indexing(self, index_var, size, check, wrap_neg):
655655
raise NotImplementedError()
656656

657+
def check_bounds(self, index, size, lower, upper):
658+
return
659+
#raise NotImplementedError()
660+
657661
def codegen_global_init(self):
658662
raise NotImplementedError()
659663

@@ -918,6 +922,10 @@ def indirect_indexing(index_var, size, check=True, wrap_neg=True):
918922
# Skip CSE since this doesn't return an expression
919923
return self.indirect_indexing(index_var, size, check, wrap_neg)
920924

925+
@staticmethod
926+
def check_bounds(index, size, lower, upper):
927+
return self.check_bounds(index, size, lower, upper)
928+
921929
@staticmethod
922930
def load(name: str, index: sympy.Expr):
923931
index = self.rename_indexing(index)
@@ -964,6 +972,10 @@ def store_reduction(name, index, value):
964972
def reduction(dtype, src_dtype, reduction_type, value):
965973
return self.reduction(dtype, src_dtype, reduction_type, value)
966974

975+
@staticmethod
976+
def check_bounds(index, size, lower, upper):
977+
return self.check_bounds(index, size, lower, upper)
978+
967979
@staticmethod
968980
def _index_expr(tile_size, buffer, renamed_expression, index):
969981
return self._index_expr(tile_size, buffer, renamed_expression, index)

0 commit comments

Comments
 (0)