Skip to content

Commit a90f114

Browse files
committed
[Fix] Fusion axis mechanism change
1 parent 0c6175f commit a90f114

5 files changed

Lines changed: 28 additions & 39 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def __init__(self, kernel_group, reason=None):
313313
self.base_vector_initialized = False
314314

315315
def reset(self, reason):
316+
save = self.exit_stack, self._nested_context_depth
316317
self.__init__(self.kernel_group, reason=reason)
318+
self.exit_stack, self._nested_context_depth = save
317319

318320
# padding type 0: zero-padding 1: negative-padding(-inf) ...
319321
def get_padding_type(self):
@@ -395,17 +397,11 @@ def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> co
395397

396398
# Convert sympy expression to affine map expression
397399
expr_str, indices = self._convert_sympy_to_mlir_expr(expr, sorted_args)
398-
399-
# Extract index var
400-
if len(indirect_dims):
401-
comments = "{indirect_access} " + comments # Add indirect access attribute
402400
indirect_args = [f"%{i}" for i in indirect_dims]
403401
# Create affine.apply operation
404402
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
405403
map_var = ops.affine_map(indices, expr_str, symbol_names=indirect_dims)
406404

407-
if hasattr(self, "dim_aliasing"):
408-
indices = [self.dim_aliasing.get(index, index) for index in indices]
409405
index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments)
410406
return index
411407

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def __init__(self, kernel_group, reason=None):
614614
self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse)
615615
self._nested_context_depth = 0
616616

617-
def set_ranges(self, lengths, reduction_lengths):
617+
def set_ranges(self, lengths, reduction_lengths, index_names=None):
618618
if self.call_ranges:
619619
assert self.call_ranges == tuple(lengths) + tuple(
620620
reduction_lengths
@@ -623,7 +623,12 @@ def set_ranges(self, lengths, reduction_lengths):
623623
else:
624624
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
625625
self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
626-
self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))]
626+
if index_names is None:
627+
self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))]
628+
else:
629+
assert len(index_names) == len(self.ranges), f"Index names length mismatch: {len(index_names)} != {len(self.ranges)}"
630+
self.itervars = [sympy.Symbol(str(n)) for n in index_names]
631+
627632
self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars}
628633
self.reduction_depth = len(lengths)
629634
return (
@@ -867,18 +872,22 @@ def rename_indexing(self, index) -> sympy.Expr:
867872
def override_buffer_cse(self, *, buffer=None, cse=None):
868873
buffer_override = self.target_buffer_override
869874
cse_override = self.target_cse_override
870-
target_buffer = target_cse = None
875+
buffer_token = cse_token = None
871876
try:
877+
# Store tokens for proper restoration in nested contexts
878+
# contextvars.set() returns the previous value (token) which can be used for reset()
872879
if buffer is not None:
873-
target_buffer = buffer_override.set(buffer)
880+
buffer_token = buffer_override.set(buffer)
874881
if cse is not None:
875-
target_cse = cse_override.set(cse)
882+
cse_token = cse_override.set(cse)
876883
yield self
877884
finally:
878-
if target_cse is not None:
879-
cse_override.reset(target_cse)
880-
if target_buffer is not None:
881-
buffer_override.reset(target_buffer)
885+
# Restore using tokens - contextvars automatically handles nested contexts
886+
# Each level restores to its own previous value
887+
if cse_token is not None:
888+
cse_override.reset(cse_token)
889+
if buffer_token is not None:
890+
buffer_override.reset(buffer_token)
882891

883892
def __enter__(self):
884893
class CSEProxy:

PyTorchSimFrontend/mlir/mlir_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kw
11861186
# Add indirect dimensions if provided
11871187
if indirect_dims:
11881188
indirect_str = ", ".join(indirect_dims)
1189-
op_str += f"[{indirect_str}]"
1189+
op_str += f"[{indirect_str}] {{indirect_access}}"
11901190
if comment:
11911191
op_str += f" // {comment}"
11921192
return op_str, [1, "index"]

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def can_fuse_horizontal(self, node1, node2):
178178
return False
179179

180180
size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1)
181-
target_symbol = symbols("r0")
181+
target_symbol = symbols("r0_0")
182182
try:
183183
stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1]
184184
stride = int(sympify(stride).coeff(target_symbol))

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
429429
).group
430430
prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True)
431431
kernel.kernel_group.set_tile_info(prologue_tile_desc)
432-
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
432+
vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values()))
433433
for node in prologue_nodes:
434434
# Reuse created spad
435435
read_list = sorted([i.name for i in node.read_writes.reads])
@@ -469,10 +469,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
469469
_, (group, reduction_group) = max(
470470
epilogue_nodes, key=lambda x: int(x.is_reduction())
471471
).group
472-
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
472+
vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values()))
473473
for node in epilogue_nodes:
474474
node.codegen((vars, reduction_vars))
475475

476+
with self as kernel:
476477
src_code = (
477478
partial_code
478479
if isinstance(partial_code, str)
@@ -855,7 +856,7 @@ def load_epilogue(self, name: str, index: sympy.Expr):
855856
# Want to use tile_desc from epilogue_info
856857
with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse):
857858
index_var = self.parse_indices(index)
858-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()]
859+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()]
859860
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
860861
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
861862
tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype)
@@ -892,7 +893,6 @@ def load_epilogue(self, name: str, index: sympy.Expr):
892893
map_var = ops.affine_map(["d0", "d1"], f"d0 + d1*{(self.r_tile_size)}")
893894
with self.override_buffer_cse(buffer=self.loads):
894895
offset = ops.affine_apply(map_var, [self.compute_idx, self.reduction_loop_idx])
895-
#offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})")
896896
compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"])
897897

898898
with self.override_buffer_cse(buffer=self.loads):
@@ -908,7 +908,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
908908

909909
with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse):
910910
index_var = self.parse_indices(index)
911-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()]
911+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()]
912912
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
913913
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
914914
tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype)
@@ -1012,7 +1012,7 @@ def store_reduction_epilogue(self, name, index, value):
10121012

10131013
with self.override_buffer_cse(buffer=self.reductions_suffix, cse=self.apply_cse):
10141014
index_var = self.parse_indices(index, comments="// Store reduction")
1015-
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.keys()][:-1] # Assume that there is only one reduction axis
1015+
dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis
10161016
vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis
10171017
vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride
10181018

@@ -1123,22 +1123,6 @@ def set_tile_size(self, template_fusion_info, prologue=False):
11231123
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
11241124
return tile_desc
11251125

1126-
def rename_indexing(self, index) -> sympy.Expr:
1127-
# First step: replace dim_name with tmp_+dim_aliased_name to avoid circular dependencies
1128-
# (e.g., {"index0":"index1", "index1":"index0"})
1129-
tmp_subs = {
1130-
sympy.Symbol(dim_name): sympy.Symbol("tmp_"+dim_aliased_name)
1131-
for dim_name, dim_aliased_name in self.dim_aliasing.items()
1132-
}
1133-
index = index.subs(tmp_subs)
1134-
# Second step: replace tmp_+dim_aliased_name with dim_aliased_name
1135-
final_subs = {
1136-
sympy.Symbol("tmp_"+dim_aliased_name): sympy.Symbol(dim_aliased_name)
1137-
for dim_aliased_name in self.dim_aliasing.values()
1138-
}
1139-
index = index.subs(final_subs)
1140-
return index
1141-
11421126
class MLIRTemplateCaller(CUDATemplateCaller):
11431127
def __str__(self):
11441128
return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})"

0 commit comments

Comments
 (0)