diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index 42b8c734e..8f63ffb1a 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -10,7 +10,8 @@ from ._quadrants_callable import QuadrantsCallable from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError -NUM_WARMUP: int = 3 +NUM_FIRST_WARMUP: int = 1 +NUM_WARMUP: int = 0 NUM_ACTIVE: int = 1 REPEAT_AFTER_COUNT: int = 0 REPEAT_AFTER_SECONDS: float = 1.0 @@ -71,12 +72,14 @@ def __init__( self, get_geometry_hash: Callable[P, int], fn: Callable, + num_first_warmup: int | None = None, num_warmup: int | None = None, num_active: int | None = None, repeat_after_count: int | None = None, repeat_after_seconds: float | None = None, ) -> None: self._name: str = fn.__name__ # type: ignore + self.num_first_warmup = num_first_warmup if num_first_warmup is not None else NUM_FIRST_WARMUP self.num_warmup = num_warmup if num_warmup is not None else NUM_WARMUP self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT @@ -98,6 +101,7 @@ def __init__( ) self._calls_since_last_update_by_geometry_hash: dict[int, int] = defaultdict(int) self._last_check_time_by_geometry_hash: dict[int, float] = defaultdict(float) + self._first_eval_completed: set[int] = set() def register( self, implementation: Callable | None = None, *, is_compatible: Callable[[dict], bool] | None = None @@ -199,16 +203,21 @@ def _get_next_dispatch_impl( assert least_trials_dispatch_impl is not None and least_trials is not None return least_trials, least_trials_dispatch_impl + def _get_effective_warmup(self, geometry_hash: int) -> int: + if geometry_hash not in self._first_eval_completed: + return self.num_first_warmup + return self.num_warmup + def _get_min_trials_finished(self, geometry_hash: int) -> int: return min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) - def _compute_are_trials_finished(self, geometry_hash: int) -> bool: - if len(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash]) == 0: + def _compute_are_trials_finished(self, geometry_hash: int, num_compatible: int) -> bool: + trial_counts = self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash] + if len(trial_counts) < num_compatible: return False - min_trials = min(self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash].values()) - res = min_trials >= self.num_warmup + self.num_active - return res + min_trials = min(trial_counts.values()) + return min_trials >= self._get_effective_warmup(geometry_hash) + self.num_active def _compute_and_update_fastest(self, geometry_hash: int) -> None: times_by_dispatch_impl = self._times_by_dispatch_impl_by_geometry_hash[geometry_hash] @@ -293,12 +302,13 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): print(log_str) return dispatch_impl_(*args, **kwargs) + effective_warmup = self._get_effective_warmup(geometry_hash) min_trial_count, dispatch_impl = self._get_next_dispatch_impl( compatible_set=compatible_set, geometry_hash=geometry_hash ) trial_count_by_dispatch_impl = self._trial_count_by_dispatch_impl_by_geometry_hash[geometry_hash] trial_count_by_dispatch_impl[dispatch_impl] += 1 - in_warmup = min_trial_count < self.num_warmup + in_warmup = min_trial_count < effective_warmup start = 0 if not in_warmup: runtime.sync() @@ -309,8 +319,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): end = time.time() elapsed = end - start self._times_by_dispatch_impl_by_geometry_hash[geometry_hash][dispatch_impl].append(elapsed) - if self._compute_are_trials_finished(geometry_hash=geometry_hash): + if self._compute_are_trials_finished(geometry_hash=geometry_hash, num_compatible=len(compatible_set)): self._compute_and_update_fastest(geometry_hash) + self._first_eval_completed.add(geometry_hash) self._last_check_time_by_geometry_hash[geometry_hash] = time.time() return res @@ -318,6 +329,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs): def perf_dispatch( *, get_geometry_hash: Callable, + first_warmup: int = NUM_FIRST_WARMUP, warmup: int = NUM_WARMUP, active: int = NUM_ACTIVE, repeat_after_count: int = REPEAT_AFTER_COUNT, @@ -333,7 +345,10 @@ def perf_dispatch( Args: get_geometry_hash: A function that returns a geometry hash given the arguments. - warmup: Number of warmup iterations to run for each implementation before measuring. Default 3. + first_warmup: Number of warmup iterations for the very first evaluation cycle (default: 1). + Use a higher value when kernels need initial JIT compilation warmup. + warmup: Number of warmup iterations for subsequent re-evaluation cycles (default: 0). + After the first evaluation, kernels are already compiled, so less warmup is needed. active: Number of active (timed) iterations to run for each implementation. Default 1. repeat_after_count: repeats the cycle of warmup and active from scratch after repeat_after_count additional calls. @@ -402,6 +417,7 @@ def decorator(fn: Callable | QuadrantsCallable): return PerformanceDispatcher( get_geometry_hash=get_geometry_hash, fn=fn, + num_first_warmup=first_warmup, num_warmup=warmup, num_active=active, repeat_after_count=repeat_after_count, diff --git a/tests/python/test_perf_dispatch.py b/tests/python/test_perf_dispatch.py index fcd794d71..eaef03d99 100644 --- a/tests/python/test_perf_dispatch.py +++ b/tests/python/test_perf_dispatch.py @@ -6,7 +6,7 @@ import quadrants as qd from quadrants.lang import _perf_dispatch from quadrants.lang._perf_dispatch import ( - NUM_WARMUP, + NUM_FIRST_WARMUP, PerformanceDispatcher, _parse_force_map, ) @@ -32,12 +32,19 @@ def do_work_py(i_b, amount_work: qd.i32, state: qd.types.NDArray[qd.i32, 1]): @test_utils.test() def test_perf_dispatch_kernels() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -81,13 +88,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -98,18 +105,25 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @test_utils.test() def test_perf_dispatch_python() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -150,13 +164,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -167,18 +181,25 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @test_utils.test() def test_perf_dispatch_kernel_py_mix() -> None: + WARMUP = 3 + class ImplEnum(IntEnum): slow = 0 fastest_a_shape0_lt2 = 1 a_shape0_ge2 = 2 - @qd.perf_dispatch(get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), repeat_after_seconds=0) + @qd.perf_dispatch( + get_geometry_hash=lambda a, c, rand_state: hash(a.shape + c.shape), + first_warmup=WARMUP, + warmup=WARMUP, + repeat_after_seconds=0, + ) def my_func1( a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1], rand_state: qd.types.NDArray[qd.i32, 1] ): ... @@ -221,13 +242,13 @@ def my_func1_impl_a_shape0_ge_2( c = qd.ndarray(qd.i32, (len(ImplEnum),)) rand_state = qd.ndarray(qd.i32, (num_threads,)) - for it in range((NUM_WARMUP + 5)): + for it in range((WARMUP + 5)): c.fill(0) for _inner_it in range(2): # 2 compatible kernels a.fill(5) my_func1(a, c, rand_state=rand_state) assert (a.to_numpy()[:5] == [0, 5, 10, 15, 20]).all() - if it <= NUM_WARMUP: + if it <= WARMUP: assert c[ImplEnum.slow] == 1 assert c[ImplEnum.fastest_a_shape0_lt2] == 0 assert c[ImplEnum.a_shape0_ge2] == 1 @@ -238,7 +259,7 @@ def my_func1_impl_a_shape0_ge_2( speed_checker = cast(PerformanceDispatcher, my_func1) geometry = list(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash.keys())[0] for _dispatch_impl, trials in speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry].items(): - assert trials == NUM_WARMUP + 1 + assert trials == WARMUP + 1 assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 @@ -290,9 +311,107 @@ def my_func1_impl_impl2( ) -> None: ... +@test_utils.test() +def test_perf_dispatch_first_warmup_vs_warmup() -> None: + """first_warmup is used for the initial evaluation, warmup for re-evaluations.""" + called = [] + + @qd.perf_dispatch( + get_geometry_hash=lambda a: hash(a.shape), + first_warmup=2, + warmup=0, + repeat_after_count=5, + repeat_after_seconds=0, + ) + def my_func(a: qd.types.NDArray[qd.i32, 1]): ... + + @my_func.register + def impl_a(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("a") + + @my_func.register + def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("b") + + a = qd.ndarray(qd.i32, (4,)) + speed_checker = cast(PerformanceDispatcher, my_func) + + # --- First evaluation cycle: first_warmup=2, active=1 --- + # 2 impls × (2 warmup + 1 active) = 6 calls to complete first eval + for _ in range(6): + my_func(a) + + assert speed_checker._fastest_dispatch_impl_by_geometry_hash, "first eval should have completed" + geometry = list(speed_checker._fastest_dispatch_impl_by_geometry_hash.keys())[0] + assert geometry in speed_checker._first_eval_completed + + first_eval_calls = list(called) + assert len(first_eval_calls) == 6 + + # --- Now fastest is chosen; repeat_after_count=5, so 4 calls use fastest, + # then 5th call triggers re-eval --- + called.clear() + for _ in range(4): + my_func(a) + assert len(called) == 4 + assert len(set(called)) == 1 # all fastest + + # --- 5th call triggers re-eval; warmup=0, active=1 --- + # Re-eval needs 2 impls × (0 warmup + 1 active) = 2 calls. + # The 5th call from above starts re-eval, so we need 2 more calls total. + called.clear() + for _ in range(2): + my_func(a) + assert len(called) == 2 + assert set(called) == {"a", "b"} + + +@test_utils.test() +def test_perf_dispatch_default_warmup_values() -> None: + """With defaults (first_warmup=1, warmup=0), first eval uses 1 warmup, re-evals use 0.""" + called = [] + + @qd.perf_dispatch( + get_geometry_hash=lambda a: hash(a.shape), + repeat_after_count=4, + repeat_after_seconds=0, + ) + def my_func(a: qd.types.NDArray[qd.i32, 1]): ... + + @my_func.register + def impl_a(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("a") + + @my_func.register + def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: + called.append("b") + + a = qd.ndarray(qd.i32, (4,)) + speed_checker = cast(PerformanceDispatcher, my_func) + + # First eval: 2 impls × (1 first_warmup + 1 active) = 4 calls + for _ in range(4): + my_func(a) + assert speed_checker._fastest_dispatch_impl_by_geometry_hash + + # 3 calls using fastest (calls before repeat_after_count=4 triggers on the 4th) + called.clear() + for _ in range(3): + my_func(a) + assert len(called) == 3 + assert len(set(called)) == 1 # all same (fastest) + + # 4th call triggers re-eval; with warmup=0, active=1, needs 2 calls total + called.clear() + for _ in range(2): + my_func(a) + assert len(called) == 2 + assert set(called) == {"a", "b"} + + @test_utils.test() def test_perf_dispatch_sanity_check_register_args() -> None: - @qd.perf_dispatch(get_geometry_hash=lambda a, c: hash(a.shape + c.shape), warmup=25, active=25) + @qd.perf_dispatch(get_geometry_hash=lambda a, c: hash(a.shape + c.shape), first_warmup=25, warmup=25, active=25) def my_func1(a: qd.types.NDArray[qd.i32, 1], c: qd.types.NDArray[qd.i32, 1]): ... @@ -337,10 +456,10 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: called.append("b") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): my_func(a) - assert len(called) == NUM_WARMUP * 2 + 3 + assert len(called) == NUM_FIRST_WARMUP * 2 + 3 assert all(c == "b" for c in called) @@ -364,9 +483,9 @@ def impl_b(a: qd.types.NDArray[qd.i32, 1]) -> None: called.append("b") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): my_func(a) - assert len(called) == NUM_WARMUP * 2 + 3 + assert len(called) == NUM_FIRST_WARMUP * 2 + 3 @test_utils.test() @@ -401,11 +520,11 @@ def op_b_v2(a: qd.types.NDArray[qd.i32, 1]) -> None: called_b.append("v2") a = qd.ndarray(qd.i32, (4,)) - for _ in range(NUM_WARMUP * 2 + 3): + for _ in range(NUM_FIRST_WARMUP * 2 + 3): op_a(a) op_b(a) - assert len(called_a) == NUM_WARMUP * 2 + 3 - assert len(called_b) == NUM_WARMUP * 2 + 3 + assert len(called_a) == NUM_FIRST_WARMUP * 2 + 3 + assert len(called_b) == NUM_FIRST_WARMUP * 2 + 3 assert all(c == "v2" for c in called_a) assert all(c == "v1" for c in called_b)