From 7ad4a7a3f4739381a3032766378ada9531d608fc Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 18 Mar 2026 15:53:10 +0200 Subject: [PATCH 1/3] fix prefetch layout calculation in schedule --- lighthouse/schedule/xegpu/mlp_schedule.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 1eed63a..e9fa790 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -427,6 +427,8 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): prefetch_tile_b = [prefetch_b_k, prefetch_b_n] @td_smt_ext.constrain_params( + wg_m, + wg_n, sg_m, sg_n, k_tile, @@ -440,7 +442,19 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): prefetch_b_n, ) def constrain_and_calculate_load_and_prefetch_params( - SG_M, SG_N, K_TILE, LDA_M, LDA_K, LDB_K, LDB_N, PFA_M, PFA_K, PFB_K, PFB_N + WG_M, + WG_N, + SG_M, + SG_N, + K_TILE, + LDA_M, + LDA_K, + LDB_K, + LDB_N, + PFA_M, + PFA_K, + PFB_K, + PFB_N, ): # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values smt_ext.assert_(SG_M % LDA_M == 0) @@ -478,7 +492,7 @@ def constrain_and_calculate_load_and_prefetch_params( smt_ext.assert_(nb_load_b_n <= 1, "invalid load_tile_b_n for VNNI") # prefetch A layout - prefetch_nb_a_m = SG_M // PFA_M + prefetch_nb_a_m = WG_M // PFA_M prefetch_nb_a_k = K_TILE // PFA_K prefetch_nb_a = prefetch_nb_a_m * prefetch_nb_a_k smt_ext.assert_(prefetch_nb_a <= MAX_NB_SG_THREADS) @@ -486,7 +500,7 @@ def constrain_and_calculate_load_and_prefetch_params( # prefetch B layout prefetch_nb_b_k = K_TILE // PFB_K - prefetch_nb_b_n = SG_N // PFB_N + prefetch_nb_b_n = WG_N // PFB_N prefetch_nb_b = prefetch_nb_b_k * prefetch_nb_b_n smt_ext.assert_(prefetch_nb_b <= MAX_NB_SG_THREADS) if isinstance(prefetch_nb_b, smt_ext.SMTIntValue): From 041bc9339ef03fec579ed58db09ccd6c994152f9 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 19 Mar 2026 15:27:59 +0200 Subject: [PATCH 2/3] parameter selector: update matmul params Re-ran ga optimization after fixing prefetch layout calculation. --- examples/xegpu/matmul.py | 6 +- examples/xegpu/matmul_params.json | 100 +++++++++++++++--------------- 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index f3425f1..4d95621 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -245,7 +245,7 @@ def parse_cli(): parser.add_argument( "--k-tile", type=int, - default=128, + default=64, help="Inner reduction dimension tile size K.", ) parser.add_argument( @@ -266,14 +266,14 @@ def parse_cli(): "--prefetch-tile-a", type=int, nargs=2, - default=[8, 16], + default=[8, 32], help="Tile size for cooperative prefetching of subgroup A matrix", ) parser.add_argument( "--prefetch-tile-b", type=int, nargs=2, - default=[16, 16], + default=[16, 32], help="Tile size for cooperative prefetching of subgroup B matrix", ) parser.add_argument( diff --git a/examples/xegpu/matmul_params.json b/examples/xegpu/matmul_params.json index 26b1fa3..16ddd52 100644 --- a/examples/xegpu/matmul_params.json +++ b/examples/xegpu/matmul_params.json @@ -4,17 +4,17 @@ "n": 1024, "k": 1024, "wg_m": 256, - "wg_n": 64, + "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k_tile": 128, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 16, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, - "prefetch_a_m": 16, + "prefetch_a_m": 8, "prefetch_a_k": 16, - "prefetch_b_k": 8, + "prefetch_b_k": 16, "prefetch_b_n": 16, "prefetch_nb": 1 }, @@ -23,10 +23,10 @@ "n": 1024, "k": 8192, "wg_m": 256, - "wg_n": 64, + "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k_tile": 256, + "k_tile": 32, "load_a_m": 32, "load_a_k": 16, "load_b_k": 32, @@ -41,18 +41,18 @@ "m": 1024, "n": 8192, "k": 1024, - "wg_m": 256, - "wg_n": 128, + "wg_m": 128, + "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k_tile": 128, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 16, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, "prefetch_a_m": 8, "prefetch_a_k": 16, - "prefetch_b_k": 8, + "prefetch_b_k": 32, "prefetch_b_n": 16, "prefetch_nb": 1 }, @@ -62,16 +62,16 @@ "k": 16384, "wg_m": 128, "wg_n": 256, - "sg_m": 32, + "sg_m": 16, "sg_n": 32, - "k_tile": 256, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 8, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, - "prefetch_a_m": 32, + "prefetch_a_m": 8, "prefetch_a_k": 16, - "prefetch_b_k": 16, + "prefetch_b_k": 8, "prefetch_b_n": 32, "prefetch_nb": 1 }, @@ -83,15 +83,15 @@ "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k_tile": 256, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 8, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, - "prefetch_a_m": 32, + "prefetch_a_m": 8, "prefetch_a_k": 16, - "prefetch_b_k": 16, - "prefetch_b_n": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 16, "prefetch_nb": 1 }, { @@ -100,17 +100,17 @@ "k": 16384, "wg_m": 128, "wg_n": 256, - "sg_m": 32, + "sg_m": 16, "sg_n": 32, - "k_tile": 128, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 16, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, - "prefetch_a_m": 8, + "prefetch_a_m": 16, "prefetch_a_k": 16, - "prefetch_b_k": 8, - "prefetch_b_n": 16, + "prefetch_b_k": 16, + "prefetch_b_n": 32, "prefetch_nb": 1 }, { @@ -118,18 +118,18 @@ "n": 32768, "k": 32768, "wg_m": 128, - "wg_n": 128, + "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k_tile": 128, - "load_a_m": 32, + "k_tile": 32, + "load_a_m": 16, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "prefetch_a_m": 8, + "prefetch_a_m": 16, "prefetch_a_k": 16, "prefetch_b_k": 8, - "prefetch_b_n": 32, + "prefetch_b_n": 16, "prefetch_nb": 1 }, { @@ -137,17 +137,17 @@ "n": 8192, "k": 16384, "wg_m": 128, - "wg_n": 128, - "sg_m": 32, - "sg_n": 32, + "wg_n": 256, + "sg_m": 16, + "sg_n": 64, "k_tile": 128, - "load_a_m": 32, + "load_a_m": 16, "load_a_k": 16, - "load_b_k": 32, + "load_b_k": 16, "load_b_n": 16, - "prefetch_a_m": 8, - "prefetch_a_k": 16, - "prefetch_b_k": 8, + "prefetch_a_m": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 32, "prefetch_b_n": 16, "prefetch_nb": 1 }, @@ -159,15 +159,15 @@ "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k_tile": 128, + "k_tile": 64, "load_a_m": 32, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, "prefetch_a_m": 8, - "prefetch_a_k": 16, - "prefetch_b_k": 16, - "prefetch_b_n": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 32, "prefetch_nb": 1 } ] From 4e081fa1ec5db74d50e450420ebe1cc1fec97036 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 20 Mar 2026 14:45:21 +0200 Subject: [PATCH 3/3] fix enum test --- examples/xegpu/enumerate_matmul_schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/xegpu/enumerate_matmul_schedules.py b/examples/xegpu/enumerate_matmul_schedules.py index 7b60f5b..a03ba9b 100644 --- a/examples/xegpu/enumerate_matmul_schedules.py +++ b/examples/xegpu/enumerate_matmul_schedules.py @@ -1,6 +1,6 @@ # RUN: %PYTHON %s | lh-tune - -n 1 | FileCheck %s # RUN: %PYTHON %s | lh-tune - -n 2147483647 --count-only | FileCheck %s --check-prefix=ENUM-CHECK -# ENUM-CHECK: count: 174 +# ENUM-CHECK: count: 192 """Enumerate concrete schedules given a schedule with tunable parameters."""