Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
37c5338
Typo.
isazi May 23, 2025
6c5b360
Add missing parameter to the interface.
isazi May 23, 2025
a21caf8
Formatting.
isazi May 23, 2025
a2328c4
First early draft of the parallel runner.
isazi Jun 5, 2025
c5896ad
Merge branch 'master' into parallel_runner
isazi Jun 5, 2025
68a569b
Need a dummy DeviceInterface even on the master.
isazi Jun 5, 2025
9d0dee4
Missing device_options in state.
isazi Jun 5, 2025
aff21f0
Flatten the results.
isazi Jun 5, 2025
d7e8cae
Various bug fixes.
isazi Jun 5, 2025
b4ff7fa
Add another example for the parallel runner.
isazi Jun 6, 2025
dd4f5ff
Merge branch 'master' into parallel_runner
isazi Jun 12, 2025
5cb0243
Merge branch 'master' into parallel_runner
isazi Jun 20, 2025
c4f7f32
Merge branch 'master' into parallel_runner
isazi Jul 1, 2025
dd4a4ed
Merge branch 'master' into parallel_runner
isazi Jul 8, 2025
e322824
Merge branch 'master' into parallel_runner
isazi Aug 12, 2025
426dd2a
Rewrite parallel runner to use stateful actors
stijnh Jan 19, 2026
baf4fd1
Merge branch 'master' into parallel_runner
stijnh Jan 19, 2026
f585d42
Move `tuning_options` to constructor of `ParallelRunner`
stijnh Jan 20, 2026
ad55ba4
Fix several errors related to parallel runner
stijnh Jan 20, 2026
4d8f4f5
Extend several strategies with support for parallel tuning: DiffEvo, …
stijnh Jan 20, 2026
fd41333
Add `pcu_bus_id` to environment for Nvidia backends
stijnh Jan 27, 2026
96e168d
Add support `eval_all` in `CostFunc`
stijnh Jan 27, 2026
d7129cd
Remove `return_raw` from `CostFunc` as it is unused
stijnh Jan 27, 2026
57fd617
Fix timings and handling of duplicate jobs in parallel runner
stijnh Jan 27, 2026
e1259b1
fix bug for continuous optimization
benvanwerkhoven Jan 29, 2026
72bfe94
fix test_time_keeping test ensuring at least two GA generations
benvanwerkhoven Jan 29, 2026
cc6bb97
fix tests needing more context for tuning_options
benvanwerkhoven Jan 29, 2026
ce8123a
do not count invalid for unique_results and avoid overshooting budget
benvanwerkhoven Jan 29, 2026
f6c63bf
fix for not overshooting/undershooting budget
benvanwerkhoven Jan 29, 2026
ce7e330
fix time accounting when using batched costfunc
benvanwerkhoven Jan 29, 2026
faf6cdc
fix timing issues
benvanwerkhoven Jan 30, 2026
35e61fb
merge master
benvanwerkhoven Jan 30, 2026
ec30052
fix budget overshoot issue for sequential tuning
benvanwerkhoven Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/cuda/sepconv_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
import numpy
from kernel_tuner import tune_kernel
from collections import OrderedDict


def tune():
with open("convolution.cu", "r") as f:
kernel_string = f.read()

# setup tunable parameters
tune_params = OrderedDict()
tune_params["filter_height"] = [i for i in range(3, 19, 2)]
tune_params["filter_width"] = [i for i in range(3, 19, 2)]
tune_params["block_size_x"] = [16 * i for i in range(1, 65)]
tune_params["block_size_y"] = [2**i for i in range(6)]
tune_params["tile_size_x"] = [i for i in range(1, 11)]
tune_params["tile_size_y"] = [i for i in range(1, 11)]

tune_params["use_padding"] = [0, 1] # toggle the insertion of padding in shared memory
tune_params["read_only"] = [0, 1] # toggle using the read-only cache

