Skip to content

Commit 0c6175f

Browse files
committed
[Template] Fix template fusion codegen
1 parent ea79ad0 commit 0c6175f

6 files changed

Lines changed: 153 additions & 91 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def get_padding_type(self):
327327
# return 1
328328
return 0
329329

330-
def convert_index(self, expr, buffer):
330+
def convert_index(self, expr):
331331
if len(expr.free_symbols) != 1:
332332
raise NotImplementedError("Not supporting this view operation...!")
333333

@@ -346,17 +346,37 @@ def convert_index(self, expr, buffer):
346346
first_arg = expr.args[0]
347347
if len(first_arg.free_symbols) != 1:
348348
raise NotImplementedError("What is this case?")
349+
350+
# Create affine.apply operation
349351
indices = [list(first_arg.free_symbols)[0]]
350-
args = ", ".join(map(str, indices))
351-
map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>")
352-
args = ", ".join([f"%{i}" for i in indices])
353-
index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})")
352+
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
353+
map_var = ops.affine_map(indices, expr_str)
354+
index = ops.affine_apply(map_var, indices)
354355
return index
355356

356-
def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable:
357-
if buffer is None:
358-
buffer = self.applys
357+
def _convert_sympy_to_mlir_expr(self, expr, sorted_args):
358+
"""
359+
Convert sympy expression to MLIR affine map expression by replacing index variables.
360+
"""
361+
indices = []
362+
363+
for arg in sorted_args:
364+
if arg.is_Mul and arg.args[0].is_number:
365+
target_arg = arg.args[1]
366+
elif not arg.is_number:
367+
target_arg = arg
368+
else:
369+
continue
370+
new_arg = sympy.Symbol(str(self.convert_index(target_arg)))
371+
expr = expr.replace(target_arg, new_arg)
372+
indices.append(str(new_arg))
373+
374+
expr_str = str(expr)
375+
if "//" in expr_str:
376+
expr_str = expr_str.replace("//", " floordiv ")
377+
return expr_str, indices
359378

379+
def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> common.CSEVariable:
360380
# Constant case
361381
if expr.is_number and len(indirect_dims) == 0:
362382
return self.get_const_cse(int(expr))
@@ -372,33 +392,25 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com
372392
# Sort index variable.. ex) (%index1, %index0)
373393
args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols}
374394
sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term]))
375-
indices = []
376-
for arg in sorted_args:
377-
if arg.is_Mul and arg.args[0].is_number:
378-
new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer)))
379-
expr = expr.replace(arg.args[1], new_arg)
380-
indices.append(str(new_arg))
381-
elif not arg.is_number:
382-
new_arg = sympy.Symbol(str(self.convert_index(arg, buffer)))
383-
expr = expr.replace(arg, new_arg)
384-
indices.append(str(new_arg))
395+
396+
# Convert sympy expression to affine map expression
397+
expr_str, indices = self._convert_sympy_to_mlir_expr(expr, sorted_args)
385398

386399
# Extract index var
387-
indirect_args = [f"%{i}" for i in indirect_dims]
388-
if len(indirect_args):
400+
if len(indirect_dims):
389401
comments = "{indirect_access} " + comments # Add indirect access attribute
390-
expr_str = str(expr)
391-
if "//" in expr_str:
392-
expr_str = expr_str.replace("//", " floordiv ")
393-
args = ", ".join(map(str, indices))
394-
map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>")
395-
args = ", ".join([f"%{i}" for i in indices])
396-
index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}")
402+
indirect_args = [f"%{i}" for i in indirect_dims]
403+
# Create affine.apply operation
404+
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
405+
map_var = ops.affine_map(indices, expr_str, symbol_names=indirect_dims)
406+
407+
if hasattr(self, "dim_aliasing"):
408+
indices = [self.dim_aliasing.get(index, index) for index in indices]
409+
index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments)
397410
return index
398411

399-
def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) -> common.CSEVariable:
400-
if buffer is None:
401-
buffer = self.applys
412+
def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSEVariable:
413+
""" Need to override buffer and cse to use this function. """
402414
expr_list = [arg for arg in expr_list]
403415
dim_list = [f"d{i}" for i in range(len(expr_list))]
404416

