Skip to content

Commit 62a1cd6

Browse files
committed
[Frontend] Make dma tag unique
1 parent 903ff13 commit 62a1cd6

3 files changed

Lines changed: 17 additions & 14 deletions

File tree

PyTorchSimFrontend/extension_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
# Backendsim config
3838
CONFIG_TORCHSIM_BACKEND_CONFIG = os.environ.get('TORCHSIM_CONFIG',
3939
default=f'{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json')
40-
CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", True))
40+
CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", False))
4141
CONFIG_BACKENDSIM_EAGER_MODE = int(os.environ.get("BACKENDSIM_EAGER_MODE", default=False))
4242
CONFIG_BACKENDSIM_DRYRUN = int(os.environ.get('BACKENDSIM_DRYRUN', default=False))
4343
CONFIG_BACKENDSIM_DEBUG_LEVEL = os.environ.get("BACKENDSIM_DEBUG_LEVEL", "")

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def __init__(self, kernel_group, reason=None):
877877
self.spadbuf_counter = 0
878878
self.dma_read_counter = 1
879879
self.dma_write_counter = 1
880+
self.dma_tag_id = 0
880881
self.affine_yield = {}
881882
self.welford_reduce_out = None
882883
self.reduce_iterator = {}
@@ -1028,7 +1029,7 @@ def load(self, name: str, index: sympy.Expr):
10281029
# MVIN Encoding
10291030
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}"
10301031
code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
1031-
f"{name}_tag", dram_shape, tile_shape, attribute)
1032+
dram_shape, tile_shape, attribute)
10321033
self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching
10331034
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
10341035
# Generate vector load instruction
@@ -1090,7 +1091,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
10901091
# Generate DMA instruction
10911092
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}"
10921093
code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
1093-
f"{name}_tag", dram_shape, tile_shape, attribute)
1094+
dram_shape, tile_shape, attribute)
10941095
self.dma_stores.writeline(common.DeferredLine(name, code))
10951096

10961097
def reduction(self, dtype, src_dtype, reduction_type, value):
@@ -1243,7 +1244,7 @@ def store_reduction(self, name, index, value):
12431244
# Generate DMA instruction
12441245
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}"
12451246
code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
1246-
f"{name}_tag", dram_shape, tile_shape, attribute)
1247+
dram_shape, tile_shape, attribute)
12471248
self.reductions_suffix.writeline(common.DeferredLine(name, code))
12481249

12491250
# Restore origin cse
@@ -1655,7 +1656,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
16551656
return local_tile_desc, index_var, dram_stride
16561657

16571658
def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var,
1658-
tag_name, dram_shape, tile_shape, attribute):
1659+
dram_shape, tile_shape, attribute):
16591660
dma_key = (vlane_split_axis, vlane_stride, mlir_dtype)
16601661
if dma_type_name == "MVIN" and dma_key in self.dma_read_cache:
16611662
dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key]
@@ -1670,9 +1671,9 @@ def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype
16701671
self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride]
16711672
else:
16721673
dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"])
1673-
# self.dma_write_counter += 1 Is it okay?
1674+
self.dma_write_counter += 1
16741675
self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride]
1675-
tag = self.get_tag_cse(tag_name)
1676+
tag = self.get_tag_cse()
16761677
zero_cse = self.get_const_cse(0)
16771678

16781679
# Prepare opearnds and attributes
@@ -1742,9 +1743,12 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable:
17421743
self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}")
17431744
return self.consts[str(value)+dtype]
17441745

1745-
def get_tag_cse(self, value, shape="memref<1xi32>"):
1746+
def get_tag_cse(self, value=None, shape="memref<1xi32>"):
1747+
if value is None:
1748+
value = self.dma_tag_id
1749+
self.dma_tag_id += 1
17461750
if value not in self.tags:
1747-
self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape}")
1751+
self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape} // {value}")
17481752
return self.tags[value]
17491753

17501754
def get_mask(self):

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com
663663
# Prepare code block
664664
local_code = IndentedBuffer()
665665
with V.set_kernel_handler(self):
666-
tag = f"mvint_{self.dma_read_counter}" if dma_type == "MVIN" else f"mvoutt_{self.dma_write_counter}"
667666
index_var = self.parse_index_list(index_list, local_code)
668667
node_layout = self.named_nodes[dram_var].get_layout()
669668
numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel()
@@ -696,7 +695,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com
696695
attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}")
697696
attribute = " {" + ", ".join(attribute_parts) + "}"
698697
code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
699-
tag, dram_shape, tile_shape, "")
698+
dram_shape, tile_shape, "")
700699
local_code.writeline(code)
701700
local_code.writeline(attribute)
702701
return textwrap.indent(local_code.getvalue(), " "*indent_size).strip()
@@ -749,7 +748,7 @@ def load_epilogue(self, name: str, index: sympy.Expr):
749748
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index)
750749
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}"
751750
code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
752-
f"{name}_tag", dram_shape, tile_shape, attribute)
751+
dram_shape, tile_shape, attribute)
753752
self.cse.generate(self.dma_loads, code, assignment = False)
754753
self.buffer_names[name] = sram_var
755754
else:
@@ -831,7 +830,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
831830
# Generate DMA instruction
832831
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}"
833832
code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
834-
f"{name}_tag", dram_shape, tile_shape, attribute)
833+
dram_shape, tile_shape, attribute)
835834
self.dma_stores.writeline(DeferredLine(name, code))
836835

837836
def reduction_epilogue(self, dtype, src_dtype, reduction_type, value):
@@ -991,7 +990,7 @@ def store_reduction_epilogue(self, name, index, value):
991990
# Generate DMA instruction
992991
attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}"
993992
code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
994-
f"{name}_tag", dram_shape, final_tile_shape, attribute)
993+
dram_shape, final_tile_shape, attribute)
995994
self.reductions_suffix.writeline(DeferredLine(name, code))
996995

997996
def set_tile_size(self, template_fusion_info, prologue=False):

0 commit comments

Comments
 (0)