# limit the search to only use padding when its effective, and at least 32 threads in a block
restrict = ["use_padding==0 or (block_size_x % 32 != 0)", "block_size_x*block_size_y >= 32"]

# setup input and output dimensions
problem_size = (4096, 4096)
size = numpy.prod(problem_size)
largest_fh = max(tune_params["filter_height"])
largest_fw = max(tune_params["filter_width"])
input_size = (problem_size[0] + largest_fw - 1) * (problem_size[1] + largest_fh - 1)

# create input data
output_image = numpy.zeros(size).astype(numpy.float32)
input_image = numpy.random.randn(input_size).astype(numpy.float32)
filter_weights = numpy.random.randn(largest_fh * largest_fw).astype(numpy.float32)

# setup kernel arguments
cmem_args = {"d_filter": filter_weights}
args = [output_image, input_image, filter_weights]

# tell the Kernel Tuner how to compute grid dimensions
grid_div_x = ["block_size_x", "tile_size_x"]
grid_div_y = ["block_size_y", "tile_size_y"]

# start tuning separable convolution (row)
tune_params["filter_height"] = [1]
tune_params["tile_size_y"] = [1]
results_row = tune_kernel(
"convolution_kernel",
kernel_string,
problem_size,
args,
tune_params,
grid_div_y=grid_div_y,
grid_div_x=grid_div_x,
cmem_args=cmem_args,
verbose=False,
restrictions=restrict,
parallel_runner=1024,
cache="convolution_kernel_row",
)

# start tuning separable convolution (col)
tune_params["filter_height"] = tune_params["filter_width"][:]
tune_params["file_size_y"] = tune_params["tile_size_x"][:]
tune_params["filter_width"] = [1]
tune_params["tile_size_x"] = [1]
results_col = tune_kernel(
"convolution_kernel",
kernel_string,
problem_size,
args,
tune_params,
grid_div_y=grid_div_y,
grid_div_x=grid_div_x,
cmem_args=cmem_args,
verbose=False,
restrictions=restrict,
parallel_runner=1024,
cache="convolution_kernel_col",
)

return results_row, results_col


if __name__ == "__main__":
results_row, results_col = tune()
35 changes: 35 additions & 0 deletions examples/cuda/vector_add_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python

import numpy
from kernel_tuner import tune_kernel


def tune():
kernel_string = """
__global__ void vector_add(float *c, float *a, float *b, int n) {
int i = (blockIdx.x * block_size_x) + threadIdx.x;
if ( i < n ) {
c[i] = a[i] + b[i];
}
}
"""

size = 10000000

a = numpy.random.randn(size).astype(numpy.float32)
b = numpy.random.randn(size).astype(numpy.float32)
c = numpy.zeros_like(b)
n = numpy.int32(size)

args = [c, a, b, n]

tune_params = dict()
tune_params["block_size_x"] = [32 * i for i in range(1, 33)]

results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, parallel_workers=True)
print(env)
return results


if __name__ == "__main__":
tune()
1 change: 1 addition & 0 deletions kernel_tuner/backends/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
s.split(":")[0].strip(): s.split(":")[1].strip() for s in cupy_info
}
env["device_name"] = info_dict[f"Device {device} Name"]
env["pci_bus_id"] = info_dict[f"Device {device} PCI Bus ID"]

env["cuda_version"] = cp.cuda.runtime.driverGetVersion()
env["compute_capability"] = self.cc
Expand Down
1 change: 1 addition & 0 deletions kernel_tuner/backends/nvcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
cuda_error_check(err)
env = dict()
env["device_name"] = device_properties.name.decode()
env["pci_bus_id"] = device_properties.pciBusID
env["cuda_version"] = driver.CUDA_VERSION
env["compute_capability"] = self.cc
env["iterations"] = self.iterations
Expand Down
1 change: 1 addition & 0 deletions kernel_tuner/backends/pycuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _finish_up():
# collect environment information
env = dict()
env["device_name"] = self.context.get_device().name()
env["pci_bus_id"] = self.context.get_device().pci_bus_id()
env["cuda_version"] = ".".join([str(i) for i in drv.get_version()])
env["compute_capability"] = self.cc
env["iterations"] = self.iterations
Expand Down
34 changes: 27 additions & 7 deletions kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
import kernel_tuner.util as util
from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results, import_class_from_file
from kernel_tuner.integration import get_objective_defaults
from kernel_tuner.runners.sequential import SequentialRunner
from kernel_tuner.runners.simulation import SimulationRunner
from kernel_tuner.searchspace import Searchspace

