diff --git a/benchmarks/speed.py b/benchmarks/speed.py index 6f0cc91..3d89aaa 100644 --- a/benchmarks/speed.py +++ b/benchmarks/speed.py @@ -25,7 +25,7 @@ """Measure the speed of the MCMC and its interfaces.""" from collections.abc import Mapping -from contextlib import nullcontext, redirect_stdout +from contextlib import redirect_stdout from dataclasses import replace from functools import partial from inspect import signature @@ -50,7 +50,6 @@ from jax.tree_util import tree_map from jaxtyping import Array, Float32, Integer, Key, UInt8 -import bartz from bartz import mcmcloop, mcmcstep from bartz.mcmcloop import run_mcmc from benchmarks.latest_bartz.jaxext import get_device_count, split @@ -313,7 +312,6 @@ def setup( niters: int = NITERS, nchains: int = 1, cache: Cache = 'warm', - profile: bool = False, predict: bool = False, kwargs: Mapping[str, Any] = MappingProxyType({}), ) -> None: @@ -357,22 +355,12 @@ def setup( self.kw.update(kwargs) block_until_ready(self.kw) - # set profile mode - if not profile: - self.context = nullcontext - elif hasattr(bartz, 'profile_mode'): - self.context = lambda: bartz.profile_mode(True) - else: - msg = 'Profile mode not supported.' - raise NotImplementedError(msg) - # save information used to run predictions self.predict = predict if predict: self.test = test - with self.context(): - self.bart = gbart(**self.kw) - block_bart(self.bart) + self.bart = gbart(**self.kw) + block_bart(self.bart) # decide how much to cold-start match cache: @@ -385,7 +373,7 @@ def setup( def time_gbart(self, *_: Any) -> None: """Time instantiating the class.""" - with redirect_stdout(StringIO()), self.context(): + with redirect_stdout(StringIO()): if self.predict: ypred = self.bart.predict(self.test.x) block_until_ready(ypred) @@ -439,13 +427,13 @@ def setup(self, nchains: int, shard: bool) -> None: # ty:ignore[invalid-method- # on gpu shard explicitly kwargs = dict(num_chain_devices=min(nchains, get_device_count())) - super().setup(NITERS, nchains, 'warm', False, False, dict(bart_kwargs=kwargs)) + super().setup(NITERS, nchains, 'warm', False, dict(bart_kwargs=kwargs)) class GbartGeneric(BaseGbart): """General timing of `mc_gbart` with many settings.""" - params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True), (False, True)) + params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True)) class BaseRunMcmc(AutoParamNames): diff --git a/docs/development.rst b/docs/development.rst index 805e138..0de28c7 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -148,13 +148,13 @@ This runs only benchmarks whose name matches , only once, within the wo Profiling --------- -Use the `JAX profiling utilities `_ to profile `bartz`. By default the MCMC loop is compiled all at once, which makes it quite opaque to profiling. There are two ways to understand what's going on inside in more detail: 1) inspect the individual operations and use intuition to understand to what piece of code they correspond to, 2) turn on bartz's profile mode. Basic workflow: +Use the `JAX profiling utilities `_ to profile `bartz`. It works well on GPU, not on CPU. .. code-block:: python from jax.profiler import trace, ProfileOptions + from jax import block_until_ready from bartz.BART import gbart - from bartz import profile_mode traceopt = ProfileOptions() @@ -162,11 +162,12 @@ Use the `JAX profiling utilities traceopt.python_tracer_level = 1 # on cpu, this makes the trace detailed enough to understand what's going on - # even within compiled functions + # even within compiled functions by manual inspection of each operation traceopt.host_tracer_level = 2 - with trace('./trace_results', profiler_options=traceopt), profile_mode(True): + with trace('./trace_results', profiler_options=traceopt): bart = gbart(...) + block_until_ready(bart) On the first run, the trace will show compilation operations, while subsequent runs (within the same Python shell) will be warmed-up. Start a xprof server to visualize the results: @@ -177,5 +178,3 @@ On the first run, the trace will show compilation operations, while subsequent r XProf at http://localhost:8791/ (Press CTRL+C to quit) Open the provided URL in a browser. In the sidebar, select the tool "Trace Viewer". - -In "profile mode", the MCMC loop is split into a few chunks that are compiled separately, allowing to see at a glance how much time each phase of the MCMC cycle takes. This causes some overhead, so the timings are not equivalent to the normal mode ones. On some specific example on CPU, Bartz was 20% slower in profile mode with one chain, and 2x slower with multiple chains. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9ea214c..7af6dff 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,4 +37,3 @@ Reference jaxext.rst debug.rst test.rst - profile.rst diff --git a/docs/reference/profile.rst b/docs/reference/profile.rst deleted file mode 100644 index 183c08b..0000000 --- a/docs/reference/profile.rst +++ /dev/null @@ -1,28 +0,0 @@ -.. bartz/docs/reference/profile.rst -.. -.. Copyright (c) 2025, The Bartz Contributors -.. -.. This file is part of bartz. -.. -.. Permission is hereby granted, free of charge, to any person obtaining a copy -.. of this software and associated documentation files (the "Software"), to deal -.. in the Software without restriction, including without limitation the rights -.. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -.. copies of the Software, and to permit persons to whom the Software is -.. furnished to do so, subject to the following conditions: -.. -.. The above copyright notice and this permission notice shall be included in all -.. copies or substantial portions of the Software. -.. -.. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -.. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -.. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -.. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -.. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -.. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -.. SOFTWARE. - -Profiling ---------- - -.. autofunction:: bartz.profile_mode diff --git a/src/bartz/__init__.py b/src/bartz/__init__.py index 19a30d9..069d757 100644 --- a/src/bartz/__init__.py +++ b/src/bartz/__init__.py @@ -30,5 +30,4 @@ from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401 from bartz._interface import Bart # noqa: F401 -from bartz._profiler import profile_mode # noqa: F401 from bartz._version import __version__, __version_info__ # noqa: F401 diff --git a/src/bartz/_profiler.py b/src/bartz/_profiler.py deleted file mode 100644 index d38eb38..0000000 --- a/src/bartz/_profiler.py +++ /dev/null @@ -1,318 +0,0 @@ -# bartz/src/bartz/_profiler.py -# -# Copyright (c) 2025-2026, The Bartz Contributors -# -# This file is part of bartz. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Module with utilities related to profiling bartz.""" - -from collections.abc import Callable, Iterator -from contextlib import contextmanager -from functools import wraps -from typing import Any, TypeVar - -from jax import block_until_ready, debug, jit, lax -from jax.profiler import TraceAnnotation -from jaxtyping import Array, Bool - -from bartz.mcmcstep._state import vmap_chains - -PROFILE_MODE: bool = False - -T = TypeVar('T') - - -def get_profile_mode() -> bool: - """Return the current profile mode status. - - Returns - ------- - True if profile mode is enabled, False otherwise. - """ - return PROFILE_MODE - - -def set_profile_mode(value: bool, /) -> None: - """Set the profile mode status. - - Parameters - ---------- - value - If True, enable profile mode. If False, disable it. - """ - global PROFILE_MODE # noqa: PLW0603 - PROFILE_MODE = value - - -@contextmanager -def profile_mode(value: bool, /) -> Iterator[None]: - """Context manager to temporarily set profile mode. - - Parameters - ---------- - value - Profile mode value to set within the context. - - Examples - -------- - >>> with profile_mode(True): - ... # Code runs with profile mode enabled - ... pass - - Notes - ----- - In profiling mode, the MCMC loop is not compiled into a single function, but - instead compiled in smaller pieces that are instrumented to show up in the - jax tracer and Python profiling statistics. Search for function names - starting with 'jab' (see `jit_and_block_if_profiling`). - - Jax tracing is not enabled by this context manager and if used must be - handled separately by the user; this context manager only makes sure that - the execution flow will be more interpretable in the traces if the tracer is - used. - """ - old_value = get_profile_mode() - set_profile_mode(value) - try: - yield - finally: - set_profile_mode(old_value) - - -def jit_and_block_if_profiling( - func: Callable[..., T], block_before: bool = False, **kwargs: Any -) -> Callable[..., T]: - """Apply JIT compilation and block if profiling is enabled. - - When profile mode is off, the function runs without JIT. When profile mode - is on, the function is JIT compiled and blocks outputs to ensure proper - timing. - - Parameters - ---------- - func - Function to wrap. - block_before - If True block inputs before passing them to the JIT-compiled function. - This ensures that any pending computations are completed before entering - the JIT-compiled function. This phase is not included in the trace - event. - **kwargs - Additional arguments to pass to `jax.jit`. - - Returns - ------- - Wrapped function. - - Notes - ----- - Under profiling mode, the function invocation is handled such that a custom - jax trace event with name `jab[]` is created. The statistics on - the actual Python function will be off, while the function - `jab_inner_wrapper` represents the actual execution time. - """ - jitted_func = jit(func, **kwargs) - - event_name = f'jab[{func.__name__}]' - - # this wrapper is meant to measure the time spent executing the function - def jab_inner_wrapper(*args: Any, **kwargs: Any) -> T: - with TraceAnnotation(event_name): - result = jitted_func(*args, **kwargs) - return block_until_ready(result) - - @wraps(func) - def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: - if get_profile_mode(): - if block_before: - args, kwargs = block_until_ready((args, kwargs)) - return jab_inner_wrapper(*args, **kwargs) - else: - return func(*args, **kwargs) - - return jab_outer_wrapper - - -def jit_if_profiling( - func: Callable[..., T], *args: Any, **kwargs: Any -) -> Callable[..., T]: - """Apply JIT compilation only when profiling. - - Parameters - ---------- - func - Function to wrap. - *args - **kwargs - Additional arguments to pass to `jax.jit`. - - Returns - ------- - Wrapped function. - """ - jitted_func = jit(func, *args, **kwargs) - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> T: - if get_profile_mode(): - return jitted_func(*args, **kwargs) - else: - return func(*args, **kwargs) - - return wrapper - - -def jit_if_not_profiling( - func: Callable[..., T], *args: Any, **kwargs: Any -) -> Callable[..., T]: - """Apply JIT compilation only when not profiling. - - When profile mode is off, the function is JIT compiled. When profile mode is - on, the function runs as-is. - - Parameters - ---------- - func - Function to wrap. - *args - **kwargs - Additional arguments to pass to `jax.jit`. - - Returns - ------- - Wrapped function. - """ - jitted_func = jit(func, *args, **kwargs) - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> T: - if get_profile_mode(): - return func(*args, **kwargs) - else: - return jitted_func(*args, **kwargs) - - wrapper._fun = func # used by run_mcmc # noqa: SLF001 - - return wrapper - - -def while_loop_if_not_profiling( - cond_fun: Callable[[T], Bool[Array, ''] | bool], - body_fun: Callable[[T], T], - init_val: T, - /, -) -> T: - """Restricted replacement for `jax.lax.while_loop` that uses a Python loop when profiling. - - Parameters - ---------- - cond_fun - Function to evaluate to determine whether to continue the loop. - body_fun - Function that updates the state in each iteration. - init_val - Initial state. - - Returns - ------- - Final state. - """ - if get_profile_mode(): - val = init_val - while cond_fun(val): - val = body_fun(val) - return val - - else: - return lax.while_loop(cond_fun, body_fun, init_val) - - -def cond_if_not_profiling( - pred: bool | Bool[Array, ''], - true_fun: Callable[..., T], - false_fun: Callable[..., T], - /, - *operands: Any, -) -> T: - """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling. - - Parameters - ---------- - pred - Boolean predicate to choose which function to execute. - true_fun - Function to execute if `pred` is True. - false_fun - Function to execute if `pred` is False. - *operands - Arguments passed to `true_fun` and `false_fun`. - - Returns - ------- - Result of either `true_fun()` or `false_fun()`. - """ - if get_profile_mode(): - if pred: - return true_fun(*operands) - else: - return false_fun(*operands) - else: - return lax.cond(pred, true_fun, false_fun, *operands) - - -def callback_if_not_profiling( - callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any -) -> None: - """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode.""" - if get_profile_mode(): - callback(*args, **kwargs) - else: - debug.callback(callback, *args, ordered=ordered, **kwargs) - - -def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs: Any) -> Callable[..., T]: - """Apply `vmap_chains` only when profile mode is enabled.""" - new_fun = vmap_chains(fun, **kwargs) - - @wraps(fun) - def wrapper(*args: Any, **kwargs: Any) -> T: - if get_profile_mode(): - return new_fun(*args, **kwargs) - else: - return fun(*args, **kwargs) - - return wrapper - - -def vmap_chains_if_not_profiling( - fun: Callable[..., T], **kwargs: Any -) -> Callable[..., T]: - """Apply `vmap_chains` only when profile mode is disabled.""" - new_fun = vmap_chains(fun, **kwargs) - - @wraps(fun) - def wrapper(*args: Any, **kwargs: Any) -> T: - if get_profile_mode(): - return fun(*args, **kwargs) - else: - return new_fun(*args, **kwargs) - - return wrapper diff --git a/src/bartz/jaxext/__init__.py b/src/bartz/jaxext/__init__.py index dd9971d..c268fa2 100644 --- a/src/bartz/jaxext/__init__.py +++ b/src/bartz/jaxext/__init__.py @@ -26,7 +26,6 @@ import math from collections.abc import Callable, Sequence -from contextlib import nullcontext from functools import partial from typing import Any @@ -38,7 +37,6 @@ import jax from jax import ( Device, - debug_key_reuse, device_count, ensure_compile_time_eval, jit, @@ -46,7 +44,6 @@ random, tree, typeof, - vmap, ) from jax import numpy as jnp from jax.dtypes import prng_key @@ -135,33 +132,19 @@ class split: The key to split. num The number of keys to split into. - - Notes - ----- - Unlike `jax.random.split`, this class supports a vector of keys as input. In - this case, it behaves as if everything had been vmapped over, so `keys.pop` - has an additional initial output dimension equal to the number of input - keys, and the deterministic dependency respects this axis. """ - _keys: tuple[Key[Array, '*batch'], ...] + _keys: tuple[Key[Array, ''], ...] _num_used: int - def __init__(self, key: Key[Array, '*batch'], num: int = 2) -> None: - if key.ndim: - context = debug_key_reuse(False) - else: - context = nullcontext() - with context: - # jitted-vmapped key split seems to be triggering a false positive - # with key reuse checks - self._keys = _split_unpack(key, num) + def __init__(self, key: Key[Array, ''], num: int = 2) -> None: + self._keys = _split_unpack(key, num) self._num_used = 0 def __len__(self) -> int: return len(self._keys) - self._num_used - def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']: + def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' {shape}']: """ Pop one or more keys from the list. @@ -194,26 +177,18 @@ def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}'] @partial(jit, static_argnums=(1,)) -def _split_unpack( - key: Key[Array, '*batch'], num: int -) -> tuple[Key[Array, '*batch'], ...]: - if key.ndim == 0: - keys = random.split(key, num) - elif key.ndim == 1: - keys = vmap(random.split, in_axes=(0, None), out_axes=1)(key, num) +def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]: + keys = random.split(key, num) return tuple(keys) @partial(jit, static_argnums=(1,)) def _split_shaped( - key: Key[Array, '*batch'], shape: tuple[int, ...] -) -> Key[Array, '*batch {shape}']: + key: Key[Array, ''], shape: tuple[int, ...] +) -> Key[Array, ' {shape}']: num = math.prod(shape) - if key.ndim == 0: - keys = random.split(key, num) - elif key.ndim == 1: - keys = vmap(random.split, in_axes=(0, None))(key, num) - return keys.reshape(*key.shape, *shape) + keys = random.split(key, num) + return keys.reshape(shape) def truncated_normal_onesided( diff --git a/src/bartz/mcmcloop.py b/src/bartz/mcmcloop.py index 70ef735..065f54c 100644 --- a/src/bartz/mcmcloop.py +++ b/src/bartz/mcmcloop.py @@ -29,7 +29,7 @@ from collections.abc import Callable from dataclasses import fields -from functools import partial, wraps +from functools import partial, update_wrapper, wraps from math import floor from typing import Any, NamedTuple, Protocol, TypeVar @@ -43,6 +43,8 @@ device_put, eval_shape, jit, + lax, + named_call, tree, ) from jax import numpy as jnp @@ -62,12 +64,6 @@ ) from bartz import jaxext, mcmcstep -from bartz._profiler import ( - cond_if_not_profiling, - get_profile_mode, - jit_if_not_profiling, - while_loop_if_not_profiling, -) from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram from bartz.jaxext import autobatch, jit_active from bartz.mcmcstep import State @@ -314,13 +310,12 @@ def run_mcmc( # setting to 0 would make for a clean noop, but it's useful to keep the # same code path for benchmarking and testing - # error if under jit and there are unrolled loops or profile mode is on - if jit_active() and (n_outer > 1 or get_profile_mode()): + # error if under jit and there are unrolled loops + if jit_active() and n_outer > 1: msg = ( '`run_mcmc` was called within a jit-compiled function and ' - 'there are either more than 1 outer loops or profile mode is active, ' - 'please either do not jit, set `inner_loop_length=None`, or disable ' - 'profile mode.' + 'there are more than 1 outer loops, ' + 'please either do not jit or set `inner_loop_length=None`' ) raise RuntimeError(msg) @@ -383,13 +378,14 @@ class _CallCounter: def __init__(self, func: Callable[..., T]) -> None: self.func = func self.n_calls = 0 + update_wrapper(self, func) def reset_call_counter(self) -> None: """Reset the call counter.""" self.n_calls = 0 def __call__(self, *args: Any, **kwargs: Any) -> T: - if self.n_calls and not get_profile_mode(): + if self.n_calls: msg = ( 'The inner loop of `run_mcmc` was traced more than once, ' 'which indicates a double compilation of the MCMC code. This ' @@ -403,7 +399,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> T: return self.func(*args, **kwargs) -@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(2, 3, 4)) +@partial(jit, donate_argnums=(0,), static_argnums=(2, 3, 4)) @_CallCounter def _run_mcmc_inner_loop( carry: _Carry, @@ -472,13 +468,10 @@ def body(carry: _Carry) -> _Carry: callback_state=callback_state, ) - return while_loop_if_not_profiling(cond, body, carry) + return lax.while_loop(cond, body, carry) -@partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3)) -# this is jitted because under profiling _run_mcmc_inner_loop and the loop -# within it are not, so I need the donate_argnums feature of jit to avoid -# creating copies of the traces +@named_call def _save_state_to_trace( burnin_trace: PyTree, main_trace: PyTree, @@ -649,10 +642,10 @@ def just_dot_branch() -> None: ) # logging can't do in-line printing so we use print - cond_if_not_profiling( + lax.cond( report_cond, line_report_branch, - lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None), + lambda: lax.cond(dot_cond, just_dot_branch, lambda: None), ) diff --git a/src/bartz/mcmcstep/_moves.py b/src/bartz/mcmcstep/_moves.py index d071c0d..c89193e 100644 --- a/src/bartz/mcmcstep/_moves.py +++ b/src/bartz/mcmcstep/_moves.py @@ -28,14 +28,13 @@ import jax from equinox import Module +from jax import named_call, random from jax import numpy as jnp -from jax import random from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt from bartz import grove -from bartz._profiler import jit_and_block_if_profiling from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc -from bartz.mcmcstep._state import Forest, field, vmap_chains +from bartz.mcmcstep._state import Forest, field class Moves(Module): @@ -106,8 +105,7 @@ class Moves(Module): computed.""" -@jit_and_block_if_profiling -@vmap_chains +@named_call def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: """ Propose moves for all the trees. @@ -218,6 +216,7 @@ class GrowMoves(Module): would be produced as `True` if it would have available decision rules.""" +@named_call @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) def propose_grow_moves( key: Key[Array, ' num_trees'], @@ -779,6 +778,7 @@ class PruneMoves(Module): growable.""" +@named_call @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) def propose_prune_moves( key: Key[Array, ''], diff --git a/src/bartz/mcmcstep/_state.py b/src/bartz/mcmcstep/_state.py index 8275aa6..05945dd 100644 --- a/src/bartz/mcmcstep/_state.py +++ b/src/bartz/mcmcstep/_state.py @@ -1203,9 +1203,7 @@ def split_key(x: object) -> object: return tree.map(split_key, x) -def vmap_chains( - fun: Callable[..., T], *, auto_split_keys: bool = False -) -> Callable[..., T]: +def vmap_chains(fun: Callable[..., T]) -> Callable[..., T]: """Apply vmap on chain axes automatically if the inputs are multichain.""" @wraps(fun) @@ -1213,8 +1211,7 @@ def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T: all_args = args, kwargs num_chains = get_num_chains(all_args) if num_chains is not None: - if auto_split_keys: - all_args = _split_all_keys(all_args, num_chains) + all_args = _split_all_keys(all_args, num_chains) def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: return fun(*args, **kwargs) diff --git a/src/bartz/mcmcstep/_step.py b/src/bartz/mcmcstep/_step.py index 52638c8..89db1a9 100644 --- a/src/bartz/mcmcstep/_step.py +++ b/src/bartz/mcmcstep/_step.py @@ -36,29 +36,21 @@ import jax from equinox import Module, tree_at -from jax import lax, random, vmap +from jax import jit, lax, named_call, random, vmap from jax import numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import gammaln, logsumexp from jax.sharding import Mesh, PartitionSpec from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32 -from bartz._profiler import ( - get_profile_mode, - jit_and_block_if_profiling, - jit_if_not_profiling, - jit_if_profiling, - vmap_chains_if_not_profiling, - vmap_chains_if_profiling, -) from bartz.grove import var_histogram from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc from bartz.mcmcstep._moves import Moves, propose_moves -from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field +from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field, vmap_chains -@partial(jit_if_not_profiling, donate_argnums=(1,)) -@partial(vmap_chains_if_not_profiling, auto_split_keys=True) +@partial(jit, donate_argnums=(1,)) +@vmap_chains def step(key: Key[Array, ''], bart: State) -> State: """ Do one MCMC step. @@ -80,17 +72,10 @@ def step(key: Key[Array, ''], bart: State) -> State: state can not be used any more after calling `step`. All this applies outside of `jax.jit`. """ - # handle the interactions between chains and profile mode - num_chains = bart.forest.num_chains() - chain_shape = () if num_chains is None else (num_chains,) - if get_profile_mode() and num_chains is not None and key.ndim == 0: - key = random.split(key, num_chains) - assert key.shape == chain_shape - keys = split(key, 3) if bart.y.dtype == bool: - bart = replace(bart, error_cov_inv=jnp.ones(chain_shape)) + bart = replace(bart, error_cov_inv=jnp.array(1.0)) bart = step_trees(keys.pop(), bart) bart = replace(bart, error_cov_inv=None) bart = step_z(keys.pop(), bart) @@ -103,6 +88,7 @@ def step(key: Key[Array, ''], bart: State) -> State: return step_config(bart) +@named_call def step_trees(key: Key[Array, ''], bart: State) -> State: """ Forest sampling step of BART MCMC. @@ -127,6 +113,7 @@ def step_trees(key: Key[Array, ''], bart: State) -> State: return accept_moves_and_sample_leaves(keys.pop(), bart, moves) +@named_call def accept_moves_and_sample_leaves( key: Key[Array, ''], bart: State, moves: Moves ) -> State: @@ -295,8 +282,7 @@ class ParallelStageOut(Module): """Object with pre-computed terms of the leaf samples.""" -@partial(jit_and_block_if_profiling, donate_argnums=(1, 2)) -@vmap_chains_if_profiling +@named_call def accept_moves_parallel_stage( key: Key[Array, ''], bart: State, moves: Moves ) -> ParallelStageOut: @@ -400,6 +386,7 @@ def accept_moves_parallel_stage( ) +@named_call @partial(vmap_nodoc, in_axes=(0, 0, None)) def apply_grow_to_indices( moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] @@ -491,6 +478,7 @@ def _compute_count_or_prec_tree( return trees, counts +@named_call def compute_count_trees( leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig ) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]: @@ -518,6 +506,7 @@ def compute_count_trees( return _compute_count_or_prec_trees(None, leaf_indices, moves, config) +@named_call def compute_prec_trees( prec_scale: Float32[Array, ' n'], leaf_indices: UInt[Array, 'num_trees n'], @@ -604,6 +593,7 @@ def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Move ) +@named_call @vmap_nodoc def adapt_leaf_trees_to_grow_indices( leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves @@ -702,6 +692,7 @@ def _term_from_chol( return prelkv, None +@named_call def precompute_likelihood_terms( error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], @@ -818,6 +809,7 @@ def _precompute_leaf_terms_mv( return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) +@named_call def precompute_leaf_terms( key: Key[Array, ''], prec_trees: Float32[Array, 'num_trees 2**d'], @@ -864,8 +856,7 @@ def precompute_leaf_terms( ) -@partial(jit_and_block_if_profiling, donate_argnums=(0,)) -@vmap_chains_if_profiling +@named_call def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: """ Accept/reject the moves one tree at a time. @@ -982,6 +973,7 @@ class SeqStageInPerTree(Module): """The pre-computed terms of the leaf sampling which are specific to the tree.""" +@named_call def accept_move_and_sample_leaves( resid: Float32[Array, ' n'] | Float32[Array, ' k n'], at: SeqStageInAllTrees, @@ -1076,6 +1068,7 @@ def accept_move_and_sample_leaves( return resid, leaf_tree, acc, to_prune, log_lk_ratio +@named_call @partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)') def sum_resid( scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'], @@ -1224,6 +1217,7 @@ def _quadratic_form( return prelkv.log_sqrt_term + exp_term +@named_call def compute_likelihood_ratio( total_resid: Float32[Array, ''] | Float32[Array, ' k'], left_resid: Float32[Array, ''] | Float32[Array, ' k'], @@ -1264,8 +1258,7 @@ def compute_likelihood_ratio( ) -@partial(jit_and_block_if_profiling, donate_argnums=(0, 1)) -@vmap_chains_if_profiling +@named_call def accept_moves_final_stage(bart: State, moves: Moves) -> State: """ Post-process the mcmc state after accepting/rejecting the moves. @@ -1297,6 +1290,7 @@ def accept_moves_final_stage(bart: State, moves: Moves) -> State: ) +@named_call @vmap_nodoc def apply_moves_to_leaf_indices( leaf_indices: UInt[Array, 'num_trees n'], moves: Moves @@ -1325,6 +1319,7 @@ def apply_moves_to_leaf_indices( ) +@named_call @vmap_nodoc def apply_moves_to_split_trees( split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves @@ -1416,8 +1411,7 @@ def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State: return replace(bart, error_cov_inv=prec) -@partial(jit_and_block_if_profiling, donate_argnums=(1,)) -@vmap_chains_if_profiling +@named_call def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State: """ MCMC-update the inverse error covariance. @@ -1443,8 +1437,7 @@ def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State: return _step_error_cov_inv_uv(key, bart) -@partial(jit_and_block_if_profiling, donate_argnums=(1,)) -@vmap_chains_if_profiling +@named_call def step_z(key: Key[Array, ''], bart: State) -> State: """ MCMC-update the latent variable for binary regression. @@ -1467,6 +1460,7 @@ def step_z(key: Key[Array, ''], bart: State) -> State: return replace(bart, z=z, resid=resid) +@named_call def step_s(key: Key[Array, ''], bart: State) -> State: """ Update `log_s` using Dirichlet sampling. @@ -1508,6 +1502,7 @@ def step_s(key: Key[Array, ''], bart: State) -> State: return replace(bart, forest=replace(bart.forest, log_s=log_s)) +@named_call def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: """ Update `theta`. @@ -1569,8 +1564,7 @@ def _log_p_lamda( ), theta -@partial(jit_and_block_if_profiling, donate_argnums=(1,)) -@vmap_chains_if_profiling +@named_call def step_sparse(key: Key[Array, ''], bart: State) -> State: """ Update the sparsity parameters. @@ -1608,8 +1602,7 @@ def _step_sparse(key: Key[Array, ''], bart: State) -> State: return bart -@jit_if_profiling -# jit to avoid the overhead of replace(_: Module) +@named_call def step_config(bart: State) -> State: config = bart.config config = replace(config, steps_done=config.steps_done + 1) diff --git a/tests/test_BART.py b/tests/test_BART.py index 00da99e..16db441 100644 --- a/tests/test_BART.py +++ b/tests/test_BART.py @@ -34,7 +34,6 @@ from gc import collect from os import getpid, kill from signal import SIG_IGN, SIGINT, getsignal, signal -from sys import version_info from threading import Event, Thread from time import monotonic from typing import Any, Literal @@ -65,7 +64,6 @@ from numpy.testing import assert_allclose, assert_array_equal from pytest_subtests import SubTests -from bartz import profile_mode from bartz.debug import TraceWithOffset, check_trace, sample_prior, trees_BART_to_bartz from bartz.debug import debug_gbart as gbart from bartz.debug import debug_mc_gbart as mc_gbart @@ -82,11 +80,7 @@ from bartz.mcmcstep._state import chain_vmap_axes from tests.rbartpackages import BART3 from tests.test_mcmcstep import check_sharding, get_normal_spec, normalize_spec -from tests.util import ( - assert_close_matrices, - assert_different_matrices, - get_old_python_tuple, -) +from tests.util import assert_close_matrices, assert_different_matrices def gen_X( @@ -1495,39 +1489,6 @@ def test_gbart_multichain_error(keys: split) -> None: gbart(X, y, mc_cores='gatto') -def test_same_result_profiling(variant: int, kw: dict) -> None: - """Check that the result is the same in profiling mode.""" - bart = mc_gbart(**kw) - with profile_mode(True): - kw.update(seed=random.clone(kw['seed'])) - bartp = mc_gbart(**kw) - - platform = get_default_device().platform - python_version = version_info[:2] - old_python = get_old_python_tuple() - exact_check = platform != 'gpu' and python_version != old_python - - def check_same(_path: KeyPath, x: Array, xp: Array) -> None: - if exact_check: - assert_array_equal(xp, x) - else: - assert_allclose(xp, x, atol=1e-5, rtol=1e-5) - - try: - tree.map_with_path(check_same, bart._mcmc_state, bartp._mcmc_state) - tree.map_with_path(check_same, bart._main_trace, bartp._main_trace) - except AssertionError as a: - if ( - '\nNot equal to tolerance ' in str(a) - and not exact_check - and python_version == old_python - and variant in (1, 3) - ): - pytest.xfail('unsolved bug with old toolchain') - else: - raise - - def get_expect_sharded(kw: dict) -> bool: """Check whether we expect sharding to be set up based on the arguments.""" bart_kwargs = kw.get('bart_kwargs', {}) diff --git a/tests/test_jaxext.py b/tests/test_jaxext.py index eaf5a99..3a4c502 100644 --- a/tests/test_jaxext.py +++ b/tests/test_jaxext.py @@ -385,11 +385,6 @@ def test_split(keys: split) -> None: ks = jaxext.split(keys.pop()) assert len(ks) == 2 - ks = keys.pop(4) - k1 = jnp.stack([jaxext.split(k).pop() for k in ks]) - k2 = jaxext.split(random.clone(ks)).pop() - assert not different_keys(k1, k2) - class TestJaxPatches: """Check that some jax stuff I patch is correct and still to be patched.""" diff --git a/tests/test_mcmcloop.py b/tests/test_mcmcloop.py index 131456b..d32c790 100644 --- a/tests/test_mcmcloop.py +++ b/tests/test_mcmcloop.py @@ -37,9 +37,7 @@ from jaxtyping import Array, Float32, UInt8 from numpy.testing import assert_array_equal from pytest import FixtureRequest # noqa: PT013 -from pytest_subtests import SubTests -from bartz import profile_mode from bartz.jaxext import get_default_device, split from bartz.mcmcloop import BurninTrace, MainTrace, run_mcmc from bartz.mcmcstep import State, init, make_p_nonterminal @@ -159,9 +157,7 @@ def check_trace(trace: MainTrace | BurninTrace) -> None: check_trace(burnin_trace) check_trace(main_trace) - def test_predicted_double_compilation( - self, keys: split, subtests: SubTests - ) -> None: + def test_predicted_double_compilation(self, keys: split) -> None: """Check that an error is raised under jit if the configuration would lead to double compilation.""" initial_state = simple_init(10, 100, 20) @@ -169,18 +165,10 @@ def test_predicted_double_compilation( run_mcmc, static_argnames=('n_save', 'inner_loop_length') ) - msg = r'there are either more than 1 outer loops' - - with subtests.test('outer loops'), pytest.raises(RuntimeError, match=msg): + msg = r'there are more than 1 outer loops' + with pytest.raises(RuntimeError, match=msg): compiled_run_mcmc(keys.pop(), initial_state, 2, inner_loop_length=1) - with ( - subtests.test('profile mode'), - profile_mode(True), - pytest.raises(RuntimeError, match=msg), - ): - compiled_run_mcmc(keys.pop(), initial_state, 1) - def test_detected_double_compilation(self, keys: split) -> None: """Check that double compilation is detected.""" state = simple_init(10, 100, 20) diff --git a/tests/test_mcmcstep.py b/tests/test_mcmcstep.py index 20975bf..82a4018 100644 --- a/tests/test_mcmcstep.py +++ b/tests/test_mcmcstep.py @@ -25,6 +25,7 @@ """Test `bartz.mcmcstep`.""" from collections.abc import Sequence +from functools import wraps from math import prod from typing import Literal, NamedTuple @@ -40,7 +41,6 @@ from pytest_subtests import SubTests from scipy import stats -from bartz import profile_mode from bartz.jaxext import get_device_count, minimal_unsigned_dtype, split from bartz.mcmcstep import State, init, step from bartz.mcmcstep._moves import ( @@ -536,6 +536,13 @@ def test_minimal_tree(self) -> None: assert r == 4 +@jaxtyped(typechecker=beartype) +@wraps(step) +def typechecking_step(key: Key[Array, ''], state: State) -> State: + """Wrap `bartz.mcmcstep.step` because `jaxtyping.jaxtyped` can not be applied to a jitted function.""" + return step(key, state) + + class TestMultichain: """Basic tests of the multichain functionality.""" @@ -601,7 +608,6 @@ def test_basic( check_sharding(state, state.config.mesh) with subtests.test('step'): - typechecking_step = jaxtyped(step, typechecker=beartype) with debug_key_reuse(num_chains != 0): # key reuse checks trigger with empty key array apparently new_state = typechecking_step(keys.pop(), state) @@ -609,10 +615,7 @@ def test_basic( check_strong_types(new_state) check_sharding(new_state, state.config.mesh) - @pytest.mark.parametrize('profile', [False, True]) - def test_multichain_equiv_stack( - self, init_kwargs: dict, keys: split, profile: bool - ) -> None: + def test_multichain_equiv_stack(self, init_kwargs: dict, keys: split) -> None: """Check that stacking multiple chains is equivalent to a multichain trace.""" num_chains = 4 num_iters = 10 @@ -631,16 +634,14 @@ def test_multichain_equiv_stack( ] # run a few mcmc steps with the same random keys - with profile_mode(profile): - for _ in range(num_iters): - mc_key = keys.pop() - sc_keys = random.split(random.clone(mc_key), num_chains) - - mc_state = step(mc_key, mc_state) - sc_states = [ - step(key, state) - for key, state in zip(sc_keys, sc_states, strict=True) - ] + for _ in range(num_iters): + mc_key = keys.pop() + sc_keys = random.split(random.clone(mc_key), num_chains) + + mc_state = step(mc_key, mc_state) + sc_states = [ + step(key, state) for key, state in zip(sc_keys, sc_states, strict=True) + ] # stack single-chain states def stack_leaf( diff --git a/tests/test_profiler.py b/tests/test_profiler.py deleted file mode 100644 index c9ddfa1..0000000 --- a/tests/test_profiler.py +++ /dev/null @@ -1,337 +0,0 @@ -# bartz/tests/test_profiler.py -# -# Copyright (c) 2025-2026, The Bartz Contributors -# -# This file is part of bartz. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Test `bartz._profiler`.""" - -from cProfile import Profile -from functools import partial -from pstats import Stats -from time import perf_counter, sleep -from typing import NamedTuple - -import pytest -from jax import debug_infs, debug_nans, jit, pure_callback, random -from jax import numpy as jnp -from jaxtyping import Array, Bool, Float32, Int32, Integer -from numpy.testing import assert_array_equal - -from bartz._profiler import ( - cond_if_not_profiling, - get_profile_mode, - jit_and_block_if_profiling, - jit_if_not_profiling, - profile_mode, - set_profile_mode, - while_loop_if_not_profiling, -) -from bartz.jaxext import get_default_device - - -class TestFlag: - """Test the functionality of the global profile mode flag.""" - - def test_initial_state(self) -> None: - """Check profiling mode is off by default.""" - assert not get_profile_mode() - - def test_getter_setter(self) -> None: - """Test setting and getting the profile mode.""" - set_profile_mode(True) - assert get_profile_mode() - set_profile_mode(False) - assert not get_profile_mode() - - def test_context_manager(self) -> None: - """Test the profile mode context manager.""" - with profile_mode(True): - assert get_profile_mode() - assert not get_profile_mode() - - set_profile_mode(True) - with profile_mode(False): - assert not get_profile_mode() - assert get_profile_mode() - set_profile_mode(False) - - with profile_mode(True): - assert get_profile_mode() - with profile_mode(False): - assert not get_profile_mode() - assert get_profile_mode() - assert not get_profile_mode() - - -class TestScanIfNotProfiling: - """Test `scan_if_not_profiling`.""" - - @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool) -> None: - """Test that `scan_if_not_profiling` has the right output on a simple example.""" - - def cond(carry: Integer[Array, '']) -> Bool[Array, '']: - return carry < 5 - - def body(carry: Integer[Array, '']) -> Integer[Array, '']: - return carry + 1 - - with profile_mode(mode): - carry = while_loop_if_not_profiling(cond, body, 0) - assert carry == 5 - - def test_does_not_jit(self) -> None: - """Check that `scan_if_not_profiling` does not jit the function in profiling mode.""" - - class Carry(NamedTuple): - i: Int32[Array, ''] - state: Int32[Array, ''] - - def cond(carry: Carry) -> Bool[Array, '']: - return carry.i < 5 - - def body(carry: Carry) -> Carry: - return Carry(carry.i + 1, carry.state.block_until_ready()) - - with profile_mode(True): - while_loop_if_not_profiling(cond, body, Carry(jnp.int32(0), jnp.int32(0))) - - with pytest.raises( - AttributeError, - match='DynamicJaxprTracer has no attribute block_until_ready', - ): - while_loop_if_not_profiling(cond, body, Carry(jnp.int32(0), jnp.int32(0))) - - -class TestCondIfNotProfiling: - """Test `cond_if_not_profiling`.""" - - @pytest.mark.parametrize('mode', [True, False]) - @pytest.mark.parametrize('pred', [True, False]) - def test_result(self, mode: bool, pred: bool) -> None: - """Test that `cond_if_not_profiling` has the right output on a simple example.""" - with profile_mode(mode): - out = cond_if_not_profiling( - pred, lambda x: x - 1, lambda x: x + 1, jnp.int32(5) - ) - assert out == (4 if pred else 6) - - def test_does_not_jit(self) -> None: - """Check that `cond_if_not_profiling` does not jit the function in profiling mode.""" - with profile_mode(True): - cond_if_not_profiling( - True, - lambda x: x.block_until_ready(), - lambda x: x.block_until_ready(), - jnp.int32(5), - ) - - with pytest.raises( - AttributeError, - match='DynamicJaxprTracer has no attribute block_until_ready', - ): - cond_if_not_profiling( - False, - lambda x: x.block_until_ready(), - lambda x: x.block_until_ready(), - jnp.int32(5), - ) - - -class TestJitIfNotProfiling: - """Test `jit_if_not_profiling`.""" - - @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool) -> None: - """Test that `jit_if_not_profiling` has the right output in both modes.""" - - def func(x: Integer[Array, '']) -> Integer[Array, '']: - return x * 2 + 1 - - jitted_func = jit_if_not_profiling(func) - - with profile_mode(mode): - result = jitted_func(5) - assert result == 11 - - def test_does_not_jit(self) -> None: - """Check that `jit_if_not_profiling` does not jit the function in profiling mode.""" - - def func(x: Int32[Array, '']) -> Int32[Array, '']: - return x.block_until_ready() - # block_until_ready errors under jit - - jitted_func = jit_if_not_profiling(func) - - with profile_mode(True): - result = jitted_func(jnp.int32(42)) - assert result == 42 - - with pytest.raises( - AttributeError, - match='DynamicJaxprTracer has no attribute block_until_ready', - ): - jitted_func(jnp.int32(42)) - - -class TestJitAndBlockIfProfiling: - """Test `jit_and_block_if_profiling`.""" - - @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool) -> None: - """Test that `jit_and_block_if_profiling` has the right output in both modes.""" - - def func(x: Integer[Array, '']) -> Integer[Array, '']: - return x * 2 + 1 - - jitted_func = jit_and_block_if_profiling(func) - - with profile_mode(mode): - result = jitted_func(5) - assert result == 11 - - def test_jits_when_profiling(self) -> None: - """Check that `jit_and_block_if_profiling` jits when profiling is enabled.""" - - def func(x: Int32[Array, '']) -> Int32[Array, '']: - return x.block_until_ready() - # block_until_ready errors under jit - - jitted_func = jit_and_block_if_profiling(func) - - # When profiling is ON, function IS jitted, so should error - with ( - pytest.raises( - AttributeError, - match='DynamicJaxprTracer has no attribute block_until_ready', - ), - profile_mode(True), - ): - jitted_func(0) - - # When profiling is OFF, function is NOT jitted, so should work - with profile_mode(False): - jitted_func(jnp.int32(0)) - - def test_static_args(self) -> None: - """Check that it works with static arguments.""" - - def func(n: int) -> Integer[Array, ' {n}']: - return jnp.arange(n) - - jitted_func = jit_and_block_if_profiling(func, static_argnums=(0,)) - - with profile_mode(True): - result = jitted_func(5) - assert_array_equal(result, jnp.arange(5)) - - @pytest.mark.flaky(max_runs=3) - # flaky because it involves comparing time measurements done on the fly - def test_blocks_execution(self) -> None: - """Check that `jit_and_block_if_profiling` blocks execution when profiling.""" - with debug_nans(False), debug_infs(False): - platform = get_default_device().platform - match platform: - case 'cpu': - n = 2000 - case 'gpu': - n = 10_000 - case _: # pragma: no cover - msg = f'Unsupported platform for timing test: {platform}' - raise RuntimeError(msg) - - func = lambda: idle(n) # about 50-100 ms - jit_func = jit(func) - jab_func = jit_and_block_if_profiling(func) - - # Time the jitted function - for _ in range(3): - jit_func().block_until_ready() # Warm-up - start = perf_counter() - jit_func().block_until_ready() - expected = perf_counter() - start - - # Check execution is async - start = perf_counter() - result = jit_func() - elapsed = perf_counter() - start - result.block_until_ready() # Ensure completion - assert elapsed < expected / 2 - - # Test profiling mode first (should block and wait >= expected) - with profile_mode(True): - jab_func() # Warm-up - start = perf_counter() - jab_func() - elapsed = perf_counter() - start - assert elapsed >= expected * 0.9, ( - f'Expected blocking to wait >= {expected:#.2g}s, got {elapsed:#.2g}s' - ) - - # Test non-profiling mode (should be async, < expected) - jab_func().block_until_ready() # Warm-up - start = perf_counter() - result = jab_func() - elapsed = perf_counter() - start - result.block_until_ready() # Ensure completion - assert elapsed < expected / 2, ( - f'Expected async execution << {expected:#.2g}s, got {elapsed:#.2g}s' - ) - - def test_profile(self) -> None: - """Test `jit_and_block_if_profiling` under the Python profiler.""" - runtime = 0.1 - - @jit_and_block_if_profiling - # weird name to make sure identifiers are legit - def awlkugh() -> Int32[Array, '']: - x = jnp.int32(0) - - def sleeper(x: Int32[Array, '']) -> Int32[Array, '']: - sleep(runtime) - return x - - return pure_callback(sleeper, x, x) - - with profile_mode(True): - for _ in range(2): - awlkugh() # warm-up - - with Profile() as prof, profile_mode(True): - awlkugh() - - stats = Stats(prof).get_stats_profile() - - assert 'awlkugh' not in stats.func_profiles - # it's not there because it was traced during warm-up - - p_run = stats.func_profiles['jab_inner_wrapper'] - - assert runtime < p_run.cumtime < 10 * runtime - - -@partial(jit, static_argnums=(0,)) -def idle(n: int) -> Float32[Array, ' {n} {n}']: - """Waste time in jax computation.""" - key = random.key(0) - x = random.normal(key, (n, n)) - return x @ x