Skip to content

Commit 8bc7bba

Browse files
YunseonShinYWHyuk
authored andcommitted
[Bert] xlarge model tile size
1 parent caa5c62 commit 8bc7bba

9 files changed

Lines changed: 19 additions & 303 deletions

File tree

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p
214214
input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K)
215215
output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N)
216216
used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision
217-
n_tile = math.ceil(N / tile_N)
217+
n_tile = math.ceil(M / tile_M) * math.ceil(N / tile_N)
218218
check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane)
219-
if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10 and tile_N < 2048:
219+
if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10:
220220
max_used_spad_size = used_spad_size
221221
maximize_i_j = tile_M * tile_N
222222
mapping = (tile_M, tile_N, tile_K)

validation/gemm_candidates/gemm_512_64_512.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

validation/gemm_candidates/gemm_512_768_3072.txt

Lines changed: 0 additions & 90 deletions
This file was deleted.

validation/gemm_candidates/gemm_512_768_768.txt

Lines changed: 0 additions & 48 deletions
This file was deleted.

validation/gemm_candidates/gemm_64_512_512.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

validation/gemm_candidates/gemm_768_3072_512.txt

Lines changed: 0 additions & 90 deletions
This file was deleted.

validation/gemm_candidates/gemm_768_768_512.txt

Lines changed: 0 additions & 48 deletions
This file was deleted.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"512_2048_8192" : {
3+
"TILE_M" : 512,
4+
"TILE_K" : 512,
5+
"TILE_N" : 1024
6+
},
7+
"512_2048_2048" : {
8+
"TILE_M" : 512,
9+
"TILE_K" : 512,
10+
"TILE_N" : 1024
11+
},
12+
"2048_2048_512" : {
13+
"TILE_M" : 1024,
14+
"TILE_K" : 512,
15+
"TILE_N" : 512
16+
}
17+
}

validation/gemm_tpuv3_cheetsheat.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)