From 018514835311ba1ee5fcc5077bd5f1b7f28fabc3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 22 Mar 2026 10:08:16 -0700 Subject: [PATCH 1/3] Add first_warmup parameter to perf_dispatch The first evaluation cycle (when kernels may need JIT compilation) uses first_warmup (default 1) warmup iterations. Subsequent re-evaluations use warmup (default 0) since kernels are already compiled, reducing the number of dispatch overhead steps from 8 to 2 per re-evaluation cycle. --- python/quadrants/lang/_perf_dispatch.py | 24 +++- tests/python/test_perf_dispatch.py | 149 ++++++++++++++++++++---- 2 files changed, 149 insertions(+), 24 deletions(-) diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index 42b8c734e..a9a80030a 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -10,7 +10,8 @@ from ._quadrants_callable import QuadrantsCallable from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError -NUM_WARMUP: int = 3 +NUM_FIRST_WARMUP: int = 1 +NUM_WARMUP: int = 0 NUM_ACTIVE: int = 1 REPEAT_AFTER_COUNT: int = 0 REPEAT_AFTER_SECONDS: float = 1.0 @@ -71,12 +72,14 @@ def __init__( self, get_geometry_hash: Callable[P, int], fn: Callable, + num_first_warmup: int | None = None, num_warmup: int | None = None, num_active: int | None = None, repeat_after_count: int | None = None, repeat_after_seconds: float | None = None, ) -> None: self._name: str = fn.__name__ # type: ignore + self.num_first_warmup = num_first_warmup if num_first_warmup is not None else NUM_FIRST_WARMUP self.num_warmup = num_warmup if num_warmup is not None else NUM_WARMUP self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT @@ -98,6 +101,7 @@ def __init__( ) self._calls_since_last_update_by_geometry_hash: dict[int, int] = defaultdict(int) self._last_check_time_by_geometry_hash: dict[int, float] = defaultdict(float) + self._first_eval_completed: set[int] = set() def register( self, implementation: Callable | None = None, *, is_compatible: Callable[[dict], bool] | None = None @@ -199,6 +203,11 @@ def _get_next_dispatch_impl( assert least_trials_dispatch_impl is not None and least_trials is not None return least_trials, least_trials_dispatch_impl + def _get_effective_warmup(self, geometry_hash: int) -> int: + if geometry_hash not in self._first_eval_completed: + return self.num_first_warmup + return self.num_warmup + def _get_min_trials_finished(self, geometry_hash: int) -> int: return min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) @@ -207,7 +216,7 @@ def _compute_are_trials_finished(self, geometry_hash: int) -> bool: return False min_trials = min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) - res = min_trials >= self.num_warmup + self.num_active + res = min_trials >= self._get_effective_warmup(geometry_hash) + self.num_active return res def _compute_and_update_fastest(self, geometry_hash: int) -> None: @@ -293,12 +302,13 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): print(log_str) return dispatch_impl_(*args, **kwargs) + effective_warmup = self._get_effective_warmup(geometry_hash) min_trial_count, dispatch_impl = self._get_next_dispatch_impl( compatible_set=compatible_set, geometry_hash=geometry_hash ) trial_count_by_dispatch_impl = self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash] trial_count_by_dispatch_impl[dispatch_impl] += 1 - in_warmup = min_trial_count < self.num_warmup + in_warmup = min_trial_count < effective_warmup start = 0 if not in_warmup: runtime.sync() @@ -311,6 +321,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): self._times_by_dispatch_impl_by_geometry_hash[geometry_hash][dispatch_impl].append(elapsed) if self._compute_are_trials_finished(geometry_hash=geometry_hash): self._compute_and_update_fastest(geometry_hash) + self._first_eval_completed.add(geometry_hash) self._last_check_time_by_geometry_hash[geometry_hash] = time.time() return res @@ -318,6 +329,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): def perf_dispatch( *, get_geometry_hash: Callable, + first_warmup: int = NUM_FIRST_WARMUP, warmup: int = NUM_WARMUP, active: int = NUM_ACTIVE, repeat_after_count: int = REPEAT_AFTER_COUNT, @@ -333,7 +345,10 @@ def perf_dispatch( Args: get_geometry_hash: A function that returns a geometry hash given the arguments. - warmup: Number of warmup iterations to run for each implementation before measuring. Default 3. + first_warmup: Number of warmup iterations for the very first evaluation cycle (default: 1). + Use a higher value when kernels need initial JIT compilation warmup. + warmup: Number of warmup iterations for subsequent re-evaluation cycles (default: 0). + After the first evaluation, kernels are already compiled, so less warmup is needed. active: Number of active (timed) iterations to run for each implementation. Default 1. repeat_after_count: repeats the cycle of warmup and active from scratch after repeat_after_count additional calls. @@ -402,6 +417,7 @@ def decorator(fn: Callable | QuadrantsCallable): return PerformanceDispatcher( get_geometry_hash=get_geometry_hash, fn=fn, + num_first_warmup=first_warmup, num_warmup=warmup, num_active=active, repeat_after_count=repeat_after_count, diff --git a/tests/python/test_perf_dispatch.py b/tests/python/test_perf_dispatch.py index fcd794d71..da1d93011 100644 --- a/tests/python/test_perf_dispatch.py +++ b/tests/python/test_perf_dispatch.py @@ -6,6 +6,7 @@ import quadrants as qd from quadrants.lang import _perf_dispatch from quadrants.lang._perf_dispatch import ( + NUM_FIRST_WARMUP, NUM_WARMUP, PerformanceDispatcher, _parse_force_map, @@ -32,12 +33,15 @@ def do_work_py(i_b, amount_work: qd.i32, state: qd.types.NDArray[qd.i32, 1]): @test_utils.test() def test_perf_dispatch_kernels() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -81,13 +85,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -98,18 +102,21 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @test_utils.test() def test_perf_dispatch_python() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -150,13 +157,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -167,18 +174,21 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @test_utils.test() def test_perf_dispatch_kernel_py_mix() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -221,13 +231,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -238,7 +248,7 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @@ -290,9 +300,108 @@ def my_func1_impl_impl2( ) -> None: ... +@test_utils.test() +def test_perf_dispatch_first_warmup_vs_warmup() -> None: + """first_warmup is used for the initial evaluation, warmup for re-evaluations.""" + called = [] + + @qd.perf_dispatch( + get_geometry_hash=lambda a: hash(a.shape), + first_warmup=2, + warmup=0, + repeat_after_count=3, + repeat_after_seconds=0, + ) + def my_func(a: qd.types.NDArray[qd.i32, 1]): ... + + @my_func.register + def impl_a(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("a") + + @my_func.register + def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("b") + + a = qd.ndarray(qd.i32, (4,)) + speed_checker = cast(PerformanceDispatcher, my_func) + + # --- First evaluation cycle: first_warmup=2, active=1 --- + # 2 impls × (2 warmup + 1 active) = 6 calls to complete first eval + for _ in range(6): + my_func(a) + + assert speed_checker._fastest_dispatch_impl_by_geometry_hash, "first eval should have completed" + geometry = list(speed_checker._fastest_dispatch_impl_by_geometry_hash.keys())[0] + assert geometry in speed_checker._first_eval_completed + + first_eval_calls = list(called) + assert len(first_eval_calls) == 6 + + # --- Now fastest is chosen; 3 calls before re-eval triggers (repeat_after_count=3) --- + called.clear() + for _ in range(3): + my_func(a) + assert len(called) == 3 + # All 3 should be the fastest impl (no re-eval yet) + assert len(set(called)) == 1 + + # --- Re-evaluation cycle: warmup=0, active=1 --- + # Next call triggers re-eval. With warmup=0, each impl gets 1 active call. + # 2 impls × (0 warmup + 1 active) = 2 calls to complete re-eval + called.clear() + for _ in range(2): + my_func(a) + assert len(called) == 2 + # Both impls should have been called (one each for the active measurement) + assert set(called) == {"a", "b"} + + +@test_utils.test() +def test_perf_dispatch_default_warmup_values() -> None: + """With defaults (first_warmup=1, warmup=0), first eval uses 1 warmup, re-evals use 0.""" + called = [] + + @qd.perf_dispatch( + get_geometry_hash=lambda a: hash(a.shape), + repeat_after_count=2, + repeat_after_seconds=0, + ) + def my_func(a: qd.types.NDArray[qd.i32, 1]): ... + + @my_func.register + def impl_a(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("a") + + @my_func.register + def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("b") + + a = qd.ndarray(qd.i32, (4,)) + speed_checker = cast(PerformanceDispatcher, my_func) + + # First eval: 2 impls × (1 first_warmup + 1 active) = 4 calls + for _ in range(4): + my_func(a) + assert speed_checker._fastest_dispatch_impl_by_geometry_hash + + # 2 calls using fastest, then re-eval triggers + called.clear() + for _ in range(2): + my_func(a) + assert len(called) == 2 + assert len(set(called)) == 1 # all same (fastest) + + # Re-eval: 2 impls × (0 warmup + 1 active) = 2 calls + called.clear() + for _ in range(2): + my_func(a) + assert len(called) == 2 + assert set(called) == {"a", "b"} + + @test_utils.test() def test_perf_dispatch_sanity_check_register_args() -> None: - @qd.perf_dispatch(get_geometry_hash=lambda a, c: hash(a.shape + c.shape), warmup=25, active=25) + @qd.perf_dispatch(get_geometry_hash=lambda a, c: hash(a.shape + c.shape), first_warmup=25, warmup=25, active=25) def my_func1(a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1]): ... @@ -337,10 +446,10 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: called.append("b") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): my_func(a) - assert len(called) == NUM_WARMUP * 2 + 3 + assert len(called) == NUM_FIRST_WARMUP * 2 + 3 assert all(c == "b" for c in called) @@ -364,9 +473,9 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: called.append("b") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): my_func(a) - assert len(called) == NUM_WARMUP * 2 + 3 + assert len(called) == NUM_FIRST_WARMUP * 2 + 3 @test_utils.test() @@ -401,11 +510,11 @@ def op_b_v2(a: qd.types.NDArray[qd.i32, 1]) -> None: called_b.append("v2") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): op_a(a) op_b(a) - assert len(called_a) == NUM_WARMUP * 2 + 3 - assert len(called_b) == NUM_WARMUP * 2 + 3 + assert len(called_a) == NUM_FIRST_WARMUP * 2 + 3 + assert len(called_b) == NUM_FIRST_WARMUP * 2 + 3 assert all(c == "v2" for c in called_a) assert all(c == "v1" for c in called_b) From 2b02d2e2727b8b5d2606053e7cbbe165c5d2e925 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 22 Mar 2026 10:49:22 -0700 Subject: [PATCH 2/3] Fix _compute_are_trials_finished bug with warmup=0 When warmup=0, the method would declare trials finished after trying only one impl because untried impls weren't in the trial count dict. Now checks that all compatible impls have been tried before evaluating the min trial threshold. --- python/quadrants/lang/_perf_dispatch.py | 12 +++++----- tests/python/test_perf_dispatch.py | 29 ++++++++++++------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index a9a80030a..8f63ffb1a 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -211,13 +211,13 @@ def _get_effective_warmup(self, geometry_hash: int) -> int: def _get_min_trials_finished(self, geometry_hash: int) -> int: return min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) - def _compute_are_trials_finished(self, geometry_hash: int) -> bool: - if len(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash]) == 0: + def _compute_are_trials_finished(self, geometry_hash: int, num_compatible: int) -> bool: + trial_counts = self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash] + if len(trial_counts) < num_compatible: return False - min_trials = min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) - res = min_trials >= self._get_effective_warmup(geometry_hash) + self.num_active - return res + min_trials = min(trial_counts.values()) + return min_trials >= self._get_effective_warmup(geometry_hash) + self.num_active def _compute_and_update_fastest(self, geometry_hash: int) -> None: times_by_dispatch_impl = self._times_by_dispatch_impl_by_geometry_hash[geometry_hash] @@ -319,7 +319,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): end = time.time() elapsed = end - start self._times_by_dispatch_impl_by_geometry_hash[geometry_hash][dispatch_impl].append(elapsed) - if self._compute_are_trials_finished(geometry_hash=geometry_hash): + if self._compute_are_trials_finished(geometry_hash=geometry_hash, num_compatible=len(compatible_set)): self._compute_and_update_fastest(geometry_hash) self._first_eval_completed.add(geometry_hash) self._last_check_time_by_geometry_hash[geometry_hash] = time.time() diff --git a/tests/python/test_perf_dispatch.py b/tests/python/test_perf_dispatch.py index da1d93011..259e15b74 100644 --- a/tests/python/test_perf_dispatch.py +++ b/tests/python/test_perf_dispatch.py @@ -309,7 +309,7 @@ def test_perf_dispatch_first_warmup_vs_warmup() -> None: get_geometry_hash=lambda a: hash(a.shape), first_warmup=2, warmup=0, - repeat_after_count=3, + repeat_after_count=5, repeat_after_seconds=0, ) def my_func(a: qd.types.NDArray[qd.i32, 1]): ... @@ -337,22 +337,21 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: first_eval_calls = list(called) assert len(first_eval_calls) == 6 - # --- Now fastest is chosen; 3 calls before re-eval triggers (repeat_after_count=3) --- + # --- Now fastest is chosen; repeat_after_count=5, so 4 calls use fastest, + # then 5th call triggers re-eval --- called.clear() - for _ in range(3): + for _ in range(4): my_func(a) - assert len(called) == 3 - # All 3 should be the fastest impl (no re-eval yet) - assert len(set(called)) == 1 + assert len(called) == 4 + assert len(set(called)) == 1 # all fastest - # --- Re-evaluation cycle: warmup=0, active=1 --- - # Next call triggers re-eval. With warmup=0, each impl gets 1 active call. - # 2 impls × (0 warmup + 1 active) = 2 calls to complete re-eval + # --- 5th call triggers re-eval; warmup=0, active=1 --- + # Re-eval needs 2 impls × (0 warmup + 1 active) = 2 calls. + # The 5th call from above starts re-eval, so we need 2 more calls total. called.clear() for _ in range(2): my_func(a) assert len(called) == 2 - # Both impls should have been called (one each for the active measurement) assert set(called) == {"a", "b"} @@ -363,7 +362,7 @@ def test_perf_dispatch_default_warmup_values() -> None: @qd.perf_dispatch( get_geometry_hash=lambda a: hash(a.shape), - repeat_after_count=2, + repeat_after_count=4, repeat_after_seconds=0, ) def my_func(a: qd.types.NDArray[qd.i32, 1]): ... @@ -384,14 +383,14 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: my_func(a) assert speed_checker._fastest_dispatch_impl_by_geometry_hash - # 2 calls using fastest, then re-eval triggers + # 3 calls using fastest (calls before repeat_after_count=4 triggers on the 4th) called.clear() - for _ in range(2): + for _ in range(3): my_func(a) - assert len(called) == 2 + assert len(called) == 3 assert len(set(called)) == 1 # all same (fastest) - # Re-eval: 2 impls × (0 warmup + 1 active) = 2 calls + # 4th call triggers re-eval; with warmup=0, active=1, needs 2 calls total called.clear() for _ in range(2): my_func(a) From 433c72762335925ada04273ee9824c5e1f8b2c82 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 22 Mar 2026 11:08:43 -0700 Subject: [PATCH 3/3] Fix formatting and remove unused NUM_WARMUP import --- tests/python/test_perf_dispatch.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/python/test_perf_dispatch.py b/tests/python/test_perf_dispatch.py index 259e15b74..eaef03d99 100644 --- a/tests/python/test_perf_dispatch.py +++ b/tests/python/test_perf_dispatch.py @@ -7,7 +7,6 @@ from quadrants.lang import _perf_dispatch from quadrants.lang._perf_dispatch import ( NUM_FIRST_WARMUP, - NUM_WARMUP, PerformanceDispatcher, _parse_force_map, ) @@ -40,8 +39,12 @@ class ImplEnum(IntEnum): fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), - first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -115,8 +118,12 @@ class ImplEnum(IntEnum): fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), - first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -187,8 +194,12 @@ class ImplEnum(IntEnum): fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), - first_warmup=WARMUP, warmup=WARMUP, repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ...