Skip to content

Commit b8a23f8

Browse files
committed
[Fix] indirect access attribute
1 parent 9f9583b commit b8a23f8

3 files changed

Lines changed: 6 additions & 10 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,6 @@ def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> co
395395

396396
# Convert sympy expression to affine map expression
397397
expr_str, indices = self._convert_sympy_to_mlir_expr(expr, sorted_args)
398-
399-
# Extract index var
400-
if len(indirect_dims):
401-
comments = "{indirect_access} " + comments # Add indirect access attribute
402398
indirect_args = [f"%{i}" for i in indirect_dims]
403399
# Create affine.apply operation
404400
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
@@ -407,8 +403,8 @@ def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> co
407403
if hasattr(self, "dim_aliasing"):
408404
# Create reverse mapping: value -> key
409405
reverse_mapping = {v: k for k, v in self.dim_aliasing.items()}
410-
indices_as_keys = [reverse_mapping.get(idx, idx) for idx in indices]
411-
sorted_pairs = sorted(indices_as_keys, key=lambda x: str(x))
406+
indices_as_keys = [reverse_mapping[idx] for idx in indices]
407+
sorted_pairs = sorted(indices_as_keys, key=lambda x: list(self.dim_aliasing.keys()).index(x))
412408
indices = [self.dim_aliasing[idx] for idx in sorted_pairs]
413409
index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments)
414410
return index

PyTorchSimFrontend/mlir/mlir_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kw
11861186
# Add indirect dimensions if provided
11871187
if indirect_dims:
11881188
indirect_str = ", ".join(indirect_dims)
1189-
op_str += f"[{indirect_str}]"
1189+
op_str += f"[{indirect_str}] {{indirect_access}}"
11901190
if comment:
11911191
op_str += f" // {comment}"
11921192
return op_str, [1, "index"]

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def load_epilogue(self, name: str, index: sympy.Expr):
856856
# Want to use tile_desc from epilogue_info
857857
with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse):
858858
index_var = self.parse_indices(index)
859-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()]
859+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()]
860860
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
861861
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
862862
tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype)
@@ -909,7 +909,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
909909

910910
with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse):
911911
index_var = self.parse_indices(index)
912-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()]
912+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()]
913913
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
914914
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
915915
tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype)
@@ -1013,7 +1013,7 @@ def store_reduction_epilogue(self, name, index, value):
10131013

10141014
with self.override_buffer_cse(buffer=self.reductions_suffix, cse=self.apply_cse):
10151015
index_var = self.parse_indices(index, comments="// Store reduction")
1016-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()][:-1] # Assume that there is only one reduction axis
1016+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis
10171017
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
10181018
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
10191019

0 commit comments

Comments
 (0)