Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/xegpu/enumerate_matmul_schedules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s | lh-tune - -n 1 | FileCheck %s
# RUN: %PYTHON %s | lh-tune - -n 2147483647 --count-only | FileCheck %s --check-prefix=ENUM-CHECK
# ENUM-CHECK: count: 174
# ENUM-CHECK: count: 192

"""Enumerate concrete schedules given a schedule with tunable parameters."""

Expand Down
6 changes: 3 additions & 3 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def parse_cli():
parser.add_argument(
"--k-tile",
type=int,
default=128,
default=64,
help="Inner reduction dimension tile size K.",
)
parser.add_argument(
Expand All @@ -266,14 +266,14 @@ def parse_cli():
"--prefetch-tile-a",
type=int,
nargs=2,
default=[8, 16],
default=[8, 32],
help="Tile size for cooperative prefetching of subgroup A matrix",
)
parser.add_argument(
"--prefetch-tile-b",
type=int,
nargs=2,
default=[16, 16],
default=[16, 32],
help="Tile size for cooperative prefetching of subgroup B matrix",
)
parser.add_argument(
Expand Down
100 changes: 50 additions & 50 deletions examples/xegpu/matmul_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
"n": 1024,
"k": 1024,
"wg_m": 256,
"wg_n": 64,
"wg_n": 128,
"sg_m": 32,
"sg_n": 32,
"k_tile": 128,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 16,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 16,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 8,
"prefetch_b_k": 16,
"prefetch_b_n": 16,
"prefetch_nb": 1
},
Expand All @@ -23,10 +23,10 @@
"n": 1024,
"k": 8192,
"wg_m": 256,
"wg_n": 64,
"wg_n": 256,
"sg_m": 32,
"sg_n": 32,
"k_tile": 256,
"k_tile": 32,
"load_a_m": 32,
"load_a_k": 16,
"load_b_k": 32,
Expand All @@ -41,18 +41,18 @@
"m": 1024,
"n": 8192,
"k": 1024,
"wg_m": 256,
"wg_n": 128,
"wg_m": 128,
"wg_n": 256,
"sg_m": 32,
"sg_n": 32,
"k_tile": 128,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 16,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 8,
"prefetch_b_k": 32,
"prefetch_b_n": 16,
"prefetch_nb": 1
},
Expand All @@ -62,16 +62,16 @@
"k": 16384,
"wg_m": 128,
"wg_n": 256,
"sg_m": 32,
"sg_m": 16,
"sg_n": 32,
"k_tile": 256,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 8,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 32,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 16,
"prefetch_b_k": 8,
"prefetch_b_n": 32,
"prefetch_nb": 1
},
Expand All @@ -83,15 +83,15 @@
"wg_n": 256,
"sg_m": 32,
"sg_n": 32,
"k_tile": 256,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 8,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 32,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 16,
"prefetch_b_n": 32,
"prefetch_b_k": 8,
"prefetch_b_n": 16,
"prefetch_nb": 1
},
{
Expand All @@ -100,54 +100,54 @@
"k": 16384,
"wg_m": 128,
"wg_n": 256,
"sg_m": 32,
"sg_m": 16,
"sg_n": 32,
"k_tile": 128,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 16,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 8,
"prefetch_a_m": 16,
"prefetch_a_k": 16,
"prefetch_b_k": 8,
"prefetch_b_n": 16,
"prefetch_b_k": 16,
"prefetch_b_n": 32,
"prefetch_nb": 1
},
{
"m": 128,
"n": 32768,
"k": 32768,
"wg_m": 128,
"wg_n": 128,
"wg_n": 256,
"sg_m": 32,
"sg_n": 32,
"k_tile": 128,
"load_a_m": 32,
"k_tile": 32,
"load_a_m": 16,
"load_a_k": 16,
"load_b_k": 32,
"load_b_n": 16,
"prefetch_a_m": 8,
"prefetch_a_m": 16,
"prefetch_a_k": 16,
"prefetch_b_k": 8,
"prefetch_b_n": 32,
"prefetch_b_n": 16,
"prefetch_nb": 1
},
{
"m": 128,
"n": 8192,
"k": 16384,
"wg_m": 128,
"wg_n": 128,
"sg_m": 32,
"sg_n": 32,
"wg_n": 256,
"sg_m": 16,
"sg_n": 64,
"k_tile": 128,
"load_a_m": 32,
"load_a_m": 16,
"load_a_k": 16,
"load_b_k": 32,
"load_b_k": 16,
"load_b_n": 16,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 8,
"prefetch_a_m": 16,
"prefetch_a_k": 32,
"prefetch_b_k": 32,
"prefetch_b_n": 16,
"prefetch_nb": 1
},
Expand All @@ -159,15 +159,15 @@
"wg_n": 256,
"sg_m": 32,
"sg_n": 32,
"k_tile": 128,
"k_tile": 64,
"load_a_m": 32,
"load_a_k": 16,
"load_b_k": 32,
"load_b_n": 16,
"prefetch_a_m": 8,
"prefetch_a_k": 16,
"prefetch_b_k": 16,
"prefetch_b_n": 16,
"prefetch_a_k": 32,
"prefetch_b_k": 8,
"prefetch_b_n": 32,
"prefetch_nb": 1
}
]
20 changes: 17 additions & 3 deletions lighthouse/schedule/xegpu/mlp_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,8 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N):
prefetch_tile_b = [prefetch_b_k, prefetch_b_n]

@td_smt_ext.constrain_params(
wg_m,
wg_n,
sg_m,
sg_n,
k_tile,
Expand All @@ -440,7 +442,19 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N):
prefetch_b_n,
)
def constrain_and_calculate_load_and_prefetch_params(
SG_M, SG_N, K_TILE, LDA_M, LDA_K, LDB_K, LDB_N, PFA_M, PFA_K, PFB_K, PFB_N
WG_M,
WG_N,
SG_M,
SG_N,
K_TILE,
LDA_M,
LDA_K,
LDB_K,
LDB_N,
PFA_M,
PFA_K,
PFB_K,
PFB_N,
):
# NB: normal asserts in case of concrete values, SMT assert ops for symbolic values
smt_ext.assert_(SG_M % LDA_M == 0)
Expand Down Expand Up @@ -478,15 +492,15 @@ def constrain_and_calculate_load_and_prefetch_params(
smt_ext.assert_(nb_load_b_n <= 1, "invalid load_tile_b_n for VNNI")

# prefetch A layout
prefetch_nb_a_m = SG_M // PFA_M
prefetch_nb_a_m = WG_M // PFA_M
prefetch_nb_a_k = K_TILE // PFA_K
prefetch_nb_a = prefetch_nb_a_m * prefetch_nb_a_k
smt_ext.assert_(prefetch_nb_a <= MAX_NB_SG_THREADS)
smt_ext.assert_(prefetch_nb_a_m * prefetch_nb_a_k >= MIN_NB_THREADS)

# prefetch B layout
prefetch_nb_b_k = K_TILE // PFB_K
prefetch_nb_b_n = SG_N // PFB_N
prefetch_nb_b_n = WG_N // PFB_N
prefetch_nb_b = prefetch_nb_b_k * prefetch_nb_b_n
smt_ext.assert_(prefetch_nb_b <= MAX_NB_SG_THREADS)
if isinstance(prefetch_nb_b, smt_ext.SMTIntValue):
Expand Down
Loading