try:
Expand Down Expand Up @@ -476,6 +474,7 @@ def __deepcopy__(self, _):
),
("metrics", ("specifies user-defined metrics, please see :ref:`metrics`.", "dict")),
("simulation_mode", ("Simulate an auto-tuning search from an existing cachefile", "bool")),
("parallel_workers", ("Set to `True` or an integer to enable parallel tuning. If set to an integer, this will be the number of parallel workers.", "int|bool")),
("observers", ("""A list of Observers to use during tuning, please see :ref:`observers`.""", "list")),
]
)
Expand Down Expand Up @@ -587,6 +586,7 @@ def tune_kernel(
cache=None,
metrics=None,
simulation_mode=False,
parallel_workers=None,
observers=None,
objective=None,
objective_higher_is_better=None,
Expand Down Expand Up @@ -654,9 +654,22 @@ def tune_kernel(
strategy = brute_force

# select the runner for this job based on input
selected_runner = SimulationRunner if simulation_mode else SequentialRunner
# TODO: we could use the "match case" syntax when removing support for 3.9
tuning_options.simulated_time = 0
runner = selected_runner(kernelsource, kernel_options, device_options, iterations, observers)

if parallel_workers and simulation_mode:
raise ValueError("Enabling `parallel_workers` and `simulation_mode` together is not supported")
elif simulation_mode:
from kernel_tuner.runners.simulation import SimulationRunner
runner = SimulationRunner(kernelsource, kernel_options, device_options, iterations, observers)
elif parallel_workers:
from kernel_tuner.runners.parallel import ParallelRunner
num_workers = None if parallel_workers is True else parallel_workers
runner = ParallelRunner(kernelsource, kernel_options, device_options, tuning_options, iterations, observers, num_workers=num_workers)
else:
from kernel_tuner.runners.sequential import SequentialRunner
runner = SequentialRunner(kernelsource, kernel_options, device_options, iterations, observers)


# the user-specified function may or may not have an optional atol argument;
# we normalize it so that it always accepts atol.
Expand All @@ -672,16 +685,20 @@ def preprocess_cache(filepath):
# process cache
if cache:
cache = preprocess_cache(cache)
util.process_cache(cache, kernel_options, tuning_options, runner)
tuning_options.cachefile = cache
tuning_options.cache = util.process_cache(cache, kernel_options, tuning_options, runner)
else:
tuning_options.cache = {}
tuning_options.cachefile = None
tuning_options.cache = {}

# create search space
tuning_options.restrictions_unmodified = deepcopy(restrictions)
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads, **searchspace_construction_options)
device_info = runner.get_device_info()
searchspace = Searchspace(tune_params, restrictions, device_info.max_threads, **searchspace_construction_options)

restrictions = searchspace._modified_restrictions
tuning_options.restrictions = restrictions

if verbose:
print(f"Searchspace has {searchspace.size} configurations after restrictions.")

Expand All @@ -699,6 +716,9 @@ def preprocess_cache(filepath):
results = strategy.tune(searchspace, runner, tuning_options)
env = runner.get_environment(tuning_options)

# Shut down the runner
runner.shutdown()

# finished iterating over search space
if results: # checks if results is not empty
best_config = util.get_best_config(results, objective, objective_higher_is_better)
Expand Down
Loading