@@ -877,6 +877,7 @@ def __init__(self, kernel_group, reason=None):
877877 self .spadbuf_counter = 0
878878 self .dma_read_counter = 1
879879 self .dma_write_counter = 1
880+ self .dma_tag_id = 0
880881 self .affine_yield = {}
881882 self .welford_reduce_out = None
882883 self .reduce_iterator = {}
@@ -1028,7 +1029,7 @@ def load(self, name: str, index: sympy.Expr):
10281029 # MVIN Encoding
10291030 attribute = f"{{dram_stride={ dram_stride } , sram_stride={ tile_stride } , padding={ padding } }}"
10301031 code = self .get_dma_code ("MVIN" , vlane_split_axis , vlane_stride , mlir_dtype , dram_var , index_var , sram_var , sram_index_var ,
1031- f" { name } _tag" , dram_shape , tile_shape , attribute )
1032+ dram_shape , tile_shape , attribute )
10321033 self .cse .generate (self .dma_loads , code , assignment = False ) # FIXME: assignment = False does not support caching
10331034 compute_index_var = "," .join (sram_index_var .split ("," )[:- 1 ] + [f"%{ self .compute_idx } " ])
10341035 # Generate vector load instruction
@@ -1090,7 +1091,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
10901091 # Generate DMA instruction
10911092 attribute = f"{{dram_stride={ dram_stride } , sram_stride={ tile_stride } , padding=0}}"
10921093 code = self .get_dma_code ("MVOUT" , vlane_split_axis , vlane_stride , mlir_dtype , dram_var , index_var , sram_var , sram_index_var ,
1093- f" { name } _tag" , dram_shape , tile_shape , attribute )
1094+ dram_shape , tile_shape , attribute )
10941095 self .dma_stores .writeline (common .DeferredLine (name , code ))
10951096
10961097 def reduction (self , dtype , src_dtype , reduction_type , value ):
@@ -1243,7 +1244,7 @@ def store_reduction(self, name, index, value):
12431244 # Generate DMA instruction
12441245 attribute = f"{{dram_stride={ dram_stride } , sram_stride={ tile_stride } , padding=0}}"
12451246 code = self .get_dma_code ("MVOUT" , vlane_split_axis , vlane_stride , mlir_dtype , dram_var , index_var , sram_var , sram_index_var ,
1246- f" { name } _tag" , dram_shape , tile_shape , attribute )
1247+ dram_shape , tile_shape , attribute )
12471248 self .reductions_suffix .writeline (common .DeferredLine (name , code ))
12481249
12491250 # Restore origin cse
@@ -1655,7 +1656,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
16551656 return local_tile_desc , index_var , dram_stride
16561657
16571658 def get_dma_code (self , dma_type_name , vlane_split_axis , vlane_stride , mlir_dtype , dram_var , dram_index_var , sram_var , sram_index_var ,
1658- tag_name , dram_shape , tile_shape , attribute ):
1659+ dram_shape , tile_shape , attribute ):
16591660 dma_key = (vlane_split_axis , vlane_stride , mlir_dtype )
16601661 if dma_type_name == "MVIN" and dma_key in self .dma_read_cache :
16611662 dma_type , vlane_split_axis , vlane_stride = self .dma_read_cache [dma_key ]
@@ -1670,9 +1671,9 @@ def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype
16701671 self .dma_read_cache [dma_key ] = [dma_type , vlane_split_axis , vlane_stride ]
16711672 else :
16721673 dma_type = self .get_const_cse (DMA_TYPE [f"{ dma_type_name } { self .dma_write_counter } " ])
1673- # self.dma_write_counter += 1 Is it okay?
1674+ self .dma_write_counter += 1
16741675 self .dma_write_cache [dma_key ] = [dma_type , vlane_split_axis , vlane_stride ]
1675- tag = self .get_tag_cse (tag_name )
1676+ tag = self .get_tag_cse ()
16761677 zero_cse = self .get_const_cse (0 )
16771678
16781679 # Prepare opearnds and attributes
@@ -1742,9 +1743,12 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable:
17421743 self .consts [str (value )+ dtype ] = self .const_cse .generate (self .const_buffer , f"arith.constant { value } : { dtype } " )
17431744 return self .consts [str (value )+ dtype ]
17441745
1745- def get_tag_cse (self , value , shape = "memref<1xi32>" ):
1746+ def get_tag_cse (self , value = None , shape = "memref<1xi32>" ):
1747+ if value is None :
1748+ value = self .dma_tag_id
1749+ self .dma_tag_id += 1
17461750 if value not in self .tags :
1747- self .tags [value ] = self .alloc_cse .generate (self .alloc_buffer , f"memref.alloc() : { shape } " )
1751+ self .tags [value ] = self .alloc_cse .generate (self .alloc_buffer , f"memref.alloc() : { shape } // { value } " )
17481752 return self .tags [value ]
17491753
17501754 def get_mask (self ):
0 commit comments