2929from PyTorchSimFrontend .mlir .mlir_scheduling import SchedulerNode
3030from torch ._inductor .codegen import common
3131
32- from PyTorchSimFrontend .extension_config import CONFIG_TORCHSIM_DIR , CONFIG_AUTOTUNE_TEMPLATE_TOPK
32+ from PyTorchSimFrontend .extension_config import CONFIG_TORCHSIM_DIR , CONFIG_AUTOTUNE_TEMPLATE_TOPK , CONFIG_AUTOTUNE_TEMPLATE
3333from . import mlir_common
3434
3535class IndentedBufferGroup :
@@ -494,7 +494,7 @@ def make_choices(self, tile_candidates, render, template_node, prologue_nodes, e
494494 print (f"[Auto-tune] Trying tile size: { list (tile_info )} " )
495495 src_code = self .codegen_template_code (render , template_node , prologue_nodes , epilogue_nodes , tile_info )
496496 bench_runner = self .run_bench ([template_node ], self .kernel_name , src_code )
497- choices .append ((bench_runner , src_code , tile_info ))
497+ choices .append ((bench_runner , src_code , tile_info , self . loop_size ))
498498 self .reset (reason = None )
499499 return choices
500500
@@ -506,7 +506,12 @@ def _log_autotune_result(self, best_choice, best_cycle):
506506 )
507507
508508 def codegen_nodes (self , tile_candidates , render , template_node , prologue_nodes , epilogue_nodes ):
509- src_code = self .autotune (tile_candidates , render , template_node , prologue_nodes , epilogue_nodes )
509+ if CONFIG_AUTOTUNE_TEMPLATE and len (tile_candidates ):
510+ src_code , loop_size = self .autotune (tile_candidates , render , template_node , prologue_nodes , epilogue_nodes )
511+ self .loop_size = loop_size
512+ else :
513+ tile_info = tile_candidates [0 ] if tile_candidates else None
514+ src_code = self .codegen_template_code (render , template_node , prologue_nodes , epilogue_nodes , tile_info )
510515
511516 with V .set_kernel_handler (self ):
512517 self .meta_kernel ()
0 commit comments