Skip to content

Commit 7c4a055

Browse files
committed
[Fix] Fusion axis mechanism change
1 parent b8a23f8 commit 7c4a055

3 files changed

Lines changed: 19 additions & 33 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,6 @@ def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> co
400400
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
401401
map_var = ops.affine_map(indices, expr_str, symbol_names=indirect_dims)
402402

403-
if hasattr(self, "dim_aliasing"):
404-
# Create reverse mapping: value -> key
405-
reverse_mapping = {v: k for k, v in self.dim_aliasing.items()}
406-
indices_as_keys = [reverse_mapping[idx] for idx in indices]
407-
sorted_pairs = sorted(indices_as_keys, key=lambda x: list(self.dim_aliasing.keys()).index(x))
408-
indices = [self.dim_aliasing[idx] for idx in sorted_pairs]
409403
index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments)
410404
return index
411405

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def __init__(self, kernel_group, reason=None):
614614
self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse)
615615
self._nested_context_depth = 0
616616

617-
def set_ranges(self, lengths, reduction_lengths):
617+
def set_ranges(self, lengths, reduction_lengths, index_names=None):
618618
if self.call_ranges:
619619
assert self.call_ranges == tuple(lengths) + tuple(
620620
reduction_lengths
@@ -623,7 +623,12 @@ def set_ranges(self, lengths, reduction_lengths):
623623
else:
624624
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
625625
self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
626-
self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))]
626+
if index_names is None:
627+
self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))]
628+
else:
629+
assert len(index_names) == len(self.ranges), f"Index names length mismatch: {len(index_names)} != {len(self.ranges)}"
630+
self.itervars = [sympy.Symbol(str(n)) for n in index_names]
631+
627632
self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars}
628633
self.reduction_depth = len(lengths)
629634
return (
@@ -867,18 +872,22 @@ def rename_indexing(self, index) -> sympy.Expr:
867872
def override_buffer_cse(self, *, buffer=None, cse=None):
868873
buffer_override = self.target_buffer_override
869874
cse_override = self.target_cse_override
870-
target_buffer = target_cse = None
875+
buffer_token = cse_token = None
871876
try:
877+
# Store tokens for proper restoration in nested contexts
878+
# contextvars.set() returns the previous value (token) which can be used for reset()
872879
if buffer is not None:
873-
target_buffer = buffer_override.set(buffer)
880+
buffer_token = buffer_override.set(buffer)
874881
if cse is not None:
875-
target_cse = cse_override.set(cse)
882+
cse_token = cse_override.set(cse)
876883
yield self
877884
finally:
878-
if target_cse is not None:
879-
cse_override.reset(target_cse)
880-
if target_buffer is not None:
881-
buffer_override.reset(target_buffer)
885+
# Restore using tokens - contextvars automatically handles nested contexts
886+
# Each level restores to its own previous value
887+
if cse_token is not None:
888+
cse_override.reset(cse_token)
889+
if buffer_token is not None:
890+
buffer_override.reset(buffer_token)
882891

883892
def __enter__(self):
884893
class CSEProxy:

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
469469
_, (group, reduction_group) = max(
470470
epilogue_nodes, key=lambda x: int(x.is_reduction())
471471
).group
472-
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
472+
vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values()))
473473
for node in epilogue_nodes:
474474
node.codegen((vars, reduction_vars))
475475

@@ -893,7 +893,6 @@ def load_epilogue(self, name: str, index: sympy.Expr):
893893
map_var = ops.affine_map(["d0", "d1"], f"d0 + d1*{(self.r_tile_size)}")
894894
with self.override_buffer_cse(buffer=self.loads):
895895
offset = ops.affine_apply(map_var, [self.compute_idx, self.reduction_loop_idx])
896-
#offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})")
897896
compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"])
898897

899898
with self.override_buffer_cse(buffer=self.loads):
@@ -1124,22 +1123,6 @@ def set_tile_size(self, template_fusion_info, prologue=False):
11241123
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
11251124
return tile_desc
11261125

1127-
def rename_indexing(self, index) -> sympy.Expr:
1128-
# First step: replace dim_name with tmp_+dim_aliased_name to avoid circular dependencies
1129-
# (e.g., {"index0":"index1", "index1":"index0"})
1130-
tmp_subs = {
1131-
sympy.Symbol(dim_name): sympy.Symbol("tmp_"+dim_aliased_name)
1132-
for dim_name, dim_aliased_name in self.dim_aliasing.items()
1133-
}
1134-
index = index.subs(tmp_subs)
1135-
# Second step: replace tmp_+dim_aliased_name with dim_aliased_name
1136-
final_subs = {
1137-
sympy.Symbol("tmp_"+dim_aliased_name): sympy.Symbol(dim_aliased_name)
1138-
for dim_aliased_name in self.dim_aliasing.values()
1139-
}
1140-
index = index.subs(final_subs)
1141-
return index
1142-
11431126
class MLIRTemplateCaller(CUDATemplateCaller):
11441127
def __str__(self):
11451128
return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})"

0 commit comments

Comments
 (0)