@@ -413,11 +425,11 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
413425
new_expr_list = [0] * len(expr_list)
414426
for idx, arg in enumerate(expr_list):
415427
if arg.is_Mul and arg.args[0].is_number:
416-
new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer)))
428+
new_arg = sympy.Symbol(str(self.convert_index(arg.args[1])))
417429
new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx])
418430
indices.append(str(new_arg))
419431
elif not arg.is_number:
420-
new_arg = sympy.Symbol(str(self.convert_index(arg, buffer)))
432+
new_arg = sympy.Symbol(str(self.convert_index(arg)))
421433
new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx])
422434
indices.append(str(new_arg))
423435
else:
@@ -427,11 +439,11 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
427439
indices.append(str(new_arg))
428440

429441
# Extract index var
442+
# Create affine.apply operation
430443
expr_str = str(sum(new_expr_list) + offset)
431-
args = ", ".join(map(str, dim_list))
432-
map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>")
433-
args = ", ".join([f"%{i}" for i in indices])
434-
index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]")
444+
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
445+
map_var = ops.affine_map(dim_list, expr_str)
446+
index = ops.affine_apply(map_var, indices)
435447
return index
436448

437449
def load(self, name: str, index: sympy.Expr):
@@ -1080,7 +1092,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
10801092
if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)):
10811093
local_dims = total_dims # Brodatcast tile shape
10821094

1083-
index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}")
1095+
with self.override_buffer_cse(buffer=buffer, cse=self.apply_cse):
1096+
index_var = self.parse_indices(index, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}")
10841097

10851098
if kg_tile_desc.vmap.vlane_split_axis in local_dims:
10861099
local_vlane_split_axis = local_dims.index(kg_tile_desc.vmap.vlane_split_axis)

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def __init__(self, kernel_group, reason=None):
612612
instance_id = id(self)
613613
self.target_buffer_override = contextvars.ContextVar(f"Handler_compute_override_{instance_id}", default=self.compute)
614614
self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse)
615+
self._nested_context_depth = 0
615616

616617
def set_ranges(self, lengths, reduction_lengths):
617618
if self.call_ranges:
@@ -992,13 +993,20 @@ def bucketize(
992993
values, offsets_name, offsets_size, indexing_dtype, right
993994
)
994995

995-
super().__enter__()
996-
assert self.overrides
997-
parent_handler = self.overrides()
998-
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
999-
self.exit_stack.enter_context(V.set_kernel_handler(self))
996+
if self._nested_context_depth == 0:
997+
self.exit_stack.__enter__()
998+
assert self.overrides
999+
parent_handler = self.overrides()
1000+
1001+
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
1002+
self.exit_stack.enter_context(V.set_kernel_handler(self))
1003+
self._nested_context_depth += 1
10001004
return self
10011005

1006+
def __exit__(self, exc_type, exc_val, exc_tb):
1007+
self._nested_context_depth -= 1
1008+
if self._nested_context_depth == 0:
1009+
super().__exit__(exc_type, exc_val, exc_tb)
10021010

10031011
@dataclasses.dataclass
10041012
class LoopLevel:

PyTorchSimFrontend/mlir/mlir_gemm_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def render(self,
154154
W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride)
155155
W_tile_desc.set_name("W_buffer")
156156
W_tile_desc.offset = W.get_layout().offset
157-
W_stride = W.get_layout().stride
157+
W_stride = W.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0]
158158
W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]]
159159

160160
vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0
@@ -163,7 +163,7 @@ def render(self,
163163
Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
164164
Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride)
165165
Y_tile_desc.set_name("Y_buffer")
166-
Y_stride = Y.get_layout().stride
166+
Y_stride = Y.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0]
167167
if nr_rdim == 0:
168168
Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]]
169169
else:

PyTorchSimFrontend/mlir/mlir_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,4 +1175,41 @@ def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kw
11751175
if buffer_name is not None:
11761176
return common.DeferredLine(buffer_name, line), [None, None]
11771177
else:
1178-
return line, [None, None]
1178+
return line, [None, None]
1179+
1180+
@staticmethod
1181+
def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kwargs):
1182+
# Format indices arguments
1183+
indices_str = ", ".join([f"%{i}" for i in indices])
1184+
op_str = f"affine.apply #{map_var}({indices_str})"
1185+
1186+
# Add indirect dimensions if provided
1187+
if indirect_dims:
1188+
indirect_str = ", ".join(indirect_dims)
1189+
op_str += f"[{indirect_str}]"
1190+
if comment:
1191+
op_str += f" // {comment}"
1192+
return op_str, [1, "index"]
1193+
1194+
@staticmethod
1195+
def affine_map(dim_names, expr_str, symbol_names=None, comment=None, *args, **kwargs):
1196+
# Handle dim_names as list or string
1197+
if isinstance(dim_names, list):
1198+
dims_str = ", ".join([str(dim) for dim in dim_names])
1199+
else:
1200+
dims_str = dim_names
1201+
1202+
# Build the map string
1203+
if symbol_names:
1204+
if isinstance(symbol_names, list):
1205+
symbols_str = ", ".join(symbol_names)
1206+
else:
1207+
symbols_str = symbol_names
1208+
map_str = f"affine_map<({dims_str})[{symbols_str}] -> ({expr_str})>"
1209+
else:
1210+
map_str = f"affine_map<({dims_str}) -> ({expr_str})>"
1211+
1212+
if comment:
1213+
map_str += f" // {comment}"
1214+
1215+
return map_str, [1, "map"]

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def __init__(self, scheduler):
3535
self.max_fusion_size = 5
3636

3737
def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
38-
if not extension_config.CONFIG_FUSION:
39-
return False
38+
if not extension_config.CONFIG_FUSION_PROLOGUE:
39+
return self.scheduler.can_fuse_origin(node1, node2)
4040

4141
# Extract base template node
4242
base_template_node1 = [node for node in node1.get_nodes() if node.is_template()]
4343
base_template_node2 = [node for node in node2.get_nodes() if node.is_template()]
4444

4545
# Case 3: Prologue(Pointwise) + Tempalte
46-
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
46+
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
4747
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
4848
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
4949

@@ -126,7 +126,7 @@ def can_fuse_horizontal(self, node1, node2):
126126
return same_iter and no_dependency
127127

128128
# Case 1: Template + Pointwise fusion
129-
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction():
129+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction():
130130
# Don't fuse maxpool template code
131131
from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate
132132
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
@@ -170,7 +170,7 @@ def can_fuse_horizontal(self, node1, node2):
170170
return True
171171

172172
# Case 2: Tempalte + Reduction fusion
173-
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
173+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
174174
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
175175
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
176176
target_node = base_template_node1[0].node
@@ -185,39 +185,35 @@ def can_fuse_horizontal(self, node1, node2):
185185
except:
186186
return False
187187

188-
# We can't fuse dim=-1
189-
layout_possible = stride != 1
188+
# We can't fuse dim=-1 & N == 1
189+
layout_possible = stride != 1 and (1 not in node1.node.get_size())
190190
# Directed linked?
191191
dependency_check = writes1 & reads2
192192
dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads])
193193
return size_match and layout_possible and dependency_check and dependency_size
194194

195195
# Case 3: Prologue(Pointwise) + Tempalte
196-
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
197-
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
198-
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
199-
200-
target_node = base_template_node2[0].node
201-
# Currently only BMM, MM support prologue fusion
202-
if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
203-
return False
204-
205-
if len(node1.read_writes.writes) != 1:
206-
return False
207-
if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
208-
return False
209-
210-
# We don't fuse this edge case...
211-
if base_template_node2[0].group[1][0][0] == 1:
212-
return False
213-
214-
if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
215-
node1 = self.revert_group(node1)
216-
return True
217-
218-
# Check elementwise fusion
219-
if vars1 == vars2 and reduce1 == reduce2 and not node1.is_reduction() and not node2.is_reduction():
220-
return writes1 & reads2
196+
# if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
197+
# from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
198+
# from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
199+
200+
# target_node = base_template_node2[0].node
201+
# # Currently only BMM, MM support prologue fusion
202+
# if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
203+
# return False
204+
205+
# if len(node1.read_writes.writes) != 1:
206+
# return False
207+
# if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
208+
# return False
209+
210+
# # We don't fuse this edge case...
211+
# if base_template_node2[0].group[1][0][0] == 1:
212+
# return False
213+
214+
# if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
215+
# node1 = self.revert_group(node1)
216+
# return True
221217
return False
222218

223219
def revert_group(self, act_nodes, args=None, var_ranges=None):
@@ -340,7 +336,7 @@ def codegen_template(self, template_node, epilogue_nodes, prologue_nodes):
340336
_, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs()
341337
src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes)
342338

343-
with V.set_kernel_handler(kernel):
339+
with kernel:
344340
kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info,
345341
kernel.loop_size, origins={str(i) for i in template_node.node.origins})
346342
self.define_function(kernel)

0 commit comments

Comments
 (0)