From 150f5b1af5c211797902b6662cae57a4f58962be Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Mon, 9 Mar 2026 17:18:49 +0200 Subject: [PATCH 1/2] Add gridsearch and genetic algorithm tuners --- examples/xegpu/README.md | 88 +++- examples/xegpu/csv_logger.py | 36 ++ examples/xegpu/genetic_algorithm.py | 218 ++++++++++ examples/xegpu/lit.local.cfg | 8 +- examples/xegpu/matmul.py | 137 ++++--- examples/xegpu/tune_matmul_ga.py | 144 +++++++ examples/xegpu/tune_matmul_gridsearch.py | 494 +++++++++++++++++++++++ 7 files changed, 1056 insertions(+), 69 deletions(-) create mode 100644 examples/xegpu/csv_logger.py create mode 100644 examples/xegpu/genetic_algorithm.py create mode 100644 examples/xegpu/tune_matmul_ga.py create mode 100644 examples/xegpu/tune_matmul_gridsearch.py diff --git a/examples/xegpu/README.md b/examples/xegpu/README.md index 812ef46e..138b87f8 100644 --- a/examples/xegpu/README.md +++ b/examples/xegpu/README.md @@ -65,7 +65,7 @@ export PYTHONPATH=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core/ ## Matrix multiplication benchmark -Run the default 4k (float16, float16) -> float32 matrix multiplication benchmark with correctness test: +Run the default 4k (float16, float16) -> float32 matrix-multipy-accumulate benchmark with correctness test: ```bash python matmul.py --check-result @@ -77,13 +77,25 @@ Set different M, N, K problem size python matmul.py --sizes 1024 2048 4096 ... ``` -Run with ReLU post-op: +To run matrix multiply (C = A * B) kernel instead of matrix-multiply-accumulate (C += A * B): ```bash -python matmul.py --relu ... +python matmul.py --no-accumulate-c ... ``` -See all command line arguments: +Run with bias and ReLU post-op: + +```bash +python matmul.py --bias --relu ... +``` + +Set tiling parameters from the command line: + +```bash +python matmul.py --wg-tile 128 256 ... +``` + +See all command-line arguments: ```bash python matmul.py --help @@ -124,8 +136,72 @@ Add bias to all layers and ReLU to hidden layers: python mlp.py --bias --relu ... ``` -See all command line arguments: +## Kernel tuning + +### Exhaustive grid search + +`tune_matmul_gridsearch.py` performs an exhaustive grid search on a matrix multiplication kernel. It takes similar arguments as the `matmul.py` benchmark: + +```bash +python tune_matmul_gridsearch.py --sizes 1024 2048 4096 --bias --relu --no-accumulate-c +``` + +The executed parameter combinations are stored in `out_gridsearch.csv` file along with the obtained performance metrics: + +```txt +M,N,K,wg_m,wg_n,sg_m,sg_n,k,load_a_m,load_a_k,load_b_k,load_b_n,pf_a_m,pf_a_k,pf_b_k,pf_b_n,pf_nb,time (us),GFLOPS/s +4096,4096,4096,64,256,32,32,64,8,16,16,16,8,16,8,16,1,???,??? +... +``` + +To get information about the search space (e.g., tile parameter choices) without actually executing the kernels run with `--dry-run` flag: + +```bash +python tune_matmul_gridsearch.py --dry-run +``` + +Example output: + +```txt +ab_type='f16' +c_type='f32' +has_bias=False +has_relu=False +accumulate_c=True +Variable set: +wg_m=[64, 128, 256] +wg_n=[64, 128, 256] +sg_m=[32, 64, 128] +sg_n=[32, 64, 128] +k=[16, 32, 64, 128, 256] +load_a_m=[8, 16, 32] +load_a_k=[8, 16, 32] +load_b_k=[8, 16, 32] +load_b_n=[8, 16, 32] +pf_a_m=[8, 16, 32] +pf_a_k=[8, 16, 32] +pf_b_k=[8, 16, 32] +pf_b_n=[8, 16, 32] +pf_nb=[1] +Total complexity: 2657205 configurations +Number of executed configurations: 3588 +``` + +Total complexity is the number of parameter combinations without any filtering. The number of executed configurations shows the number of valid combinations, i.e. ones that satisfy appropriate (e.g., hardware) constraints. + +To dump the best found configurations as JSON files at the end of the search, use `--dump-json n` flag where `n` stands for the number of best configurations. The files are named as `matmul_params_*_00.json` with increasing integer suffix (best configuration being 00). + +> [!NOTE] +> Running the grid search typically takes several hours to complete. + +### Adaptive sampling with Genetic Algorithm + +`tune_matmul_ga.py` employs a genetic algorithm for adaptive sampling to explore the kernel tuning search space. This approach is typically an order of magnitude faster while discovering high throughput parameter combinations. + +The command-line interface is similar to `tune_matmul_gridsearch.py`: ```bash -python mlp.py --help +python tune_matmul_ga.py --sizes 1024 2048 4096 --bias --relu --dump-json 10 ``` + +The executed parameter combinations are stored in `out_genetic_algorithm.csv` file. diff --git a/examples/xegpu/csv_logger.py b/examples/xegpu/csv_logger.py new file mode 100644 index 00000000..86c9c546 --- /dev/null +++ b/examples/xegpu/csv_logger.py @@ -0,0 +1,36 @@ +import logging +import csv +import os + + +class CSVLogger: + def __init__(self, filename: str = None): + self.filename = filename + self.header_written = False + self.fieldnames = None + self.logger = logging.getLogger("csv_logger") + self.logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + if not self.logger.hasHandlers(): + self.logger.addHandler(handler) + if self.filename is not None: + assert not os.path.exists(self.filename), ( + f"CSV file '{self.filename}' already exists" + ) + + def log(self, data: dict): + if self.fieldnames is None: + self.fieldnames = list(data.keys()) + write_header = not os.path.exists(self.filename) or not self.header_written + if write_header: + self.logger.info(",".join(self.fieldnames)) + self.logger.info(",".join(str(data[k]) for k in self.fieldnames)) + if self.filename is None: + return + with open(self.filename, mode="a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames) + if write_header: + writer.writeheader() + self.header_written = True + writer.writerow(data) diff --git a/examples/xegpu/genetic_algorithm.py b/examples/xegpu/genetic_algorithm.py new file mode 100644 index 00000000..493aab44 --- /dev/null +++ b/examples/xegpu/genetic_algorithm.py @@ -0,0 +1,218 @@ +""" +Genetic algorithm-based optimization of kernel parameters. +""" + +import numpy as np +import random +import time +from types import FunctionType + + +class Variable: + """Represents a single tunable parameter with list of valid choices.""" + + def __init__(self, name: str, choices: list): + self.name = name + self.choices = choices + + def random_sample(self) -> int: + return random.choice(self.choices) + + +class VariableSet: + """A tunable variable set forming the search space.""" + + def __init__(self, variables: list[Variable], is_valid_fn: FunctionType = None): + self.variables = variables + self.is_valid_fn = is_valid_fn + + def random_sample(self) -> list: + return [var.random_sample() for var in self.variables] + + def names(self) -> list[str]: + return [var.name for var in self.variables] + + def complexity(self) -> int: + """Return total number of unconstrained combinations.""" + total = 1 + for var in self.variables: + total *= len(var.choices) + return total + + def is_valid(self, sample: list) -> bool: + return self.is_valid_fn(self.sample_to_dict(sample)) + + def sample_to_dict(self, sample: list) -> dict: + assert len(sample) == len(self.variables) + return dict(zip(self.names(), sample)) + + def iterables(self) -> list: + return [v.choices for v in self.variables] + + def print(self): + print("Variable set:") + for v in self.variables: + print(f"{v.name}={v.choices}") + print(f"Total complexity: {self.complexity()} configurations") + + +class Population: + """A population of individuals drawn from the variable set.""" + + def __init__(self, variable_set: VariableSet, individuals: list = None): + self.variable_set = variable_set + self.individuals = individuals if individuals is not None else [] + self.fitness_scores = [] + self.generation = 0 + + def increment_generation(self): + self.generation += 1 + + def size(self) -> int: + return len(self.individuals) + + def sort(self): + scores = np.array(self.fitness_scores) + i_sorted = np.argsort(scores)[::-1] + self.individuals = [self.individuals[i] for i in i_sorted] + self.fitness_scores = [self.fitness_scores[i] for i in i_sorted] + + def extend(self, new_individuals: list, new_fitness: list): + assert len(new_individuals) == len(new_fitness) + for ind, fit in zip(new_individuals, new_fitness): + if ind not in self.individuals: + self.individuals.append(ind) + self.fitness_scores.append(fit) + + def shrink(self, nbest: int): + if nbest >= len(self.individuals): + return + self.sort() + self.individuals = self.individuals[:nbest] + self.fitness_scores = self.fitness_scores[:nbest] + + def print(self): + print( + f"\nPopulation of size {len(self.individuals)}, generation {self.generation}:" + ) + if not self.fitness_scores: + for individual in self.individuals: + print(f" {individual}") + else: + for individual, fitness in zip(self.individuals, self.fitness_scores): + print(f" {fitness:.2f}: {individual}") + print("\n") + + +def init_random_population(pop_size: int, variable_set: VariableSet) -> Population: + population = Population(variable_set=variable_set) + population.individuals = [] + i = 0 + while len(population.individuals) < pop_size: + sample = variable_set.random_sample() + if sample not in population.individuals and variable_set.is_valid(sample): + population.individuals.append(sample) + i += 1 + if i > pop_size * 10000 or i > 0.2 * variable_set.complexity(): + raise RuntimeError( + "Unable to initialize population with given constraints." + ) + return population + + +class GeneticAlgorithm: + def __init__( + self, + population: Population, + recombination_rate: float = 0.5, + mutation_rate: float = 0.001, + fertility_rate: float = 1.0, + evaluate_fitness: FunctionType = None, + ): + self.fixed_population_size = population.size() + self.population = population + self.recombination_rate = recombination_rate + self.mutation_rate = mutation_rate + self.fertility_rate = fertility_rate + self.evaluate_fitness = evaluate_fitness + self.ntrials = 50 + self.population_history = [] + self.fitness_history = [] + + def recombine_and_mutate(self, individuals: list) -> list: + variable_set = self.population.variable_set + # every individual gets an update from another donor + new_individuals = [] + npopulation = len(individuals) + for i in range(npopulation): + parent = individuals[i] + donor_idx = random.choice([j for j in range(npopulation) if j != i]) + donor = individuals[donor_idx] + for _ in range(self.ntrials): + child = parent.copy() + # perform recombination + # one gene is always copied from donor + force_idx = random.randint(0, len(child) - 1) + # a gene is copied from donor with probability recombination_rate + for j in range(len(child)): + if random.random() < self.recombination_rate or j == force_idx: + child[j] = donor[j] + # mutate + if random.random() < self.mutation_rate: + child[j] = variable_set.variables[j].random_sample() + if ( + child not in individuals + and child not in new_individuals + and variable_set.is_valid(child) + ): + new_individuals.append(child) + break + return new_individuals + + def initialize(self): + if not self.population.fitness_scores: + # evaluate fitness for the initial population + self.population.fitness_scores = [ + self.evaluate_fitness(*ind) for ind in self.population.individuals + ] + self.population.sort() + + def next_generation(self): + # select parents probabilistically based on fitness + nb_parents = int(self.population.size() * self.fertility_rate) + scores = np.array(self.population.fitness_scores) + default = scores.min() / 20 + scores[scores == 0] = default + parents = random.choices( + population=self.population.individuals, + k=nb_parents, + weights=scores, + ) + # get new set of individuals and extend population + new_individuals = self.recombine_and_mutate(parents) + new_fitness = [self.evaluate_fitness(*ind) for ind in new_individuals] + self.population.extend(new_individuals, new_fitness) + # keep only the best individuals + self.population.shrink(self.fixed_population_size) + self.population.increment_generation() + + def optimize(self, ngen: int, verbose: int = 0): + self.initialize() + tic = time.perf_counter() + for gen in range(ngen): + self.population_history.append(self.population.individuals.copy()) + self.fitness_history.append(self.population.fitness_scores.copy()) + self.next_generation() + if verbose: + best_individual = self.population.individuals[0] + best_fitness = self.population.fitness_scores[0] + scores = np.array(self.population.fitness_scores) + avg_fitness = scores[scores > 0].mean() + print( + f"Generation {self.population.generation:4d}: " + f" best: {best_fitness:.2f}, avg: {avg_fitness:.2f}," + f" best config: {best_individual}" + ) + toc = time.perf_counter() + if verbose: + print(f"\nTime spent in optimization: {toc - tic:.2f} s\n") diff --git a/examples/xegpu/lit.local.cfg b/examples/xegpu/lit.local.cfg index 5def7216..4a26f6a3 100644 --- a/examples/xegpu/lit.local.cfg +++ b/examples/xegpu/lit.local.cfg @@ -1 +1,7 @@ -config.excludes = ["parameter_selector.py", "xegpu_workload.py"] +config.excludes = [ + "csv_logger.py", + "genetic_algorithm.py", + "parameter_selector.py", + "tune_matmul_ga.py", + "xegpu_workload.py", +] diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 4d956219..a8c83bbe 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -12,6 +12,7 @@ import argparse import ctypes +import json from typing import Optional from functools import cached_property @@ -213,9 +214,10 @@ def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] -def parse_cli(): +def cli_parser(description="Matrix Multiplication using MLIR"): + """CLI argument parser for args shared with autotuner.""" parser = argparse.ArgumentParser( - description="Matrix Multiplication using MLIR", + description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( @@ -228,6 +230,26 @@ def parse_cli(): default=[4096, 4096, 4096], help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).", ) + parser.add_argument( + "--bias", + action="store_true", + help="Add bias after the matrix multiplication.", + ) + parser.add_argument( + "--relu", + action="store_true", + help="Add relu op after the matrix multiplication (and bias if any).", + ) + parser.add_argument( + "--no-accumulate-c", + action="store_true", + help="Compute plain matrix-multiply C=A*B instead of matrix-multiply-accumulate C+=A*B.", + ) + return parser + + +def parse_cli_args(): + parser = cli_parser() parser.add_argument( "--wg-tile", type=int, @@ -282,6 +304,11 @@ def parse_cli(): default=1, help="Number of initial prefetches.", ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the matrix multiplication.", + ) parser.add_argument( "--nruns", type=int, @@ -294,26 +321,6 @@ def parse_cli(): default=20, help="Number of warm-up iterations before benchmarking.", ) - parser.add_argument( - "--bias", - action="store_true", - help="Add bias after the matrix multiplication.", - ) - parser.add_argument( - "--relu", - action="store_true", - help="Add relu op after the matrix multiplication (and bias if any).", - ) - parser.add_argument( - "--no-accumulate-c", - action="store_true", - help="Compute plain matrix-multiply C=A*B instead of matrix-multiply-accumulate C+=A*B.", - ) - parser.add_argument( - "--check-result", - action="store_true", - help="Check the result of the matrix multiplication.", - ) parser.add_argument( "--dump-kernel", type=str, @@ -334,46 +341,53 @@ def parse_cli(): action="store_true", help="Dump transform schedule.", ) + parser.add_argument( + "--json", + help="Read problem sizes and tile parameters from a JSON file.", + ) args = parser.parse_args() return args if __name__ == "__main__": - args = parse_cli() - - M, N, K = args.sizes - - params = { - "m": M, - "n": N, - "k": K, - "wg_m": None if args.all_knobs else args.wg_tile[0], - "wg_n": None if args.all_knobs else args.wg_tile[1], - "sg_m": None if args.all_knobs else args.sg_tile[0], - "sg_n": None if args.all_knobs else args.sg_tile[1], - "k_tile": None if args.all_knobs else args.k_tile, - "load_a_m": None if args.all_knobs else args.load_tile_a[0], - "load_a_k": None if args.all_knobs else args.load_tile_a[1], - "load_b_k": None if args.all_knobs else args.load_tile_b[0], - "load_b_n": None if args.all_knobs else args.load_tile_b[1], - "prefetch_a_m": None if args.all_knobs else args.prefetch_tile_a[0], - "prefetch_a_k": None if args.all_knobs else args.prefetch_tile_a[1], - "prefetch_b_k": None if args.all_knobs else args.prefetch_tile_b[0], - "prefetch_b_n": None if args.all_knobs else args.prefetch_tile_b[1], - "prefetch_nb": args.nb_prefetch, - } + args = parse_cli_args() ab_type = "f16" c_type = "f32" + if args.json: + with open(args.json, "r") as f: + params = json.load(f) + else: + M, N, K = args.sizes + params = { + "m": M, + "n": N, + "k": K, + "wg_m": None if args.all_knobs else args.wg_tile[0], + "wg_n": None if args.all_knobs else args.wg_tile[1], + "sg_m": None if args.all_knobs else args.sg_tile[0], + "sg_n": None if args.all_knobs else args.sg_tile[1], + "k_tile": None if args.all_knobs else args.k_tile, + "load_a_m": None if args.all_knobs else args.load_tile_a[0], + "load_a_k": None if args.all_knobs else args.load_tile_a[1], + "load_b_k": None if args.all_knobs else args.load_tile_b[0], + "load_b_n": None if args.all_knobs else args.load_tile_b[1], + "prefetch_a_m": None if args.all_knobs else args.prefetch_tile_a[0], + "prefetch_a_k": None if args.all_knobs else args.prefetch_tile_a[1], + "prefetch_b_k": None if args.all_knobs else args.prefetch_tile_b[0], + "prefetch_b_n": None if args.all_knobs else args.prefetch_tile_b[1], + "prefetch_nb": args.nb_prefetch, + } + with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() wload = XeGPUMatMul( - M=M, - N=N, - K=K, + M=params["m"], + N=params["n"], + K=params["k"], ab_type=ab_type, c_type=c_type, has_bias=args.bias, @@ -404,17 +418,16 @@ def parse_cli(): def list2str(a): return ",".join(map(str, a)) - parts = [ - f"sizes={list2str(args.sizes)}", - f"dt={ab_type},{c_type}", - f"wg-tile={list2str(args.wg_tile)}", - f"sg-tile={list2str(args.sg_tile)}", - f"k-tile={args.k_tile}", - f"load-a-tile={list2str(args.load_tile_a)}", - f"load-b-tile={list2str(args.load_tile_b)}", - f"pf-a-tile={list2str(args.prefetch_tile_a)}", - f"pf-b-tile={list2str(args.prefetch_tile_b)}", - f"time(us): {elapsed:.2f}", - f"GFLOPS: {gflops:.2f}", - ] - print(" ".join(parts)) + print( + f"sizes={list2str(args.sizes)} " + f"dt={ab_type},{c_type} " + f"wg-tile={list2str(args.wg_tile)} " + f"sg-tile={list2str(args.sg_tile)} " + f"k-tile={args.k_tile} " + f"load-a-tile={list2str(args.load_tile_a)} " + f"load-b-tile={list2str(args.load_tile_b)} " + f"pf-a-tile={list2str(args.prefetch_tile_a)} " + f"pf-b-tile={list2str(args.prefetch_tile_b)} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f}" + ) diff --git a/examples/xegpu/tune_matmul_ga.py b/examples/xegpu/tune_matmul_ga.py new file mode 100644 index 00000000..3f36972a --- /dev/null +++ b/examples/xegpu/tune_matmul_ga.py @@ -0,0 +1,144 @@ +""" +Genetic algorithm-based optimization of kernel parameters. +""" + +from functools import cache +import sys +import os +from typing import Optional +import random +from matmul import cli_parser +from tune_matmul_gridsearch import ( + construct_search_space, + execute_and_log, + dump_configs_json, +) +from genetic_algorithm import ( + init_random_population, + GeneticAlgorithm, +) +from csv_logger import CSVLogger + + +def optimize_kernel( + sizes: list[int], + has_bias: bool, + has_relu: bool, + accumulate_c: bool, + ab_type: str = "f16", + c_type: str = "f32", + check_result: bool = True, + npopulation: int = 14, + ngenerations: int = 30, + mutation_rate: float = 0.001, + dump_json: int = 0, + random_seed: Optional[int] = None, +): + if random_seed: + # set random seed for reproducibility + random.seed(random_seed) + + # timeout for kernel execution in seconds + timeout = 50 + + # number of iterations in kernel timing is chosen adaptively + nwarmup = None + nruns = None + + # disable IGC compiler cache + os.environ["NEO_CACHE_PERSISTENT"] = "0" + + var_set, sample_to_dict = construct_search_space(*sizes) + print(f"Matmul problem size: {sizes}") + print(f"{ab_type=}") + print(f"{c_type=}") + print(f"{has_bias=}") + print(f"{has_relu=}") + print(f"{accumulate_c=}") + var_set.print() + sys.stdout.flush() + + csv_file = "out_genetic_algorithm.csv" + csv_logger = CSVLogger(csv_file) + + @cache + def evaluate_fitness(*parameters) -> float: + elapsed, gflops = execute_and_log( + csv_logger, + nruns, + nwarmup, + sample_to_dict(parameters), + check_result, + timeout=timeout, + ab_type=ab_type, + c_type=c_type, + has_bias=has_bias, + has_relu=has_relu, + accumulate_c=accumulate_c, + ) + return gflops + + pop = init_random_population(npopulation, var_set) + ga_optimizer = GeneticAlgorithm( + population=pop, + mutation_rate=mutation_rate, + evaluate_fitness=evaluate_fitness, + ) + + ga_optimizer.initialize() + pop.print() + ga_optimizer.optimize(ngen=ngenerations, verbose=1) + + nb_kernel_evals = evaluate_fitness.cache_info().currsize + print("Best configurations found:") + for params, gflops in zip(pop.individuals, pop.fitness_scores): + print(f" GFLOPS: {gflops:.2f}: {params}") + print(f"Number of kernel evaluations: {nb_kernel_evals}") + + if dump_json > 0: + configs = [sample_to_dict(p) for p in pop.individuals[:dump_json]] + sizes_str = "-".join(str(s) for s in sizes) + relu_str = "_relu" if has_relu else "" + bias_str = "_bias" if has_bias else "" + acc_str = "_acc" if accumulate_c else "" + prefix = ( + f"matmul_params_{sizes_str}_{ab_type}-{c_type}{bias_str}{relu_str}{acc_str}" + ) + dump_configs_json(configs, filename_prefix=prefix) + + +if __name__ == "__main__": + parser = cli_parser( + description="Optimize matmul kernel parameters using a genetic algorithm." + ) + parser.add_argument( + "--generations", + type=int, + default=30, + help="Number of generations for the genetic algorithm.", + ) + parser.add_argument( + "--dump-json", + dest="n_dump_json", + type=int, + default=0, + help="Dump the best n configurations as JSON files.", + ) + parser.add_argument( + "--no-check-result", + action="store_true", + help="Skip correctness check.", + ) + + args = parser.parse_args() + + optimize_kernel( + args.sizes, + args.bias, + args.relu, + not args.no_accumulate_c, + check_result=not args.no_check_result, + ngenerations=args.generations, + dump_json=args.n_dump_json, + random_seed=2, + ) diff --git a/examples/xegpu/tune_matmul_gridsearch.py b/examples/xegpu/tune_matmul_gridsearch.py new file mode 100644 index 00000000..032ceb70 --- /dev/null +++ b/examples/xegpu/tune_matmul_gridsearch.py @@ -0,0 +1,494 @@ +# RUN: %PYTHON %s --dry-run | FileCheck %s +# CHECK: Total complexity: 2657205 configurations +# CHECK: Number of executed configurations: 5292 + +from time import perf_counter +import multiprocessing +from multiprocessing.sharedctypes import Value +from ctypes import c_double +from datetime import timedelta +from itertools import product +import numpy as np +import os +import sys +import json +from csv_logger import CSVLogger + +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.workload import benchmark +from lighthouse.schedule.xegpu.mlp_schedule import DPAS + +from matmul import XeGPUMatMul, cli_parser +from genetic_algorithm import ( + Variable, + VariableSet, +) + + +def run_experiment( + ab_type: str = "f16", + c_type: str = "f32", + nruns: int = None, + nwarmup: int = None, + check_result: bool = False, + has_bias: bool = False, + has_relu: bool = False, + accumulate_c: bool = True, + **params, +) -> tuple[float, float]: + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + + wload = XeGPUMatMul( + M=params["m"], + N=params["n"], + K=params["k"], + ab_type=ab_type, + c_type=c_type, + has_bias=has_bias, + has_relu=has_relu, + accumulate_c=accumulate_c, + ) + if nruns is None and nwarmup is None: + # first run to estimate cost + times = benchmark( + wload, + nruns=10, + nwarmup=10, + schedule_parameters=params, + check_correctness=False, + verbose=0, + ) + # estimate number of runs + cost = times.mean() + warmup_target = 0.25 + nwarmup = max(int(warmup_target / cost), 10) + nruns = 3 * nwarmup + print(f"{nwarmup=} {nruns=}") + # benchmark + times = benchmark( + wload, + nruns=nruns, + nwarmup=nwarmup, + schedule_parameters=params, + check_correctness=check_result, + verbose=0, + ) + + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + return elapsed, gflops + + +def run_with_timeout(*args, timeout: int = 20, **kwargs) -> tuple[float, float]: + """ + Wrapper to execute the experiment with a new thread and a timeout. + + Experiments must be run in a new process to ensure reliable timings. + + Sends kill signal if timeout is reached. + """ + # wrap return values + timing = Value(c_double, 0.0) + gflops = Value(c_double, 0.0) + + def wrapped(timing, gflops, *args, **kwargs): + res = run_experiment(*args, **kwargs) + timing.value = res[0] + gflops.value = res[1] + + all_args = tuple([timing, gflops] + list(args)) + proc = multiprocessing.Process(target=wrapped, args=all_args, kwargs=kwargs) + proc.start() + proc.join(timeout) + if proc.is_alive(): + print("TIMEOUT") + proc.kill() + proc.join() + return 0, 0 + proc.close() + return timing.value, gflops.value + + +def execute_and_log( + csv_logger: CSVLogger, + nruns: int, + nwarmup: int, + params: dict, + check_result: bool = True, + ab_type: str = "f16", + c_type: str = "f32", + has_bias: bool = False, + has_relu: bool = False, + accumulate_c: bool = True, + timeout: int = 20, +) -> tuple[float, float]: + try: + tic = perf_counter() + elapsed, gflops = run_with_timeout( + ab_type=ab_type, + c_type=c_type, + nruns=nruns, + nwarmup=nwarmup, + check_result=check_result, + timeout=timeout, + has_bias=has_bias, + has_relu=has_relu, + accumulate_c=accumulate_c, + **params, + ) + duration = perf_counter() - tic + entry = params.copy() + entry["time (us)"] = elapsed + entry["GFLOPS/s"] = gflops + csv_logger.log(entry) + duration_str = f"Duration: {duration:.3f} s" + print(duration_str) + except Exception as e: + print("FAILED") + print(entry) + print(f" Error: {e}") + sys.stdout.flush() + return elapsed, gflops + + +def check_constraints(params: dict, verbose: bool = False) -> bool: + def print_reason(msg): + if verbose: + print(f" Invalid: {msg}") + + # hardware constraints + max_nb_sg_threads = 64 + load_max_rows = 32 + load_max_cols = 16 + pfetch_min_rows = 8 + pfetch_max_rows = 32 + pfetch_min_cols = 16 + pfetch_max_cols = 32 + + # heuristics: skip likely suboptimal configurations + min_nb_threads = 16 + + M = params["m"] + N = params["n"] + wg_tile_m = params["wg_m"] + wg_tile_n = params["wg_n"] + sg_tile_m = params["sg_m"] + sg_tile_n = params["sg_n"] + load_tile_a_m = params["load_a_m"] + load_tile_a_k = params["load_a_k"] + load_tile_b_k = params["load_b_k"] + load_tile_b_n = params["load_b_n"] + prefetch_tile_a_m = params["prefetch_a_m"] + prefetch_tile_a_k = params["prefetch_a_k"] + prefetch_tile_b_k = params["prefetch_b_k"] + prefetch_tile_b_n = params["prefetch_b_n"] + k_tile = params["k_tile"] + + if M % wg_tile_m != 0: + print_reason("wg_tile_m does not divide M") + return False + if N % wg_tile_n != 0: + print_reason("wg_tile_n does not divide N") + return False + if wg_tile_m % sg_tile_m != 0: + print_reason("sg_tile_m does not divide wg_tile_m") + return False + if wg_tile_n % sg_tile_n != 0: + print_reason("sg_tile_n does not divide wg_tile_n") + return False + if sg_tile_m % DPAS.M != 0: + print_reason("sg_tile_m not multiple of dpas_m") + return False + if sg_tile_n % DPAS.N != 0: + print_reason("sg_tile_n not multiple of dpas_n") + return False + if k_tile % DPAS.K != 0: + print_reason("k_tile not multiple of dpas_k") + return False + + # SG level thread layout: [nb_sg_threads_m, nb_sg_threads_n] + nb_sg_threads_m = wg_tile_m // sg_tile_m + nb_sg_threads_n = wg_tile_n // sg_tile_n + nb_sg_threads = nb_sg_threads_m * nb_sg_threads_n + if nb_sg_threads > max_nb_sg_threads: + print_reason("too many sg threads") + return False + if nb_sg_threads < min_nb_threads: + print_reason("too few sg threads") + return False + + if sg_tile_m % load_tile_a_m != 0: + print_reason("load_tile_a_m does not divide sg_tile_m") + return False + if k_tile % load_tile_a_k != 0: + print_reason("load_tile_a_k does not divide k_tile") + return False + if k_tile % load_tile_b_k != 0: + print_reason("load_tile_b_k does not divide k_tile") + return False + if sg_tile_n % load_tile_b_n != 0: + print_reason("load_tile_b_n does not divide sg_tile_n") + return False + if load_tile_a_m > load_max_rows: + print_reason("too large load_tile_a_m") + return False + if load_tile_a_k > load_max_cols: + print_reason("too large load_tile_a_k") + return False + if load_tile_b_k > load_max_rows: + print_reason("too large load_tile_b_k") + return False + if load_tile_b_n > load_max_cols: + print_reason("too large load_tile_b_n") + return False + if sg_tile_m % prefetch_tile_a_m != 0: + print_reason("prefetch_tile_a_m does not divide sg_tile_m") + return False + if k_tile % prefetch_tile_a_k != 0: + print_reason("prefetch_tile_a_k does not divide k_tile") + return False + if k_tile % prefetch_tile_b_k != 0: + print_reason("prefetch_tile_b_k does not divide k_tile") + return False + if sg_tile_n % prefetch_tile_b_n != 0: + print_reason("prefetch_tile_b_n does not divide sg_tile_n") + return False + if prefetch_tile_a_m > pfetch_max_rows: + print_reason("too large prefetch_tile_a_m") + return False + if prefetch_tile_a_k > pfetch_max_cols: + print_reason("too large prefetch_tile_a_k") + return False + if prefetch_tile_b_k > pfetch_max_rows: + print_reason("too large prefetch_tile_b_k") + return False + if prefetch_tile_b_n > pfetch_max_cols: + print_reason("too large prefetch_tile_b_n") + return False + if prefetch_tile_a_m < pfetch_min_rows: + print_reason("too small prefetch_tile_a_m") + return False + if prefetch_tile_a_k < pfetch_min_cols: + print_reason("too small prefetch_tile_a_k") + return False + if prefetch_tile_b_k < pfetch_min_rows: + print_reason("too small prefetch_tile_b_k") + return False + if prefetch_tile_b_n < pfetch_min_cols: + print_reason("too small prefetch_tile_b_n") + return False + if load_tile_a_m % DPAS.M != 0: + print_reason("load_tile_a_m not multiple of dpas_m") + return False + if load_tile_a_k % DPAS.K != 0: + print_reason("load_tile_a_k not multiple of dpas_k") + return False + if load_tile_b_k % DPAS.K != 0: + print_reason("load_tile_b_k not multiple of dpas_k") + return False + if load_tile_b_n % DPAS.N != 0: + print_reason("load_tile_b_n not multiple of dpas_n") + return False + + nb_load_b_n = load_tile_b_n // DPAS.N + if nb_load_b_n > 1: + # unsupported VNNI layout, loaded tile can only be row-sliced for vnni + # NOTE this can plausibly be relaxed + print_reason("invalid load_tile_b_n for VNNI") + return False + + # prefetch A layout + nb_prefetch_a_m = wg_tile_m // prefetch_tile_a_m + nb_prefetch_a_k = k_tile // prefetch_tile_a_k + if nb_prefetch_a_m * nb_prefetch_a_k > max_nb_sg_threads: + print_reason("too many prefetch A tiles") + return False + if nb_prefetch_a_m * nb_prefetch_a_k < min_nb_threads: + print_reason("too few prefetch A threads") + return False + + # prefetch B layout + nb_prefetch_b_k = k_tile // prefetch_tile_b_k + nb_prefetch_b_n = wg_tile_n // prefetch_tile_b_n + if nb_prefetch_b_k * nb_prefetch_b_n > max_nb_sg_threads: + print_reason("too many prefetch B tiles") + return False + if nb_prefetch_b_k * nb_prefetch_b_n < min_nb_threads: + print_reason("too few prefetch B threads") + return False + + return True + + +def get_divisors(n: int, min_tile: int = 32, max_tile: int = 256) -> list[int]: + p = np.ceil(n / max_tile) + q = n // min_tile + candidates = n / np.arange(max(p, 1), q + 1) + candidates = [int(v) for v in candidates if int(v) == v] + return candidates[::-1] + + +def divisible_by(a_list: list, b: int) -> list: + return [a for a in a_list if a % b == 0] + + +def construct_search_space(M: int, N: int, K: int): + wg_tile_lim_m = min(max(M // 4, 16), 64), min(M, 256) + wg_tile_lim_n = min(max(N // 4, 16), 64), min(N, 256) + sg_tile_lim_m = min(max(M // 8, 16), 32), min(M, 128) + sg_tile_lim_n = min(max(N // 8, 16), 32), min(N, 128) + + wg_tiles_m = divisible_by(get_divisors(M, *wg_tile_lim_m), DPAS.M) + wg_tiles_n = divisible_by(get_divisors(N, *wg_tile_lim_n), DPAS.N) + sg_tiles_m = divisible_by(get_divisors(M, *sg_tile_lim_m), DPAS.M) + sg_tiles_n = divisible_by(get_divisors(N, *sg_tile_lim_n), DPAS.N) + k_tiles = divisible_by(get_divisors(K, 16, min(K, 256)), DPAS.K) + load_tiles = [8, 16, 32] + prefetches = [1] + + def sample_is_valid(sample_params, verbose=False): + params = {"m": M, "n": N, "k": K} + params.update(sample_params) + return check_constraints(params, verbose=verbose) + + var_set = VariableSet( + [ + Variable("wg_m", wg_tiles_m), + Variable("wg_n", wg_tiles_n), + Variable("sg_m", sg_tiles_m), + Variable("sg_n", sg_tiles_n), + Variable("k_tile", k_tiles), + Variable("load_a_m", load_tiles), + Variable("load_a_k", load_tiles), + Variable("load_b_k", load_tiles), + Variable("load_b_n", load_tiles), + Variable("prefetch_a_m", load_tiles), + Variable("prefetch_a_k", load_tiles), + Variable("prefetch_b_k", load_tiles), + Variable("prefetch_b_n", load_tiles), + Variable("prefetch_nb", prefetches), + ], + is_valid_fn=sample_is_valid, + ) + + def sample_to_dict(sample: list) -> dict: + res = {"m": M, "n": N, "k": K} + res.update(var_set.sample_to_dict(sample)) + return res + + return var_set, sample_to_dict + + +def dump_configs_json(param_list: list[dict], filename_prefix: str = "matmul_params"): + print("\nSaving parameters:") + for i, params in enumerate(param_list): + filename = f"{filename_prefix}_{i:02d}.json" + with open(filename, "w") as f: + json.dump(params, f, indent=4) + print(f" {filename}") + + +if __name__ == "__main__": + parser = cli_parser( + description="Optimize matmul kernel parameters using a exhaustive search." + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Check validity of combinations but do not execute kernels.", + ) + parser.add_argument( + "--no-check-result", + action="store_true", + help="Skip correctness check.", + ) + parser.add_argument( + "--dump-json", + dest="n_dump_json", + type=int, + default=0, + help="Dump the best n configurations as JSON files.", + ) + args = parser.parse_args() + + sizes = args.sizes + has_bias = args.bias + has_relu = args.relu + accumulate_c = not args.no_accumulate_c + ab_type = "f16" + c_type = "f32" + + # timeout for kernel execution in seconds + timeout = 50 + + # number of iterations in kernel timing is chosen adaptively + nwarmup = None + nruns = None + + # disable IGC compiler cache + os.environ["NEO_CACHE_PERSISTENT"] = "0" + + if not args.dry_run: + csv_file = "out_gridsearch.csv" + csv_logger = CSVLogger(csv_file) + + var_set, sample_to_dict = construct_search_space(*sizes) + print(f"Matmul problem size: {sizes}") + print(f"{ab_type=}") + print(f"{c_type=}") + print(f"{has_bias=}") + print(f"{has_relu=}") + print(f"{accumulate_c=}") + var_set.print() + sys.stdout.flush() + + i = 0 + executed_configs = [] + tic = perf_counter() + for sample in product(*var_set.iterables()): + params = sample_to_dict(sample) + if not check_constraints(params, verbose=False): + continue + + i += 1 + if args.dry_run: + continue + time, gflops = execute_and_log( + csv_logger, + nruns, + nwarmup, + params, + check_result=not args.no_check_result, + timeout=timeout, + ab_type=ab_type, + c_type=c_type, + has_bias=has_bias, + has_relu=has_relu, + accumulate_c=accumulate_c, + ) + executed_configs.append((gflops, params)) + + duration = perf_counter() - tic + print(f"Number of executed configurations: {i}") + print(f"Total duration: {timedelta(seconds=duration)}") + + if args.n_dump_json > 0: + executed_configs.sort(key=lambda x: x[0], reverse=True) + best_configs = [c for c in executed_configs[: args.n_dump_json]] + print("Best configurations found:") + for gflops, params in best_configs: + print(f" GFLOPS: {gflops:.2f}: {params}") + sizes_str = "-".join(str(s) for s in sizes) + relu_str = "_relu" if has_relu else "" + bias_str = "_bias" if has_bias else "" + acc_str = "_acc" if accumulate_c else "" + prefix = ( + f"matmul_params_{sizes_str}_{ab_type}-{c_type}{bias_str}{relu_str}{acc_str}" + ) + dump_configs_json([params for _, params in best_configs]) From 9275ec08bdd9bde720cc3d8590f85d7019470aca Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 20 Mar 2026 17:03:22 +0200 Subject: [PATCH 2/2] address copilot comments --- examples/xegpu/README.md | 19 +++++++++-------- examples/xegpu/csv_logger.py | 26 ++++++++++++------------ examples/xegpu/genetic_algorithm.py | 4 ++++ examples/xegpu/matmul.py | 16 +++++++-------- examples/xegpu/tune_matmul_ga.py | 2 +- examples/xegpu/tune_matmul_gridsearch.py | 8 ++++---- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/examples/xegpu/README.md b/examples/xegpu/README.md index 138b87f8..8a59ffa3 100644 --- a/examples/xegpu/README.md +++ b/examples/xegpu/README.md @@ -65,7 +65,7 @@ export PYTHONPATH=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core/ ## Matrix multiplication benchmark -Run the default 4k (float16, float16) -> float32 matrix-multipy-accumulate benchmark with correctness test: +Run the default 4k (float16, float16) -> float32 matrix-multiply-accumulate benchmark with correctness test: ```bash python matmul.py --check-result @@ -149,7 +149,7 @@ python tune_matmul_gridsearch.py --sizes 1024 2048 4096 --bias --relu --no-accum The executed parameter combinations are stored in `out_gridsearch.csv` file along with the obtained performance metrics: ```txt -M,N,K,wg_m,wg_n,sg_m,sg_n,k,load_a_m,load_a_k,load_b_k,load_b_n,pf_a_m,pf_a_k,pf_b_k,pf_b_n,pf_nb,time (us),GFLOPS/s +m,n,k,wg_m,wg_n,sg_m,sg_n,k_tile,load_a_m,load_a_k,load_b_k,load_b_n,prefetch_a_m,prefetch_a_k,prefetch_b_k,prefetch_b_n,prefetch_nb,time (us),GFLOPS/s 4096,4096,4096,64,256,32,32,64,8,16,16,16,8,16,8,16,1,???,??? ... ``` @@ -163,6 +163,7 @@ python tune_matmul_gridsearch.py --dry-run Example output: ```txt +Matmul problem size: [4096, 4096, 4096] ab_type='f16' c_type='f32' has_bias=False @@ -173,18 +174,18 @@ wg_m=[64, 128, 256] wg_n=[64, 128, 256] sg_m=[32, 64, 128] sg_n=[32, 64, 128] -k=[16, 32, 64, 128, 256] +k_tile=[16, 32, 64, 128, 256] load_a_m=[8, 16, 32] load_a_k=[8, 16, 32] load_b_k=[8, 16, 32] load_b_n=[8, 16, 32] -pf_a_m=[8, 16, 32] -pf_a_k=[8, 16, 32] -pf_b_k=[8, 16, 32] -pf_b_n=[8, 16, 32] -pf_nb=[1] +prefetch_a_m=[8, 16, 32] +prefetch_a_k=[8, 16, 32] +prefetch_b_k=[8, 16, 32] +prefetch_b_n=[8, 16, 32] +prefetch_nb=[1] Total complexity: 2657205 configurations -Number of executed configurations: 3588 +Number of executed configurations: 5292 ``` Total complexity is the number of parameter combinations without any filtering. The number of executed configurations shows the number of valid combinations, i.e. ones that satisfy appropriate (e.g., hardware) constraints. diff --git a/examples/xegpu/csv_logger.py b/examples/xegpu/csv_logger.py index 86c9c546..915cdd17 100644 --- a/examples/xegpu/csv_logger.py +++ b/examples/xegpu/csv_logger.py @@ -8,11 +8,13 @@ def __init__(self, filename: str = None): self.filename = filename self.header_written = False self.fieldnames = None - self.logger = logging.getLogger("csv_logger") + self.logger = logging.getLogger( + "csv_logger_" + (filename if filename else "stdout") + ) self.logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter("%(message)s")) if not self.logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(message)s")) self.logger.addHandler(handler) if self.filename is not None: assert not os.path.exists(self.filename), ( @@ -22,15 +24,13 @@ def __init__(self, filename: str = None): def log(self, data: dict): if self.fieldnames is None: self.fieldnames = list(data.keys()) - write_header = not os.path.exists(self.filename) or not self.header_written - if write_header: + if not self.header_written: self.logger.info(",".join(self.fieldnames)) self.logger.info(",".join(str(data[k]) for k in self.fieldnames)) - if self.filename is None: - return - with open(self.filename, mode="a", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames) - if write_header: - writer.writeheader() - self.header_written = True - writer.writerow(data) + if self.filename is not None: + with open(self.filename, mode="a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames) + if not self.header_written: + writer.writeheader() + writer.writerow(data) + self.header_written = True diff --git a/examples/xegpu/genetic_algorithm.py b/examples/xegpu/genetic_algorithm.py index 493aab44..740bf771 100644 --- a/examples/xegpu/genetic_algorithm.py +++ b/examples/xegpu/genetic_algorithm.py @@ -40,6 +40,8 @@ def complexity(self) -> int: return total def is_valid(self, sample: list) -> bool: + if self.is_valid_fn is None: + return True return self.is_valid_fn(self.sample_to_dict(sample)) def sample_to_dict(self, sample: list) -> dict: @@ -183,6 +185,8 @@ def next_generation(self): scores = np.array(self.population.fitness_scores) default = scores.min() / 20 scores[scores == 0] = default + if all(s == 0 for s in scores): + scores = None # uniform if all scores are zero parents = random.choices( population=self.population.individuals, k=nb_parents, diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index a8c83bbe..be73fd52 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -419,15 +419,15 @@ def list2str(a): return ",".join(map(str, a)) print( - f"sizes={list2str(args.sizes)} " + f"sizes={list2str([params['m'], params['n'], params['k']])} " f"dt={ab_type},{c_type} " - f"wg-tile={list2str(args.wg_tile)} " - f"sg-tile={list2str(args.sg_tile)} " - f"k-tile={args.k_tile} " - f"load-a-tile={list2str(args.load_tile_a)} " - f"load-b-tile={list2str(args.load_tile_b)} " - f"pf-a-tile={list2str(args.prefetch_tile_a)} " - f"pf-b-tile={list2str(args.prefetch_tile_b)} " + f"wg-tile={list2str([params['wg_m'], params['wg_n']])} " + f"sg-tile={list2str([params['sg_m'], params['sg_n']])} " + f"k-tile={params['k_tile']} " + f"load-a-tile={list2str([params['load_a_m'], params['load_a_k']])} " + f"load-b-tile={list2str([params['load_b_k'], params['load_b_n']])} " + f"pf-a-tile={list2str([params['prefetch_a_m'], params['prefetch_a_k']])} " + f"pf-b-tile={list2str([params['prefetch_b_k'], params['prefetch_b_n']])} " f"time(us): {elapsed:.2f} " f"GFLOPS: {gflops:.2f}" ) diff --git a/examples/xegpu/tune_matmul_ga.py b/examples/xegpu/tune_matmul_ga.py index 3f36972a..54e166ee 100644 --- a/examples/xegpu/tune_matmul_ga.py +++ b/examples/xegpu/tune_matmul_ga.py @@ -34,7 +34,7 @@ def optimize_kernel( dump_json: int = 0, random_seed: Optional[int] = None, ): - if random_seed: + if random_seed is not None: # set random seed for reproducibility random.seed(random_seed) diff --git a/examples/xegpu/tune_matmul_gridsearch.py b/examples/xegpu/tune_matmul_gridsearch.py index 032ceb70..7cd6e8f1 100644 --- a/examples/xegpu/tune_matmul_gridsearch.py +++ b/examples/xegpu/tune_matmul_gridsearch.py @@ -128,6 +128,8 @@ def execute_and_log( accumulate_c: bool = True, timeout: int = 20, ) -> tuple[float, float]: + entry = params.copy() + elapsed, gflops = 0, 0 try: tic = perf_counter() elapsed, gflops = run_with_timeout( @@ -143,12 +145,10 @@ def execute_and_log( **params, ) duration = perf_counter() - tic - entry = params.copy() entry["time (us)"] = elapsed entry["GFLOPS/s"] = gflops csv_logger.log(entry) - duration_str = f"Duration: {duration:.3f} s" - print(duration_str) + print(f"Duration: {duration:.3f} s") except Exception as e: print("FAILED") print(entry) @@ -491,4 +491,4 @@ def dump_configs_json(param_list: list[dict], filename_prefix: str = "matmul_par prefix = ( f"matmul_params_{sizes_str}_{ab_type}-{c_type}{bias_str}{relu_str}{acc_str}" ) - dump_configs_json([params for _, params in best_configs]) + dump_configs_json([p for _, p in best_configs], filename_prefix=prefix)