diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 93a0e410..dbe9c7d2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,6 +43,8 @@ env: jobs: pre-commit: runs-on: ubuntu-latest + env: + SKIP: no-commit-to-branch steps: - uses: actions/checkout@v6 with: @@ -84,11 +86,36 @@ jobs: - id: set-matrix run: | BASE_MATRIX='[ - {"os":"ubuntu-latest","runs-on":"ubuntu-latest","target-suffix":""}, - {"os":"ubuntu-22.04","runs-on":"ubuntu-22.04","target-suffix":"-old"} + { + "id": "latest", + "runs-on": "ubuntu-latest", + "target-suffix": "", + "tests-args": "" + }, + { + "id": "old", + "runs-on": "ubuntu-22.04", + "target-suffix": "-old", + "tests-args": "" + }, + { + "id": "single-cpu", + "runs-on": "ubuntu-latest", + "target-suffix": "", + "tests-args": "ARGS=--num-cpu-devices=1" + } ]' BENCHMARKS_BASE_MATRIX='[ - {"os":"ubuntu-latest","runs-on":"ubuntu-latest","target-suffix":""} + { + "id": "latest", + "runs-on": "ubuntu-latest", + "vars": "" + }, + { + "id": "single-cpu", + "runs-on": "ubuntu-latest", + "vars": "BARTZ_BENCHMARKS_SINGLE_CPU_DEVICE=1 " + } ]' MATRIX=$(echo "$BASE_MATRIX" | jq -c ".") @@ -98,7 +125,7 @@ jobs: echo "benchmarks-matrix={\"include\":$BENCHMARKS_MATRIX}" >> $GITHUB_OUTPUT tests: - name: tests-${{ matrix.os }}${{ matrix.target-suffix }} + name: tests-${{ matrix.id }} needs: setup-matrix runs-on: ${{ matrix.runs-on }} continue-on-error: ${{ matrix.target-suffix == '-gpu' }} @@ -115,7 +142,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: ${{ env.UV_VERSION }} - cache-suffix: -${{ matrix.os }}${{ matrix.target-suffix }} + cache-suffix: -${{ matrix.id }} - name: Install R uses: r-lib/actions/setup-r@v2 @@ -134,21 +161,15 @@ jobs: ldd "$(R RHOME)/library/methods/libs/methods.so" || true uv run --dev python -c "import rpy2.situation as s; print('\n'.join(map(str, s.iter_info())))" - - name: Install jax GPU support - if: matrix.target-suffix == '-gpu' - run: | - CUDA_VERSION=`nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3` - uv pip install "jax[cuda$CUDA_VERSION]" - - name: Run unit tests timeout-minutes: 30 - run: make COVERAGE_SUFFIX=-${{ matrix.os }}${{ matrix.target-suffix }} tests${{ matrix.target-suffix }} + run: make COVERAGE_SUFFIX=-${{ matrix.id }} tests${{ matrix.target-suffix }} ${{ matrix.tests-args }} - name: Save coverage information uses: actions/upload-artifact@v6 with: - name: coverage.tests-${{ matrix.os }}${{ matrix.target-suffix }} - path: .coverage.tests-${{ matrix.os }}${{ matrix.target-suffix }} + name: coverage.tests-${{ matrix.id }} + path: .coverage.tests-${{ matrix.id }} include-hidden-files: true if-no-files-found: error @@ -279,7 +300,7 @@ jobs: uses: actions/deploy-pages@v4 benchmarks: - name: benchmarks-${{ matrix.os }}${{ matrix.target-suffix }} + name: benchmarks-${{ matrix.id }} needs: setup-matrix runs-on: ${{ matrix.runs-on }} continue-on-error: ${{ matrix.target-suffix == '-gpu' }} @@ -296,19 +317,14 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: ${{ env.UV_VERSION }} - cache-suffix: -${{ matrix.os }}${{ matrix.target-suffix }} + cache-suffix: -${{ matrix.id }} - name: Get machine information for asv run: RPY2_CFFI_MODE=ABI uv run --dev python -m asv machine --yes - - name: Install jax GPU support - if: matrix.target-suffix == '-gpu' - run: | - CUDA_VERSION=`nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3` - uv pip install "jax[cuda$CUDA_VERSION]" - - name: Run asv in quick mode - run: make asv-quick ARGS='--attribute timeout=120' # long timeout because CI runners are slow + # long timeout because CI runners are slow + run: ${{ matrix.vars }}make asv-quick ARGS='--attribute timeout=120' dev-commands: runs-on: ubuntu-latest @@ -328,6 +344,9 @@ jobs: - name: run IPython with old toolchain run: RPY2_CFFI_MODE=ABI make ipython-old ARGS=config/ipython/profile_default/startup/startup.ipy + - name: Update package version + run: make copy-version + required-dummy: runs-on: ubuntu-latest needs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6733d15..7c2f2a8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,7 @@ repos: exclude: tests/(rbartpackages/.+\.py|util\.py)$ - id: trailing-whitespace exclude: '^renv/activate\.R$' # because renv edits it automatically, with trailing whitespace + - id: no-commit-to-branch - repo: https://github.com/sbrunner/hooks rev: 1.6.1 hooks: @@ -55,4 +56,8 @@ repos: - repo: https://github.com/jorisroovers/gitlint rev: v0.19.1 hooks: - - id: gitlint # this is already defined on stage commit-msg + - id: gitlint # this is already defined to run in stage commit-msg + - repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.10.4 + hooks: + - id: uv-lock diff --git a/Makefile b/Makefile index 1534a6d8..5fe3ba62 100644 --- a/Makefile +++ b/Makefile @@ -35,8 +35,8 @@ UV_RUN = uv run --dev $(EXTRAS) OLD_PYTHON = $(shell grep 'requires-python' pyproject.toml | sed 's/.*>=\([0-9.]*\).*/\1/') UV_RUN_OLD = $(UV_RUN) --python=$(OLD_PYTHON) --resolution=lowest-direct --exclude-newer=2025-05-15 --isolated -.PHONY: all -all: +.PHONY: help +help: @echo "Available targets:" @echo "- setup: create R and Python environments for development" @echo "- tests: run unit tests on cpu, saving coverage information" @@ -61,8 +61,10 @@ all: @echo "- asv-quick: run quick benchmarks on current code, no saving" @echo "- ipython: start an ipython shell with stuff pre-imported" @echo "- ipython-old: start an ipython shell with oldest supported python and dependencies" + @echo "- lint: run the linter used in pre-commit" @echo @echo "Release workflow:" + @echo "- do a PR that re-runs benchmarks" @echo "- create a new branch" @echo "- $$ uv version --bump major|minor|patch" @echo "- describe release in docs/changelog.md" @@ -71,7 +73,7 @@ all: @echo "- $$ make release (iterate to fix problems)" @echo "- if CI does not pass, debug and go back to make release" @echo "- merge PR" - @echo "- if CI does not pass, debug and go back to make release" + @echo "- if CI does not pass, debug and go back to open PR" @echo "- $$ make upload" @echo "- publish github release (updates zenodo automatically)" @echo "- if the online docs are not up-to-date, merge another PR to trigger a new merge CI" @@ -82,6 +84,9 @@ setup: Rscript -e "renv::restore()" $(UV_RUN) pre-commit install --install-hooks +.PHONY: lint +lint: + $(UV_RUN) pre-commit run --all-files ruff-check ################# TESTS ################# @@ -89,7 +94,7 @@ TESTS_VARS = COVERAGE_FILE=.coverage.tests$(COVERAGE_SUFFIX) TESTS_COMMAND = python -m pytest --cov --cov-context=test --dist=worksteal --durations=1000 TESTS_CPU_VARS = $(TESTS_VARS) JAX_PLATFORMS=cpu TESTS_CPU_COMMAND = $(TESTS_COMMAND) --platform=cpu --numprocesses=2 -TESTS_GPU_VARS = $(TESTS_VARS) XLA_PYTHON_CLIENT_MEM_FRACTION=.20 +TESTS_GPU_VARS = $(TESTS_VARS) XLA_PYTHON_CLIENT_PREALLOCATE=false TESTS_GPU_COMMAND = $(TESTS_COMMAND) --platform=gpu --numprocesses=3 .PHONY: tests @@ -129,7 +134,6 @@ docs-latest: WORKTREE_DIR=$$(mktemp -d) && \ trap "git worktree remove --force '$$WORKTREE_DIR' 2>/dev/null || rm -rf '$$WORKTREE_DIR'" EXIT && \ git worktree add --detach "$$WORKTREE_DIR" "$$LATEST_TAG" && \ - uv sync --all-groups --directory "$$WORKTREE_DIR" && \ $(MAKE) -C "$$WORKTREE_DIR" docs && \ test ! -d _site/docs || rm -r _site/docs && \ mv "$$WORKTREE_DIR/_site/docs-dev" _site/docs @@ -172,6 +176,7 @@ clean: rm -fr .venv rm -fr dist rm -fr config/jax_cache + rm -fr docs/_build .PHONY: release release: clean update-deps copy-version check-committed tests tests-old docs @@ -180,11 +185,17 @@ release: clean update-deps copy-version check-committed tests tests-old docs .PHONY: version-tag version-tag: copy-version check-committed git fetch --tags - git tag v$(shell uv run python -c 'import bartz; print(bartz.__version__)') - git push --tags + $(eval VERSION_TAG := v$(shell uv run python -c 'import bartz; print(bartz.__version__)')) + git tag $(VERSION_TAG) + git push origin $(VERSION_TAG) + +.PHONY: smoke-test +smoke-test: + uv run --isolated --no-project --with dist/*.whl python -c 'import bartz' + uv run --isolated --no-project --with dist/*.tar.gz python -c 'import bartz' .PHONY: upload -upload: version-tag +upload: smoke-test version-tag @echo "Enter PyPI token:" @read -s UV_PUBLISH_TOKEN && \ export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \ @@ -194,7 +205,7 @@ upload: version-tag uv tool run --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)' .PHONY: upload-test -upload-test: check-committed +upload-test: smoke-test check-committed @echo "Enter TestPyPI token:" @read -s UV_PUBLISH_TOKEN && \ export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \ diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 9268620e..7843f633 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -24,10 +24,13 @@ """Benchmarking code run by asv.""" +from os import getenv + from jax import config from benchmarks._vendor_latest_bartz import ensure_vendored ensure_vendored() -config.update('jax_num_cpu_devices', 16) +if getenv('BARTZ_BENCHMARKS_SINGLE_CPU_DEVICE') is None: + config.update('jax_num_cpu_devices', 16) diff --git a/benchmarks/rmse.py b/benchmarks/rmse.py index 088125f3..51a2c0d9 100644 --- a/benchmarks/rmse.py +++ b/benchmarks/rmse.py @@ -99,7 +99,7 @@ def run_sim_impl( # adapt for older versions sig = signature(mc_gbart) - def drop_if_missing(arg: str): + def drop_if_missing(arg: str) -> None: if arg not in sig.parameters: kw.pop(arg) diff --git a/benchmarks/speed.py b/benchmarks/speed.py index e27b04eb..6f0cc91e 100644 --- a/benchmarks/speed.py +++ b/benchmarks/speed.py @@ -48,13 +48,19 @@ from jax.errors import JaxRuntimeError from jax.sharding import Mesh from jax.tree_util import tree_map -from jaxtyping import Array, Float32, UInt8 +from jaxtyping import Array, Float32, Integer, Key, UInt8 import bartz from bartz import mcmcloop, mcmcstep from bartz.mcmcloop import run_mcmc from benchmarks.latest_bartz.jaxext import get_device_count, split +try: + from bartz.mcmcstep import State +except ImportError: + # old versions use a dictionary to store the mcmc state + State: type = dict + try: from bartz.BART import mc_gbart as gbart except ImportError: @@ -91,13 +97,13 @@ def gen_nonsense_data( return X, y, max_split -Kind = Literal['plain', 'weights', 'binary', 'sparse'] +Kind = Literal['plain', 'weights', 'binary', 'sparse', 'multivariate'] def get_default_platform() -> str: """Get the default JAX platform (cpu, gpu).""" with ensure_compile_time_eval(): - return jnp.zeros(()).platform() + return jnp.zeros(0).platform() def simple_init( # noqa: C901, PLR0915 @@ -106,24 +112,28 @@ def simple_init( # noqa: C901, PLR0915 num_trees: int, kind: Kind = 'plain', *, - k: int | None = None, num_chains: int | None = None, mesh: dict[str, int] | Mesh | None = None, - **kwargs, -): + **kwargs: Any, +) -> State: """Glue code to support `mcmcstep.init` across API changes.""" + # generate data + if kind == 'multivariate': + k = 2 + else: + k = None X, y, max_split = gen_nonsense_data(p, n, k) kw: dict = dict( X=X, y=y, - offset=0.0, + offset=0.0 if k is None else jnp.zeros(k), max_split=max_split, num_trees=num_trees, p_nonterminal=make_p_nonterminal(6, 0.95, 2), - leaf_prior_cov_inv=jnp.float32(num_trees), + leaf_prior_cov_inv=jnp.float32(num_trees) * (1.0 if k is None else jnp.eye(k)), error_cov_df=2.0, - error_cov_scale=2.0, + error_cov_scale=2.0 * (1.0 if k is None else jnp.eye(k)), min_points_per_decision_node=10, num_chains=num_chains, mesh=mesh, @@ -191,6 +201,11 @@ def simple_init( # noqa: C901, PLR0915 if 'sparse_on_at' not in sig.parameters: kw.pop('sparse_on_at') + case 'multivariate': + if 'leaf_prior_cov_inv' not in sig.parameters: + msg = 'multivariate not supported' + raise NotImplementedError(msg) + kw.update(kwargs) return init(**kw) @@ -203,7 +218,7 @@ def simple_init( # noqa: C901, PLR0915 class AutoParamNames: """Superclass that automatically sets `param_names` on subclasses.""" - def __init_subclass__(cls, **_): + def __init_subclass__(cls, **_: Any) -> None: method = cls.setup sig = signature(method) params = list(sig.parameters) @@ -216,11 +231,11 @@ class StepGeneric(AutoParamNames): params: tuple[tuple[Mode, ...], tuple[Kind, ...], tuple[int | None, ...]] = ( ('compile', 'run'), - ('plain', 'binary', 'weights', 'sparse'), + ('plain', 'binary', 'weights', 'sparse', 'multivariate'), (None, 1, 2), ) - def setup(self, mode: Mode, kind: Kind, chains: int | None, **kwargs): + def setup(self, mode: Mode, kind: Kind, chains: int | None, **kwargs: Any) -> None: """Create an initial MCMC state and random seed, compile & warm-up.""" keys = list(random.split(random.key(2025_06_24_12_07))) @@ -229,7 +244,7 @@ def setup(self, mode: Mode, kind: Kind, chains: int | None, **kwargs): self.args = (keys, simple_init(**kw)) - def func(keys, bart): + def func(keys: list[Key[Array, '']], bart: State) -> State: sparse_inside_step = not hasattr(mcmcloop, 'sparse_callback') if kind == 'sparse' and sparse_inside_step: bart = replace(bart, config=replace(bart.config, sparse_on_at=0)) @@ -244,7 +259,7 @@ def func(keys, bart): block_until_ready(self.compiled_func(*self.args)) self.mode = mode - def time_step(self, *_): + def time_step(self, *_: Any) -> None: """Time compiling `step` or running it.""" match self.mode: case 'compile': @@ -261,7 +276,7 @@ class StepSharded(StepGeneric): params = ((False, True),) - def setup(self, sharded: bool): # ty:ignore[invalid-method-override] + def setup(self, sharded: bool) -> None: # ty:ignore[invalid-method-override] """Set up with settings that make the effect of sharding salient.""" sig = signature(init) if 'mesh' not in sig.parameters: @@ -299,8 +314,9 @@ def setup( nchains: int = 1, cache: Cache = 'warm', profile: bool = False, + predict: bool = False, kwargs: Mapping[str, Any] = MappingProxyType({}), - ): + ) -> None: """Prepare the arguments and run once to warm-up.""" # check support for multiple chains sig = signature(gbart) @@ -315,7 +331,7 @@ def setup( # generate simulated data dgp = gen_data( keys.pop(), - n=N, + n=2 * N, p=P, k=1, q=2, @@ -324,11 +340,13 @@ def setup( sigma2_quad=0.4, sigma2_eps=0.2, ) + train, test = dgp.split() + block_until_ready((train, test)) # arguments - self.kw = dict( - x_train=dgp.x, - y_train=dgp.y.squeeze(0), + self.kw: dict = dict( + x_train=train.x, + y_train=train.y.squeeze(0), ntree=NTREE, nskip=niters // 2, ndpost=(niters - niters // 2) * nchains, @@ -337,6 +355,7 @@ def setup( if support_multichain: self.kw.update(mc_cores=nchains) self.kw.update(kwargs) + block_until_ready(self.kw) # set profile mode if not profile: @@ -347,6 +366,14 @@ def setup( msg = 'Profile mode not supported.' raise NotImplementedError(msg) + # save information used to run predictions + self.predict = predict + if predict: + self.test = test + with self.context(): + self.bart = gbart(**self.kw) + block_bart(self.bart) + # decide how much to cold-start match cache: case 'cold': @@ -356,14 +383,23 @@ def setup( case _: raise KeyError(cache) - def time_gbart(self, *_): + def time_gbart(self, *_: Any) -> None: """Time instantiating the class.""" with redirect_stdout(StringIO()), self.context(): - bart = gbart(**self.kw) - if isinstance(bart, Module): - block_until_ready(bart) + if self.predict: + ypred = self.bart.predict(self.test.x) + block_until_ready(ypred) else: - block_until_ready((bart._mcmc_state, bart._main_trace)) + bart = gbart(**self.kw) + block_bart(bart) + + +def block_bart(bart: gbart) -> None: + """Block a bart object until ready, adapting for old versions.""" + if isinstance(bart, Module): + block_until_ready(bart) + else: + block_until_ready((bart._mcmc_state, bart._main_trace)) class GbartIters(BaseGbart): @@ -380,7 +416,7 @@ class GbartChains(BaseGbart): params = ((1, 2, 4, 8, 16, 32), (False, True)) - def setup(self, nchains: int, shard: bool): # ty:ignore[invalid-method-override] + def setup(self, nchains: int, shard: bool) -> None: # ty:ignore[invalid-method-override] """Set up to use or not multiple cpus.""" # check there is support for multichain if 'mc_cores' not in signature(gbart).parameters: @@ -403,13 +439,13 @@ def setup(self, nchains: int, shard: bool): # ty:ignore[invalid-method-override # on gpu shard explicitly kwargs = dict(num_chain_devices=min(nchains, get_device_count())) - super().setup(NITERS, nchains, 'warm', False, dict(bart_kwargs=kwargs)) + super().setup(NITERS, nchains, 'warm', False, False, dict(bart_kwargs=kwargs)) class GbartGeneric(BaseGbart): """General timing of `mc_gbart` with many settings.""" - params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True)) + params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True), (False, True)) class BaseRunMcmc(AutoParamNames): @@ -430,7 +466,7 @@ def setup( cache: Cache = 'warm', kwargs: Mapping[str, Any] = MappingProxyType({}), n: int = N, - ): + ) -> None: """Prepare the arguments, compile the function, and run to warm-up.""" kw: dict = dict( key=random.key(2025_04_25_15_57), @@ -462,12 +498,12 @@ def setup( static_argnames += ('inner_callback',) f = jit(run_mcmc, static_argnames=static_argnames) - def task(): + def task() -> None: f.clear_cache() f.lower(**kw).compile() case 'run': - def task(): + def task() -> None: block_until_ready(run_mcmc(**kw)) case _: raise KeyError(mode) @@ -488,7 +524,7 @@ def task(): case _: raise KeyError(cache) - def time_run_mcmc(self, *_): + def time_run_mcmc(self, *_: Any) -> None: """Time running or compiling the function.""" try: self.task() @@ -502,7 +538,14 @@ def time_run_mcmc(self, *_): raise RuntimeError(msg) -def kill_callback(*, canary: str, kill_niters: int | None, bart, i_total, **_): +def kill_callback( + *, + canary: str, + kill_niters: int | None, + bart: State, + i_total: Integer[Array, ''], + **_: Any, +) -> None: """Throw error `canary` after `kill_niters` in `run_mcmc`. Partially evaluate `kill_callback` on the first two arguments before @@ -524,7 +567,7 @@ def kill_callback(*, canary: str, kill_niters: int | None, bart, i_total, **_): debug.callback(lambda _token: None, token) # to avoid DCE -def detect_zero_division_error_bug(kw: dict): +def detect_zero_division_error_bug(kw: dict) -> None: """Detect a division by zero error with 0 iterations in v0.6.0.""" try: array_kw = {k: v for k, v in kw.items() if isinstance(v, jnp.ndarray)} @@ -550,7 +593,7 @@ class RunMcmcVsTraceLength(BaseRunMcmc): # asv config params = ((2**6, 2**8, 2**10, 2**12, 2**14, 2**16),) - def setup(self, n_save: int): # ty:ignore[invalid-method-override] + def setup(self, n_save: int) -> None: # ty:ignore[invalid-method-override] """Set up to kill after a certain number of iterations.""" kill_niters = min(self.params[0]) super().setup(kill_niters, kwargs=dict(n_save=n_save), n=0) @@ -562,7 +605,7 @@ class RunMcmc(BaseRunMcmc): # asv config params = (('compile', 'run'), (0, NITERS), ('cold', 'warm')) - def setup(self, mode: Mode, niters: int, cache: Cache): # ty:ignore[invalid-method-override] + def setup(self, mode: Mode, niters: int, cache: Cache) -> None: # ty:ignore[invalid-method-override] """Prepare the arguments, compile the function, and run to warm-up.""" super().setup( None, mode, cache, dict(n_save=niters // 2, n_burn=(niters - niters // 2)) diff --git a/config/refs-for-asv.py b/config/refs-for-asv.py index 0ff6d3fe..85d8670e 100644 --- a/config/refs-for-asv.py +++ b/config/refs-for-asv.py @@ -40,7 +40,7 @@ CUTOFF_DATE = datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc) -def main(): +def main() -> None: repo = Repo('.') # Get the default branch name from git diff --git a/config/util.py b/config/util.py index e65b3088..332bdd90 100644 --- a/config/util.py +++ b/config/util.py @@ -34,13 +34,17 @@ def get_version() -> str: return tomli.load(file)['project']['version'] -def update_version(): +def update_version() -> None: """Update the version file.""" version = get_version() - Path('src/bartz/_version.py').write_text(f'__version__ = {version!r}\n') + version_info = tuple(map(int, version.split('.'))) + Path('src/bartz/_version.py').write_text(f"""\ +__version__ = {version!r} +__version_info__ = {version_info!r} +""") -def main(): +def main() -> None: command = sys.argv[1] if command == 'get_version': print(get_version()) diff --git a/docs/_static/custom.css b/docs/_static/custom.css index 95f370f7..9b1a695a 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -1,6 +1,6 @@ /* bartz/docs/_static/custom.css * - * Copyright (c) 2024-2025, The Bartz Contributors + * Copyright (c) 2024-2026, The Bartz Contributors * * This file is part of bartz. * @@ -23,21 +23,25 @@ * SOFTWARE. */ -dl.py.method, dl.py.function { +dl.py.method, +dl.py.function { margin-top: 2em; margin-bottom: 2em; } -dl.py.property { +dl.py.property, +dl.py.attribute { margin-top: 1em; } -dl.py.class, dl.py.function { +dl.py.class, +dl.py.function { margin-top: 2.5em; margin-bottom: 2.5em; } -h2 + dl.py.class, h2 + dl.py.function { +h2 + dl.py.class, +h2 + dl.py.function { margin-top: 1em; } diff --git a/docs/conf.py b/docs/conf.py index cd13da10..2796c48e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -143,7 +143,7 @@ default_role = 'py:obj' # autodoc -autoclass_content = 'both' # concatenate the class and __init__ docstrings +autoclass_content = 'class' # default arguments are printed as in source instead of being evaluated autodoc_preserve_defaults = True autodoc_default_options = {'member-order': 'bysource'} @@ -173,7 +173,7 @@ viewcode_line_numbers = True -def linkcode_resolve(domain, info): +def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: """ Determine the URL corresponding to Python object, for extension linkcode. diff --git a/docs/development.rst b/docs/development.rst index 4c0087e2..805e1388 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -1,6 +1,6 @@ .. bartz/docs/development.rst .. -.. Copyright (c) 2024-2025, The Bartz Contributors +.. Copyright (c) 2024-2026, The Bartz Contributors .. .. This file is part of bartz. .. @@ -41,16 +41,33 @@ Install `R `_ and `uv `_ miss a :code:`sudo apt install r-base-dev` at the end.) The Python environment is managed by uv. To run commands that involve the Python installation, do :literal:`uv run `. For example, to start an IPython shell, do :literal:`uv run ipython`. Alternatively, do :literal:`source .venv/bin/activate` to activate the virtual environment in the current shell. The R environment is automatically active when you use :literal:`R` in the project directory. +We don't support using conda's R, though it might work. + +Contributing +------------ + +To contribute code changes to the main repository, create a `pull request `_ from your fork to the main repo. + Pre-defined commands -------------------- -Development commands are defined in a makefile. Run :literal:`make` without arguments to list the targets. +Development commands are defined in a makefile. Run :literal:`make` without arguments to list the targets. All commands that simply consist in invoking a tool with the right command line arguments use the :literal:`ARGS` variable to add extra arguments, for example: + +.. code-block:: shell + + make tests ARGS='-k test_pigs_fly' + +will invoke something like + +.. code-block:: shell + + uv run pytest --foo=1 --bar=128 --etc-etc -k test_pigs_fly Documentation ------------- @@ -73,6 +90,40 @@ To debug the documentation build, do make docs SPHINXOPTS='--fresh-env --pdb' +Unit tests +---------- + +The typical workflow to debug new changes is to first run all tests with + +.. code-block:: shell + + make tests + +Then, if some tests fail, use :literal:`pytest` directly to run and debug only the relevant tests, e.g., with + +.. code-block:: shell + + uv run pytest --lf --sw --pdb + +Where :code:`--lf` selects only the tests that failed, :code:`--sw` stops on the first failed test, starting again from it on the next run, and :code:`--pdb` opens the python debugger at the point where the test failed. Another useful option is :code:`-k `, which selects only tests whose name matches . + +Debugging dependencies +---------------------- + +To debug tests that fail with old versions of dependencies, it's convenient to piggyback on the predefined make target using :code:`ARGS`: + +.. code-block:: shell + + make tests-old ARGS='-n0 -k test_pigs_fly' + +Where :code:`-n0` disables test parallelization. + +For more fine-grained control, it's useful to invoke directly :code:`uv` with the :code:`--with` option, e.g., the following command will start an IPython shell equipped with specific versions of python and jax: + +.. code-block:: shell + + uv run --with='jax<0.7,jaxlib<0.7' --isolated --python=3.11 --dev python -m IPython + Benchmarks ---------- diff --git a/docs/reference/debug.rst b/docs/reference/debug.rst index 5c30e582..365abecb 100644 --- a/docs/reference/debug.rst +++ b/docs/reference/debug.rst @@ -1,6 +1,6 @@ .. bartz/docs/reference/debug.rst .. -.. Copyright (c) 2025, The Bartz Contributors +.. Copyright (c) 2025-2026, The Bartz Contributors .. .. This file is part of bartz. .. @@ -27,3 +27,5 @@ Debugging .. automodule:: bartz.debug :members: + :imported-members: + :special-members: __call__ diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 74e6f06a..9ea214c5 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,6 +1,6 @@ .. bartz/docs/reference/index.rst .. -.. Copyright (c) 2024-2025, The Bartz Contributors +.. Copyright (c) 2024-2026, The Bartz Contributors .. .. This file is part of bartz. .. @@ -36,4 +36,5 @@ Reference prepcovars.rst jaxext.rst debug.rst + test.rst profile.rst diff --git a/docs/reference/test.rst b/docs/reference/test.rst new file mode 100644 index 00000000..890c445e --- /dev/null +++ b/docs/reference/test.rst @@ -0,0 +1,30 @@ +.. bartz/docs/reference/test.rst +.. +.. Copyright (c) 2026, The Bartz Contributors +.. +.. This file is part of bartz. +.. +.. Permission is hereby granted, free of charge, to any person obtaining a copy +.. of this software and associated documentation files (the "Software"), to deal +.. in the Software without restriction, including without limitation the rights +.. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +.. copies of the Software, and to permit persons to whom the Software is +.. furnished to do so, subject to the following conditions: +.. +.. The above copyright notice and this permission notice shall be included in all +.. copies or substantial portions of the Software. +.. +.. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +.. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +.. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +.. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +.. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +.. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +.. SOFTWARE. + +Testing +------- + +.. automodule:: bartz.testing + :members: + :imported-members: diff --git a/pyproject.toml b/pyproject.toml index 7f80cf78..fb751806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,8 +144,12 @@ skip-magic-trailing-comma = true [tool.ruff.lint.isort] split-on-trailing-comma = false +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + [tool.ruff.lint] select = [ + "ICN", # flake8-import-conventions "ERA", # eradicate "S", # flake8-bandit "BLE", # flake8-blind-except @@ -187,6 +191,7 @@ select = [ "FURB", # refurb "RUF", # Ruff-specific rules "TRY", # tryceratops + "ANN", # flake8-annotations ] ignore = [ "B028", # warn with stacklevel = 2 @@ -234,9 +239,18 @@ convention = "numpy" min-file-size = 1 [tool.ruff.lint.flake8-tidy-imports] -banned-module-level-imports = ["bartz.debug"] ban-relative-imports = "all" +[tool.ruff.lint.flake8-import-conventions.aliases] +"jax.numpy" = "jnp" + +[tool.ruff.lint.flake8-import-conventions] +banned-from = ["jax.random", "jax.lax", "jax.numpy", "jax.tree"] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"jax.lax.reciprocal".msg = "Use jnp.reciprocal" +"jax.random.PRNGKey".msg = "User jax.random.key" + [tool.pydoclint] arg-type-hints-in-signature = true arg-type-hints-in-docstring = false @@ -245,6 +259,7 @@ check-yield-types = false treat-property-methods-as-class-attributes = true check-style-mismatch = true show-filenames-in-every-violation-message = true +allow-init-docstring = false check-class-attributes = false -# do not check class attributes because in dataclasses I document them as -# init parameters because they are duplicated in the html docs otherwise. +# do not check class attributes because pydoclint only supports them being +# documented in the class docstring while we document them individually diff --git a/scripts/optnumbatches.py b/scripts/optnumbatches.py index d039f478..41afe54d 100644 --- a/scripts/optnumbatches.py +++ b/scripts/optnumbatches.py @@ -62,7 +62,7 @@ class ParamsBase(ABC): num_batches: Sequence[None | int] = (None, *(2**i for i in range(13 + 1))) @abstractmethod - def valid(self, *, n: int, num_batches: int | None, **_) -> bool: + def valid(self, *, n: int, num_batches: int | None, **_: Any) -> bool: """Check if a set of parameter values is valid.""" return num_batches is None or num_batches <= n @@ -87,7 +87,7 @@ class ParamsResid(ParamsBase): k: Sequence[int | None] = (None, 1, 2, 4) - def valid(self, *, weights: bool, k: int | None, **values) -> bool: # ty:ignore[invalid-method-override] + def valid(self, *, weights: bool, k: int | None, **values: Any) -> bool: # ty:ignore[invalid-method-override] """Skip heteroskedastic multivariate.""" return (not weights or k is None) and super().valid(**values) @@ -98,7 +98,7 @@ class ParamsCount(ParamsBase): num_trees: Sequence[int] = tuple(4**i for i in range(5 + 1)) - def valid(self, *, n: int, num_trees: int, **values) -> bool: # ty:ignore[invalid-method-override] + def valid(self, *, n: int, num_trees: int, **values: Any) -> bool: # ty:ignore[invalid-method-override] """Skip if it would use too much memory.""" return num_trees * n <= MAX_LEAF_INDICES_SIZE and super().valid(n=n, **values) @@ -121,7 +121,7 @@ def setup( num_batches: None | int, k: int | None = None, num_trees: int = 5, - ): + ) -> None: """Initialize BART state and warmup MCMC step.""" # generate data X, y, max_split = gen_nonsense_data(1, n, k) @@ -167,11 +167,11 @@ def setup( # warm up MCMC self.task() - def task(self): + def task(self) -> None: """Run `step_trees` once.""" self.state = block_until_ready(step_func(self.key, self.state)) - def teardown(self): + def teardown(self) -> None: """Delete the state and the compiled function.""" del self.state step_func.clear_cache() @@ -274,7 +274,7 @@ def benchmark_loop(args: Namespace) -> DataFrame: return DataFrame(results) -def enable_compilation_cache(): +def enable_compilation_cache() -> None: """Enable JAX compilation caching to speed repeated runs.""" config.update('jax_compilation_cache_dir', 'config/jax_cache') config.update('jax_persistent_cache_min_entry_size_bytes', -1) @@ -325,7 +325,7 @@ def parse_args() -> Namespace: return parser.parse_args() -def main(): +def main() -> None: """Entry point of the script.""" enable_compilation_cache() args = parse_args() diff --git a/src/bartz/BART/_gbart.py b/src/bartz/BART/_gbart.py index 6ce546ed..3014771b 100644 --- a/src/bartz/BART/_gbart.py +++ b/src/bartz/BART/_gbart.py @@ -24,7 +24,7 @@ """Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package.""" -from collections.abc import Mapping +from collections.abc import Hashable, Mapping from functools import cached_property from os import cpu_count from types import MappingProxyType @@ -33,12 +33,11 @@ from equinox import Module from jax import device_count -from jax import numpy as jnp from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real from bartz import mcmcloop, mcmcstep from bartz._interface import Bart, DataFrame, FloatLike, Series -from bartz.jaxext import get_default_device +from bartz.jaxext import get_default_device, jit_active class mc_gbart(Module): @@ -259,7 +258,14 @@ def __init__( mc_cores: int = 2, seed: int | Key[Array, ''] = 0, bart_kwargs: Mapping = MappingProxyType({}), - ): + ) -> None: + # set defaults that depend on type of regression + if keepevery is None: + keepevery = 10 if type == 'pbart' else 1 + if ntree is None: + ntree = 50 if type == 'pbart' else 200 + + # set most calling arguments for Bart kwargs: dict = dict( x_train=x_train, y_train=y_train, @@ -284,7 +290,7 @@ def __init__( tau_num=tau_num, offset=offset, w=w, - ntree=ntree, + num_trees=ntree, numcut=numcut, ndpost=ndpost, nskip=nskip, @@ -294,7 +300,18 @@ def __init__( maxdepth=6, **process_mc_cores(y_train, mc_cores), ) + + # set min_points_per_leaf unless the user set it already + if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): + bart_kwargs = dict(bart_kwargs) + init_kw = dict(bart_kwargs.get('init_kw', {})) + init_kw['min_points_per_leaf'] = 5 + bart_kwargs['init_kw'] = init_kw + + # add user arguments kwargs.update(bart_kwargs) + + # invoke Bart self._bart = Bart(**kwargs) # Public attributes from Bart @@ -338,7 +355,7 @@ def _splits(self) -> Real[Array, 'p max_num_splits']: return self._bart._splits # noqa: SLF001 @property - def _x_train_fmt(self) -> Any: + def _x_train_fmt(self) -> Hashable: return self._bart._x_train_fmt # noqa: SLF001 # Cached properties from Bart @@ -450,7 +467,7 @@ def predict( class gbart(mc_gbart): """Subclass of `mc_gbart` that forces `mc_cores=1`.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: if 'mc_cores' in kwargs: msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" raise TypeError(msg) @@ -458,11 +475,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]: +def process_mc_cores(y_train: Array | Series, mc_cores: int) -> dict[str, Any]: """Determine the arguments to pass to `Bart` to configure multiple chains.""" - # one chain, leave default configuration which is num_chains=None + # one chain, disable multichain altogether if abs(mc_cores) == 1: - return {} + return dict(num_chains=None) # determine if we are on cpu; this point may raise an exception platform = get_platform(y_train, mc_cores) @@ -507,12 +524,12 @@ def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]: return kwargs -def get_platform(y_train: Array | Any, mc_cores: int) -> str: +def get_platform(y_train: Array | Series, mc_cores: int) -> str: """Get the platform for `process_mc_cores` from `y_train` or the default device.""" if isinstance(y_train, Array) and hasattr(y_train, 'platform'): return y_train.platform() elif ( - not isinstance(y_train, Array) and hasattr(jnp.zeros(()), 'platform') + not isinstance(y_train, Array) and not jit_active() # this condition means: y_train is not an array, but we are not under # jit, so y_train is going to be converted to an array on the default # device diff --git a/src/bartz/__init__.py b/src/bartz/__init__.py index ad5565a5..19a30d92 100644 --- a/src/bartz/__init__.py +++ b/src/bartz/__init__.py @@ -1,6 +1,6 @@ # bartz/src/bartz/__init__.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -31,4 +31,4 @@ from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401 from bartz._interface import Bart # noqa: F401 from bartz._profiler import profile_mode # noqa: F401 -from bartz._version import __version__ # noqa: F401 +from bartz._version import __version__, __version_info__ # noqa: F401 diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 23c09bcf..de97cf10 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -25,15 +25,15 @@ """Main high-level interface of the package.""" import math -from collections.abc import Sequence -from functools import cached_property +from collections.abc import Mapping, Sequence +from functools import cached_property, partial +from types import MappingProxyType from typing import Any, Literal, Protocol, TypedDict import jax import jax.numpy as jnp from equinox import Module, error_if, field -from jax import Device, device_put, jit, make_mesh -from jax.lax import collapse +from jax import Device, device_put, jit, lax, make_mesh from jax.scipy.special import ndtr from jax.sharding import AxisType, Mesh from jaxtyping import ( @@ -54,7 +54,7 @@ from bartz.jaxext import is_key from bartz.jaxext.scipy.special import ndtri from bartz.jaxext.scipy.stats import invgamma -from bartz.mcmcloop import compute_varcount, evaluate_trace, run_mcmc +from bartz.mcmcloop import RunMCMCResult, compute_varcount, evaluate_trace, run_mcmc from bartz.mcmcstep import make_p_nonterminal from bartz.mcmcstep._state import get_num_chains @@ -194,9 +194,8 @@ class Bart(Module): datapoints. Note: `w` is ignored in the automatic determination of `sigest`, so either the weights should be O(1), or `sigest` should be specified by the user. - ntree - The number of trees used to represent the latent mean function. By - default 200 for continuous regression and 50 for binary regression. + num_trees + The number of trees used to represent the latent mean function. numcut If `usequants` is `False`: the exact number of cutpoints used to bin the predictors, ranging between the minimum and maximum observed values @@ -221,20 +220,18 @@ class Bart(Module): The number of initial MCMC samples to discard as burn-in. This number of samples is discarded from each chain. keepevery - The thinning factor for the MCMC samples, after burn-in. By default, 1 - for continuous regression and 10 for binary regression. + The thinning factor for the MCMC samples, after burn-in. printevery The number of iterations (including thinned-away ones) between each log line. Set to `None` to disable logging. ^C interrupts the MCMC only every `printevery` iterations, so with logging disabled it's impossible to kill the MCMC conveniently. num_chains - The number of independent Markov chains to run. By default only one - chain is run. + The number of independent Markov chains to run. - The difference between not specifying `num_chains` and setting it to 1 - is that in the latter case in the object attributes and some methods - there will be an explicit chain axis of size 1. + The difference between ``num_chains=None`` and ``num_chains=1`` is that + in the latter case in the object attributes and some methods there will + be an explicit chain axis of size 1. num_chain_devices The number of devices to spread the chains across. Must be a divisor of `num_chains`. Each device will run a fraction of the chains. @@ -309,21 +306,21 @@ def __init__( tau_num: FloatLike | None = None, offset: FloatLike | None = None, w: Float[Array, ' n'] | Series | None = None, - ntree: int | None = None, - numcut: int = 100, + num_trees: int = 200, + numcut: int = 255, ndpost: int = 1000, - nskip: int = 100, - keepevery: int | None = None, + nskip: int = 1000, + keepevery: int = 1, printevery: int | None = 100, - num_chains: int | None = None, + num_chains: int | None = 4, num_chain_devices: int | None = None, num_data_devices: int | None = None, devices: Device | Sequence[Device] | None = None, seed: int | Key[Array, ''] = 0, maxdepth: int = 6, - init_kw: dict | None = None, - run_mcmc_kw: dict | None = None, - ): + init_kw: Mapping = MappingProxyType({}), + run_mcmc_kw: Mapping = MappingProxyType({}), + ) -> None: # check data and put it in the right format x_train, x_train_fmt = self._process_predictor_input(x_train) y_train = self._process_response_input(y_train) @@ -336,12 +333,6 @@ def __init__( self._check_type_settings(y_train, type, w) # from here onwards, the type is determined by y_train.dtype == bool - # set defaults that depend on type of regression - if ntree is None: - ntree = 50 if y_train.dtype == bool else 200 - if keepevery is None: - keepevery = 10 if y_train.dtype == bool else 1 - # process sparsity settings theta, a, b, rho = self._process_sparsity_settings( x_train, sparse, theta, a, b, rho @@ -349,7 +340,7 @@ def __init__( # process "standardization" settings offset = self._process_offset_settings(y_train, offset) - sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) + sigma_mu = self._process_leaf_sdev_settings(y_train, k, num_trees, tau_num) lamda, sigest = self._process_error_variance_settings( x_train, y_train, sigest, sigdf, sigquant, lamda ) @@ -371,7 +362,7 @@ def __init__( power, base, maxdepth, - ntree, + num_trees, init_kw, rm_const, theta, @@ -386,18 +377,19 @@ def __init__( sparse, nskip, ) - final_state, burnin_trace, main_trace = self._run_mcmc( + result = self._run_mcmc( initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw ) # set public attributes - self.offset = final_state.offset # from the state because of buffer donation + # set offset from the state because of buffer donation + self.offset = result.final_state.offset self.sigest = sigest # set private attributes - self._main_trace = main_trace - self._burnin_trace = burnin_trace - self._mcmc_state = final_state + self._main_trace = result.main_trace + self._burnin_trace = result.burnin_trace + self._mcmc_state = result.final_state self._splits = splits self._x_train_fmt = x_train_fmt @@ -406,7 +398,7 @@ def __init__( self.yhat_test = self.predict(x_test) @property - def ndpost(self): + def ndpost(self) -> int: """The total number of posterior samples after burn-in across all chains. May be larger than the initialization argument `ndpost` if it was not @@ -414,6 +406,11 @@ def ndpost(self): """ return self._main_trace.grow_prop_count.size + @property + def num_trees(self) -> int: + """Return the number of trees used in the model.""" + return self._mcmc_state.forest.split_tree.shape[-2] + @cached_property def prob_test(self) -> Float32[Array, 'ndpost m'] | None: """The posterior probability of y being True at `x_test` for each MCMC iteration.""" @@ -491,9 +488,7 @@ def sigma_mean(self) -> Float32[Array, ''] | None: def varcount(self) -> Int32[Array, 'ndpost p']: """Histogram of predictor usage for decision rules in the trees.""" p = self._mcmc_state.forest.max_split.size - varcount: Int32[Array, '*chains samples p'] - varcount = compute_varcount(p, self._main_trace) - return collapse(varcount, 0, -1) + return varcount(p, self._main_trace) @cached_property def varcount_mean(self) -> Float32[Array, ' p']: @@ -577,7 +572,9 @@ def predict( return self._predict(x_test) @staticmethod - def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: + def _process_predictor_input( + x: Real[Any, 'p n'] | DataFrame, + ) -> tuple[Shaped[Array, 'p n'], Any]: if hasattr(x, 'columns'): fmt = dict(kind='dataframe', columns=x.columns) x = x.to_numpy().T @@ -588,7 +585,7 @@ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: return x, fmt @staticmethod - def _process_response_input(y) -> Shaped[Array, ' n']: + def _process_response_input(y: Shaped[Array, ' n'] | Series) -> Shaped[Array, ' n']: if hasattr(y, 'to_numpy'): y = y.to_numpy() y = jnp.asarray(y) @@ -596,13 +593,19 @@ def _process_response_input(y) -> Shaped[Array, ' n']: return y @staticmethod - def _check_same_length(x1, x2): + def _check_same_length(x1: Array, x2: Array) -> None: get_length = lambda x: x.shape[-1] assert get_length(x1) == get_length(x2) @classmethod def _process_error_variance_settings( - cls, x_train, y_train, sigest, sigdf, sigquant, lamda + cls, + x_train: Shaped[Array, 'p n'], + y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + sigest: FloatLike | None, + sigdf: FloatLike, + sigquant: FloatLike, + lamda: FloatLike | None, ) -> tuple[Float32[Array, ''] | None, ...]: """Return (lamda, sigest).""" if y_train.dtype == bool: @@ -636,7 +639,7 @@ def _process_error_variance_settings( @jit def _linear_regression( x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n'] - ): + ) -> Float32[Array, '']: """Return the error variance estimated with OLS with intercept.""" x_centered = x_train.T - x_train.mean(axis=1) y_centered = y_train - y_train.mean() @@ -647,7 +650,11 @@ def _linear_regression( return chisq / dof @staticmethod - def _check_type_settings(y_train, type, w): # noqa: A002 + def _check_type_settings( + y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + type: str, # noqa: A002 + w: Float[Array, ' n'] | None, + ) -> None: match type: case 'wbart': if y_train.dtype != jnp.float32: @@ -717,10 +724,10 @@ def _process_offset_settings( @staticmethod def _process_leaf_sdev_settings( y_train: Float32[Array, ' n'] | Bool[Array, ' n'], - k: float, - ntree: int, + k: FloatLike, + num_trees: int, tau_num: FloatLike | None, - ): + ) -> FloatLike: """Return sigma_mu.""" if tau_num is None: if y_train.dtype == bool: @@ -730,7 +737,7 @@ def _process_leaf_sdev_settings( else: tau_num = (y_train.max() - y_train.min()) / 2 - return tau_num / (k * math.sqrt(ntree)) + return tau_num / (k * math.sqrt(num_trees)) @staticmethod def _determine_splits( @@ -768,8 +775,8 @@ def _setup_mcmc( power: FloatLike, base: FloatLike, maxdepth: int, - ntree: int, - init_kw: dict[str, Any] | None, + num_trees: int, + init_kw: Mapping[str, Any], rm_const: bool, theta: FloatLike | None, a: FloatLike | None, @@ -782,7 +789,7 @@ def _setup_mcmc( devices: Device | Sequence[Device] | None, sparse: bool, nskip: int, - ): + ) -> mcmcstep.State: p_nonterminal = make_p_nonterminal(maxdepth, base, power) if y_train.dtype == bool: @@ -806,13 +813,12 @@ def _setup_mcmc( offset=offset, error_scale=w, max_split=max_split, - num_trees=ntree, + num_trees=num_trees, p_nonterminal=p_nonterminal, leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)), error_cov_df=error_cov_df, error_cov_scale=error_cov_scale, min_points_per_decision_node=10, - min_points_per_leaf=5, log_s=process_varprob(varprob, max_split), theta=theta, a=a, @@ -826,8 +832,7 @@ def _setup_mcmc( n_empty = jnp.sum(max_split == 0).item() kw.update(filter_splitless_vars=n_empty) - if init_kw is not None: - kw.update(init_kw) + kw.update(init_kw) state = mcmcstep.init(**kw) @@ -846,8 +851,8 @@ def _run_mcmc( keepevery: int, printevery: int | None, seed: int | Integer[Array, ''] | Key[Array, ''], - run_mcmc_kw: dict | None, - ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]: + run_mcmc_kw: Mapping, + ) -> RunMCMCResult: # prepare random generator seed if is_key(seed): key = jnp.copy(seed) @@ -869,15 +874,32 @@ def _run_mcmc( report_every=printevery, ) ) - if run_mcmc_kw is not None: - kw.update(run_mcmc_kw) + kw.update(run_mcmc_kw) return run_mcmc(key, mcmc_state, n_save, **kw) def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: """Evaluate trees on already quantized `x`.""" - out = evaluate_trace(x, self._main_trace) - return collapse(out, 0, -1) + return predict(x, self._main_trace) + + +@partial(jit, static_argnames='p') +# this is jitted such that lax.collapse below does not create a copy +def varcount(p: int, trace: mcmcloop.MainTrace) -> Int32[Array, 'ndpost p']: + """Histogram of predictor usage for decision rules in the trees, squashing chains.""" + varcount: Int32[Array, '*chains samples p'] + varcount = compute_varcount(p, trace) + return lax.collapse(varcount, 0, -1) + + +@jit +# this is jitted such that lax.collapse below does not create a copy +def predict( + x: UInt[Array, 'p m'], trace: mcmcloop.MainTrace +) -> Float32[Array, 'ndpost m']: + """Evaluate trees on already quantized `x`, and squash chains.""" + out = evaluate_trace(x, trace) + return lax.collapse(out, 0, -1) class DeviceKwArgs(TypedDict): diff --git a/src/bartz/_profiler.py b/src/bartz/_profiler.py index ebe54b83..d38eb38f 100644 --- a/src/bartz/_profiler.py +++ b/src/bartz/_profiler.py @@ -29,8 +29,7 @@ from functools import wraps from typing import Any, TypeVar -from jax import block_until_ready, debug, jit -from jax.lax import cond, scan +from jax import block_until_ready, debug, jit, lax from jax.profiler import TraceAnnotation from jaxtyping import Array, Bool @@ -39,7 +38,6 @@ PROFILE_MODE: bool = False T = TypeVar('T') -Carry = TypeVar('Carry') def get_profile_mode() -> bool: @@ -100,7 +98,7 @@ def profile_mode(value: bool, /) -> Iterator[None]: def jit_and_block_if_profiling( - func: Callable[..., T], block_before: bool = False, **kwargs + func: Callable[..., T], block_before: bool = False, **kwargs: Any ) -> Callable[..., T]: """Apply JIT compilation and block if profiling is enabled. @@ -136,7 +134,7 @@ def jit_and_block_if_profiling( event_name = f'jab[{func.__name__}]' # this wrapper is meant to measure the time spent executing the function - def jab_inner_wrapper(*args, **kwargs) -> T: + def jab_inner_wrapper(*args: Any, **kwargs: Any) -> T: with TraceAnnotation(event_name): result = jitted_func(*args, **kwargs) return block_until_ready(result) @@ -153,7 +151,9 @@ def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: return jab_outer_wrapper -def jit_if_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: +def jit_if_profiling( + func: Callable[..., T], *args: Any, **kwargs: Any +) -> Callable[..., T]: """Apply JIT compilation only when profiling. Parameters @@ -180,7 +180,9 @@ def wrapper(*args: Any, **kwargs: Any) -> T: return wrapper -def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: +def jit_if_not_profiling( + func: Callable[..., T], *args: Any, **kwargs: Any +) -> Callable[..., T]: """Apply JIT compilation only when not profiling. When profile mode is off, the function is JIT compiled. When profile mode is @@ -212,39 +214,35 @@ def wrapper(*args: Any, **kwargs: Any) -> T: return wrapper -def scan_if_not_profiling( - f: Callable[[Carry, None], tuple[Carry, None]], - init: Carry, - xs: None, - length: int, +def while_loop_if_not_profiling( + cond_fun: Callable[[T], Bool[Array, ''] | bool], + body_fun: Callable[[T], T], + init_val: T, /, -) -> tuple[Carry, None]: - """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling. +) -> T: + """Restricted replacement for `jax.lax.while_loop` that uses a Python loop when profiling. Parameters ---------- - f - Scan body function with signature (carry, None) -> (carry, None). - init - Initial carry value. - xs - Input values to scan over (not supported). - length - Integer specifying the number of loop iterations. + cond_fun + Function to evaluate to determine whether to continue the loop. + body_fun + Function that updates the state in each iteration. + init_val + Initial state. Returns ------- - Tuple of (final_carry, None) (stacked outputs not supported). + Final state. """ - assert xs is None if get_profile_mode(): - carry = init - for _i in range(length): - carry, _ = f(carry, None) - return carry, None + val = init_val + while cond_fun(val): + val = body_fun(val) + return val else: - return scan(f, init, None, length) + return lax.while_loop(cond_fun, body_fun, init_val) def cond_if_not_profiling( @@ -252,7 +250,7 @@ def cond_if_not_profiling( true_fun: Callable[..., T], false_fun: Callable[..., T], /, - *operands, + *operands: Any, ) -> T: """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling. @@ -277,12 +275,12 @@ def cond_if_not_profiling( else: return false_fun(*operands) else: - return cond(pred, true_fun, false_fun, *operands) + return lax.cond(pred, true_fun, false_fun, *operands) def callback_if_not_profiling( callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any -): +) -> None: """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode.""" if get_profile_mode(): callback(*args, **kwargs) @@ -290,12 +288,12 @@ def callback_if_not_profiling( debug.callback(callback, *args, ordered=ordered, **kwargs) -def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]: +def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs: Any) -> Callable[..., T]: """Apply `vmap_chains` only when profile mode is enabled.""" new_fun = vmap_chains(fun, **kwargs) @wraps(fun) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> T: if get_profile_mode(): return new_fun(*args, **kwargs) else: @@ -304,12 +302,14 @@ def wrapper(*args, **kwargs): return wrapper -def vmap_chains_if_not_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]: +def vmap_chains_if_not_profiling( + fun: Callable[..., T], **kwargs: Any +) -> Callable[..., T]: """Apply `vmap_chains` only when profile mode is disabled.""" new_fun = vmap_chains(fun, **kwargs) @wraps(fun) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> T: if get_profile_mode(): return fun(*args, **kwargs) else: diff --git a/src/bartz/_version.py b/src/bartz/_version.py index 32a90a3b..cb20abcc 100644 --- a/src/bartz/_version.py +++ b/src/bartz/_version.py @@ -1 +1,2 @@ __version__ = '0.8.0' +__version_info__ = (0, 8, 0) diff --git a/src/bartz/debug.py b/src/bartz/debug.py deleted file mode 100644 index 26bd5cea..00000000 --- a/src/bartz/debug.py +++ /dev/null @@ -1,1319 +0,0 @@ -# bartz/src/bartz/debug.py -# -# Copyright (c) 2024-2026, The Bartz Contributors -# -# This file is part of bartz. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" - -from collections.abc import Callable -from dataclasses import replace -from functools import partial -from math import ceil, log2 -from re import fullmatch -from typing import Literal - -import numpy -from equinox import Module, field -from jax import jit, lax, random, vmap -from jax import numpy as jnp -from jax.tree_util import tree_map -from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt - -from bartz.BART import gbart, mc_gbart -from bartz.BART._gbart import FloatLike -from bartz.grove import ( - TreeHeaps, - evaluate_forest, - is_actual_leaf, - is_leaves_parent, - normalize_axis_tuple, - traverse_forest, - tree_depth, - tree_depths, -) -from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc -from bartz.jaxext import split as split_key -from bartz.mcmcloop import TreesTrace -from bartz.mcmcstep._moves import randint_masked - - -def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: - """Convert a tree to a human-readable string. - - Parameters - ---------- - tree - A single tree to format. - print_all - If `True`, also print the contents of unused node slots in the arrays. - - Returns - ------- - A string representation of the tree. - """ - tee = '├──' - corner = '└──' - join = '│ ' - space = ' ' - down = '┐' - bottom = '╢' # '┨' # - - def traverse_tree( - lines: list[str], - index: int, - depth: int, - indent: str, - first_indent: str, - next_indent: str, - unused: bool, - ): - if index >= len(tree.leaf_tree): - return - - var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() - split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() - - is_leaf = split == 0 - left_child = 2 * index - right_child = 2 * index + 1 - - if print_all: - if unused: - category = 'unused' - elif is_leaf: - category = 'leaf' - else: - category = 'decision' - node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' - else: - assert not unused - if is_leaf: - node_str = f'{tree.leaf_tree[index]:#.2g}' - else: - node_str = f'x{var} < {split}' - - if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): - link = down - elif not print_all and left_child >= len(tree.leaf_tree): - link = bottom - else: - link = ' ' - - max_number = len(tree.leaf_tree) - 1 - ndigits = len(str(max_number)) - number = str(index).rjust(ndigits) - - lines.append(f' {number} {indent}{first_indent}{link}{node_str}') - - indent += next_indent - unused = unused or is_leaf - - if unused and not print_all: - return - - traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) - traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) - - lines = [] - traverse_tree(lines, 1, 0, '', '', '', False) - return '\n'.join(lines) - - -def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: - """Measure the depth of the tree. - - Parameters - ---------- - split_tree - The cutpoints of the decision rules. - - Returns - ------- - The depth of the deepest leaf in the tree. The root is at depth 0. - """ - # this could be done just with split_tree != 0 - is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) - depth = tree_depths(is_leaf.size) - depth = jnp.where(is_leaf, depth, 0) - return jnp.max(depth) - - -@jit -@partial(jnp.vectorize, signature='(nt,hts)->(d)') -def forest_depth_distr( - split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'], -) -> Int32[Array, '*batch_shape d']: - """Histogram the depths of a set of trees. - - Parameters - ---------- - split_tree - The cutpoints of the decision rules of the trees. - - Returns - ------- - An integer vector where the i-th element counts how many trees have depth i. - """ - depth = tree_depth(split_tree) + 1 - depths = vmap(tree_actual_depth)(split_tree) - return jnp.bincount(depths, length=depth) - - -@partial(jit, static_argnames=('node_type', 'sum_batch_axis')) -def points_per_node_distr( - X: UInt[Array, 'p n'], - var_tree: UInt[Array, '*batch_shape 2**(d-1)'], - split_tree: UInt[Array, '*batch_shape 2**(d-1)'], - node_type: Literal['leaf', 'leaf-parent'], - *, - sum_batch_axis: int | tuple[int, ...] = (), -) -> Int32[Array, '*reduced_batch_shape n+1']: - """Histogram points-per-node counts in a set of trees. - - Count how many nodes in a tree select each possible amount of points, - over a certain subset of nodes. - - Parameters - ---------- - X - The set of points to count. - var_tree - The variables of the decision rules. - split_tree - The cutpoints of the decision rules. - node_type - The type of nodes to consider. Can be: - - 'leaf' - Count only leaf nodes. - 'leaf-parent' - Count only parent-of-leaf nodes. - sum_batch_axis - Aggregate the histogram over these batch axes, counting how many nodes - have each possible amount of points over subsets of trees instead of - in each tree separately. - - Returns - ------- - A vector where the i-th element counts how many nodes have i points. - """ - batch_ndim = var_tree.ndim - 1 - axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) - - def func( - var_tree: UInt[Array, '*batch_shape 2**(d-1)'], - split_tree: UInt[Array, '*batch_shape 2**(d-1)'], - ) -> Int32[Array, '*reduced_batch_shape n+1']: - indices: UInt[Array, '*batch_shape n'] - indices = traverse_forest(X, var_tree, split_tree) - - @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') - def count_points( - split_tree: UInt[Array, '*batch_shape 2**(d-1)'], - indices: UInt[Array, '*batch_shape n'], - ) -> ( - tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']] - | tuple[ - UInt[Array, '*batch_shape 2**(d-1)'], - Bool[Array, '*batch_shape 2**(d-1)'], - ] - ): - if node_type == 'leaf-parent': - indices >>= 1 - predicate = is_leaves_parent(split_tree) - elif node_type == 'leaf': - predicate = is_actual_leaf(split_tree, add_bottom_level=True) - else: - raise ValueError(node_type) - count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) - return count_tree, predicate - - count_tree, predicate = count_points(split_tree, indices) - - def count_nodes( - count_tree: UInt[Array, '*summed_batch_axes half_tree_size'], - predicate: Bool[Array, '*summed_batch_axes half_tree_size'], - ) -> Int32[Array, ' n+1']: - return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) - - # vmap count_nodes over non-batched dims - for i in reversed(range(batch_ndim)): - neg_i = i - var_tree.ndim - if i not in axes: - count_nodes = vmap(count_nodes, in_axes=neg_i) - - return count_nodes(count_tree, predicate) - - # automatically batch over all batch dimensions - max_io_nbytes = 2**27 # 128 MiB - out_dim_shift = len(axes) - for i in reversed(range(batch_ndim)): - if i in axes: - out_dim_shift -= 1 - else: - func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) - assert out_dim_shift == 0 - - return func(var_tree, split_tree) - - -check_functions = [] - - -CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']] - - -def check(func: CheckFunc) -> CheckFunc: - """Add a function to a list of functions used to check trees. - - Use to decorate functions that check whether a tree is valid in some way. - These functions are invoked automatically by `check_tree`, `check_trace` and - `debug_gbart`. - - Parameters - ---------- - func - The function to add to the list. It must accept a `TreeHeaps` and a - `max_split` argument, and return a boolean scalar that indicates if the - tree is ok. - - Returns - ------- - The function unchanged. - """ - check_functions.append(func) - return func - - -@check -def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: - """Check that integer types are as small as possible and coherent.""" - expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) - expected_split_dtype = max_split.dtype - return ( - tree.var_tree.dtype == expected_var_dtype - and tree.split_tree.dtype == expected_split_dtype - and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger) - ) - - -@check -def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool: - """Check that array sizes are coherent.""" - return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size - - -@check -def check_unused_node( - tree: TreeHeaps, _max_split: UInt[Array, ' p'] -) -> Bool[Array, '']: - """Check that the unused node slot at index 0 is not dirty.""" - return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) - - -@check -def check_leaf_values( - tree: TreeHeaps, _max_split: UInt[Array, ' p'] -) -> Bool[Array, '']: - """Check that all leaf values are not inf of nan.""" - return jnp.all(jnp.isfinite(tree.leaf_tree)) - - -@check -def check_stray_nodes( - tree: TreeHeaps, _max_split: UInt[Array, ' p'] -) -> Bool[Array, '']: - """Check if there is any marked-non-leaf node with a marked-leaf parent.""" - index = jnp.arange( - 2 * tree.split_tree.size, - dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), - ) - parent_index = index >> 1 - is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 - parent_is_leaf = tree.split_tree[parent_index] == 0 - stray = is_not_leaf & parent_is_leaf - stray = stray.at[1].set(False) - return ~jnp.any(stray) - - -@check -def check_rule_consistency( - tree: TreeHeaps, max_split: UInt[Array, ' p'] -) -> bool | Bool[Array, '']: - """Check that decision rules define proper subsets of ancestor rules.""" - if tree.var_tree.size < 4: - return True - - # initial boundaries of decision rules. use extreme integers instead of 0, - # max_split to avoid checking if there is something out of bounds. - dtype = tree.split_tree.dtype - small = jnp.iinfo(dtype).min - large = jnp.iinfo(dtype).max - lower = jnp.full(max_split.size, small, dtype) - upper = jnp.full(max_split.size, large, dtype) - # the split must be in (lower[var], upper[var]] - - def _check_recursive(node, lower, upper): - # read decision rule - var = tree.var_tree[node] - split = tree.split_tree[node] - - # get rule boundaries from ancestors. use fill value in case var is - # out of bounds, we don't want to check out of bounds in this function - lower_var = lower.at[var].get(mode='fill', fill_value=small) - upper_var = upper.at[var].get(mode='fill', fill_value=large) - - # check rule is in bounds - bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) - - # recurse - if node < tree.var_tree.size // 2: - idx = jnp.where(split, var, max_split.size) - bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) - bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) - - return bad - - return ~_check_recursive(1, lower, upper) - - -@check -def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 - """Check that #leaves = 1 + #(internal nodes).""" - is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) - num_leaves = jnp.count_nonzero(is_leaf) - num_internal = jnp.count_nonzero(tree.split_tree) - return num_leaves == num_internal + 1 - - -@check -def check_var_in_bounds( - tree: TreeHeaps, max_split: UInt[Array, ' p'] -) -> Bool[Array, '']: - """Check that variables are in [0, max_split.size).""" - decision_node = tree.split_tree.astype(bool) - in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) - return jnp.all(in_bounds | ~decision_node) - - -@check -def check_split_in_bounds( - tree: TreeHeaps, max_split: UInt[Array, ' p'] -) -> Bool[Array, '']: - """Check that splits are in [0, max_split[var]].""" - max_split_var = ( - max_split.astype(jnp.int32) - .at[tree.var_tree] - .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) - ) - return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) - - -def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: - """Check the validity of a tree. - - Use `describe_error` to parse the error code returned by this function. - - Parameters - ---------- - tree - The tree to check. - max_split - The maximum split value for each variable. - - Returns - ------- - An integer where each bit indicates whether a check failed. - """ - error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1) - error = error_type(0) - for i, func in enumerate(check_functions): - ok = func(tree, max_split) - ok = jnp.bool_(ok) - bit = (~ok) << i - error |= bit - return error - - -def describe_error(error: int | Integer[Array, '']) -> list[str]: - """Describe the error code returned by `check_tree`. - - Parameters - ---------- - error - The error code returned by `check_tree`. - - Returns - ------- - A list of the function names that implement the failed checks. - """ - return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)] - - -@jit -def check_trace( - trace: TreeHeaps, max_split: UInt[Array, ' p'] -) -> UInt[Array, '*batch_shape']: - """Check the validity of a set of trees. - - Use `describe_error` to parse the error codes returned by this function. - - Parameters - ---------- - trace - The set of trees to check. This object can have additional attributes - beyond the tree arrays, they are ignored. - max_split - The maximum split value for each variable. - - Returns - ------- - A tensor of error codes for each tree. - """ - # vectorize check_tree over all batch dimensions - unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) - is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim - signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' - vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) - - # automatically batch over all batch dimensions - max_io_nbytes = 2**24 # 16 MiB - batch_ndim = trace.split_tree.ndim - 1 - batched_check_tree = vec_check_tree - for i in reversed(range(batch_ndim)): - batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) - - return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) - - -def _get_next_line(s: str, i: int) -> tuple[str, int]: - """Get the next line from a string and the new index.""" - i_new = s.find('\n', i) - if i_new == -1: - return s[i:], len(s) - return s[i:i_new], i_new + 1 - - -class BARTTraceMeta(Module): - """Metadata of R BART tree traces.""" - - ndpost: int = field(static=True) - """The number of posterior draws.""" - - ntree: int = field(static=True) - """The number of trees in the model.""" - - numcut: UInt[Array, ' p'] - """The maximum split value for each variable.""" - - heap_size: int = field(static=True) - """The size of the heap required to store the trees.""" - - -def scan_BART_trees(trees: str) -> BARTTraceMeta: - """Scan an R BART tree trace checking for errors and parsing metadata. - - Parameters - ---------- - trees - The string representation of a trace of trees of the R BART package. - Can be accessed from ``mc_gbart(...).treedraws['trees']``. - - Returns - ------- - An object containing the metadata. - - Raises - ------ - ValueError - If the string is malformed or contains leftover characters. - """ - # parse first line - line, i_char = _get_next_line(trees, 0) - i_line = 1 - match = fullmatch(r'(\d+) (\d+) (\d+)', line) - if match is None: - msg = f'Malformed header at {i_line=}' - raise ValueError(msg) - ndpost, ntree, p = map(int, match.groups()) - - # initial values for maxima - max_heap_index = 0 - numcut = numpy.zeros(p, int) - - # cycle over iterations and trees - for i_iter in range(ndpost): - for i_tree in range(ntree): - # parse first line of tree definition - line, i_char = _get_next_line(trees, i_char) - i_line += 1 - match = fullmatch(r'(\d+)', line) - if match is None: - msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' - raise ValueError(msg) - num_nodes = int(line) - - # cycle over nodes - for i_node in range(num_nodes): - # parse node definition - line, i_char = _get_next_line(trees, i_char) - i_line += 1 - match = fullmatch( - r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line - ) - if match is None: - msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' - raise ValueError(msg) - i_heap = int(match.group(1)) - var = int(match.group(2)) - split = int(match.group(3)) - - # update maxima - numcut[var] = max(numcut[var], split) - max_heap_index = max(max_heap_index, i_heap) - - assert i_char <= len(trees) - if i_char < len(trees): - msg = f'Leftover {len(trees) - i_char} characters in string' - raise ValueError(msg) - - # determine minimal integer type for numcut - numcut += 1 # because BART is 0-based - split_dtype = minimal_unsigned_dtype(numcut.max()) - numcut = jnp.array(numcut.astype(split_dtype)) - - # determine minimum heap size to store the trees - heap_size = 2 ** ceil(log2(max_heap_index + 1)) - - return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) - - -class TraceWithOffset(Module): - """Implementation of `bartz.mcmcloop.Trace`.""" - - leaf_tree: Float32[Array, 'ndpost ntree 2**d'] - var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] - split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] - offset: Float32[Array, ' ndpost'] - - @classmethod - def from_trees_trace( - cls, trees: TreeHeaps, offset: Float32[Array, ''] - ) -> 'TraceWithOffset': - """Create a `TraceWithOffset` from a `TreeHeaps`.""" - ndpost, _, _ = trees.leaf_tree.shape - return cls( - leaf_tree=trees.leaf_tree, - var_tree=trees.var_tree, - split_tree=trees.split_tree, - offset=jnp.full(ndpost, offset), - ) - - -def trees_BART_to_bartz( - trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None -) -> tuple[TraceWithOffset, BARTTraceMeta]: - """Convert trees from the R BART format to the bartz format. - - Parameters - ---------- - trees - The string representation of a trace of trees of the R BART package. - Can be accessed from ``mc_gbart(...).treedraws['trees']``. - min_maxdepth - The maximum tree depth of the output will be set to the maximum - observed depth in the input trees. Use this parameter to require at - least this maximum depth in the output format. - offset - The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be - summed to the sum of trees. To match that behavior, this function - returns an offset as well, zero by default. Set with this parameter - otherwise. - - Returns - ------- - trace : TraceWithOffset - A representation of the trees compatible with the trace returned by - `bartz.mcmcloop.run_mcmc`. - meta : BARTTraceMeta - The metadata of the trace, containing the number of iterations, trees, - and the maximum split value. - """ - # scan all the string checking for errors and determining sizes - meta = scan_BART_trees(trees) - - # skip first line - _, i_char = _get_next_line(trees, 0) - - heap_size = max(meta.heap_size, 2**min_maxdepth) - leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) - var_trees = numpy.zeros( - (meta.ndpost, meta.ntree, heap_size // 2), - dtype=minimal_unsigned_dtype(meta.numcut.size - 1), - ) - split_trees = numpy.zeros( - (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype - ) - - # cycle over iterations and trees - for i_iter in range(meta.ndpost): - for i_tree in range(meta.ntree): - # parse first line of tree definition - line, i_char = _get_next_line(trees, i_char) - num_nodes = int(line) - - is_internal = numpy.zeros(heap_size // 2, dtype=bool) - - # cycle over nodes - for _ in range(num_nodes): - # parse node definition - line, i_char = _get_next_line(trees, i_char) - values = line.split() - i_heap = int(values[0]) - var = int(values[1]) - split = int(values[2]) - leaf = float(values[3]) - - # update values - leaf_trees[i_iter, i_tree, i_heap] = leaf - is_internal[i_heap // 2] = True - if i_heap < heap_size // 2: - var_trees[i_iter, i_tree, i_heap] = var - split_trees[i_iter, i_tree, i_heap] = split + 1 - - is_internal[0] = False - split_trees[i_iter, i_tree, ~is_internal] = 0 - - return TraceWithOffset( - leaf_tree=jnp.array(leaf_trees), - var_tree=jnp.array(var_trees), - split_tree=jnp.array(split_trees), - offset=jnp.zeros(meta.ndpost) - if offset is None - else jnp.full(meta.ndpost, offset), - ), meta - - -class SamplePriorStack(Module): - """Represent the manually managed stack used in `sample_prior`. - - Each level of the stack represents a recursion into a child node in a - binary tree of maximum depth `d`. - """ - - nonterminal: Bool[Array, ' d-1'] - """Whether the node is valid or the recursion is into unused node slots.""" - - lower: UInt[Array, 'd-1 p'] - """The available cutpoints along ``var`` are in the integer range - ``[1 + lower[var], 1 + upper[var])``.""" - - upper: UInt[Array, 'd-1 p'] - """The available cutpoints along ``var`` are in the integer range - ``[1 + lower[var], 1 + upper[var])``.""" - - var: UInt[Array, ' d-1'] - """The variable of a decision node.""" - - split: UInt[Array, ' d-1'] - """The cutpoint of a decision node.""" - - @classmethod - def initial( - cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p'] - ) -> 'SamplePriorStack': - """Initialize the stack. - - Parameters - ---------- - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - max_split - The number of cutpoints along each variable. - - Returns - ------- - A `SamplePriorStack` initialized to start the recursion. - """ - var_dtype = minimal_unsigned_dtype(max_split.size - 1) - return cls( - nonterminal=jnp.ones(p_nonterminal.size, bool), - lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype), - upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)), - var=jnp.zeros(p_nonterminal.size, var_dtype), - split=jnp.zeros(p_nonterminal.size, max_split.dtype), - ) - - -class SamplePriorTrees(Module): - """Object holding the trees generated by `sample_prior`.""" - - leaf_tree: Float32[Array, '* 2**d'] - """The array representing the trees, see `bartz.grove`.""" - - var_tree: UInt[Array, '* 2**(d-1)'] - """The array representing the trees, see `bartz.grove`.""" - - split_tree: UInt[Array, '* 2**(d-1)'] - """The array representing the trees, see `bartz.grove`.""" - - @classmethod - def initial( - cls, - key: Key[Array, ''], - sigma_mu: Float32[Array, ''], - p_nonterminal: Float32[Array, ' d-1'], - max_split: UInt[Array, ' p'], - ) -> 'SamplePriorTrees': - """Initialize the trees. - - The leaves are already correct and do not need to be changed. - - Parameters - ---------- - key - A jax random key. - sigma_mu - The prior standard deviation of each leaf. - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - max_split - The number of cutpoints along each variable. - - Returns - ------- - Trees initialized with random leaves and stub tree structures. - """ - heap_size = 2 ** (p_nonterminal.size + 1) - return cls( - leaf_tree=sigma_mu * random.normal(key, (heap_size,)), - var_tree=jnp.zeros( - heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1) - ), - split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype), - ) - - -class SamplePriorCarry(Module): - """Object holding values carried along the recursion in `sample_prior`.""" - - key: Key[Array, ''] - """A jax random key used to sample decision rules.""" - - stack: SamplePriorStack - """The stack used to manage the recursion.""" - - trees: SamplePriorTrees - """The output arrays.""" - - @classmethod - def initial( - cls, - key: Key[Array, ''], - sigma_mu: Float32[Array, ''], - p_nonterminal: Float32[Array, ' d-1'], - max_split: UInt[Array, ' p'], - ) -> 'SamplePriorCarry': - """Initialize the carry object. - - Parameters - ---------- - key - A jax random key. - sigma_mu - The prior standard deviation of each leaf. - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - max_split - The number of cutpoints along each variable. - - Returns - ------- - A `SamplePriorCarry` initialized to start the recursion. - """ - keys = split_key(key) - return cls( - keys.pop(), - SamplePriorStack.initial(p_nonterminal, max_split), - SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split), - ) - - -class SamplePriorX(Module): - """Object representing the recursion scan in `sample_prior`. - - The sequence of nodes to visit is pre-computed recursively once, unrolling - the recursion schedule. - """ - - node: Int32[Array, ' 2**(d-1)-1'] - """The heap index of the node to visit.""" - - depth: Int32[Array, ' 2**(d-1)-1'] - """The depth of the node.""" - - next_depth: Int32[Array, ' 2**(d-1)-1'] - """The depth of the next node to visit, either the left child or the right - sibling of the node or of an ancestor.""" - - @classmethod - def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': - """Initialize the sequence of nodes to visit. - - Parameters - ---------- - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - - Returns - ------- - A `SamplePriorX` initialized with the sequence of nodes to visit. - """ - seq = cls._sequence(p_nonterminal.size) - assert len(seq) == 2**p_nonterminal.size - 1 - node = [node for node, depth in seq] - depth = [depth for node, depth in seq] - next_depth = [*depth[1:], p_nonterminal.size] - return cls( - node=jnp.array(node), - depth=jnp.array(depth), - next_depth=jnp.array(next_depth), - ) - - @classmethod - def _sequence( - cls, max_depth: int, depth: int = 0, node: int = 1 - ) -> tuple[tuple[int, int], ...]: - """Recursively generate a sequence [(node, depth), ...].""" - if depth < max_depth: - out = ((node, depth),) - out += cls._sequence(max_depth, depth + 1, 2 * node) - out += cls._sequence(max_depth, depth + 1, 2 * node + 1) - return out - return () - - -def sample_prior_onetree( - key: Key[Array, ''], - max_split: UInt[Array, ' p'], - p_nonterminal: Float32[Array, ' d-1'], - sigma_mu: Float32[Array, ''], -) -> SamplePriorTrees: - """Sample a tree from the BART prior. - - Parameters - ---------- - key - A jax random key. - max_split - The maximum split value for each variable. - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - sigma_mu - The prior standard deviation of each leaf. - - Returns - ------- - An object containing a generated tree. - """ - carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) - xs = SamplePriorX.initial(p_nonterminal) - - def loop(carry: SamplePriorCarry, x: SamplePriorX): - keys = split_key(carry.key, 4) - - # get variables at current stack level - stack = carry.stack - nonterminal = stack.nonterminal[x.depth] - lower = stack.lower[x.depth, :] - upper = stack.upper[x.depth, :] - - # sample a random decision rule - available: Bool[Array, ' p'] = lower < upper - allowed = jnp.any(available) - var = randint_masked(keys.pop(), available) - split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) - - # cast to shorter integer types - var = var.astype(carry.trees.var_tree.dtype) - split = split.astype(carry.trees.split_tree.dtype) - - # decide whether to try to grow the node if it is growable - pnt = p_nonterminal[x.depth] - try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) - nonterminal &= try_nonterminal & allowed - - # update trees - trees = carry.trees - trees = replace( - trees, - var_tree=trees.var_tree.at[x.node].set(var), - split_tree=trees.split_tree.at[x.node].set( - jnp.where(nonterminal, split, 0) - ), - ) - - def write_push_stack() -> SamplePriorStack: - """Update the stack to go to the left child.""" - return replace( - stack, - nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal), - lower=stack.lower.at[x.next_depth, :].set(lower), - upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)), - var=stack.var.at[x.depth].set(var), - split=stack.split.at[x.depth].set(split), - ) - - def pop_push_stack() -> SamplePriorStack: - """Update the stack to go to the right sibling, possibly at lower depth.""" - var = stack.var[x.next_depth - 1] - split = stack.split[x.next_depth - 1] - lower = stack.lower[x.next_depth - 1, :] - upper = stack.upper[x.next_depth - 1, :] - return replace( - stack, - lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)), - upper=stack.upper.at[x.next_depth, :].set(upper), - ) - - # update stack - stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) - - # update carry - carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) - return carry, None - - carry, _ = lax.scan(loop, carry, xs) - return carry.trees - - -@partial(vmap_nodoc, in_axes=(0, None, None, None)) -def sample_prior_forest( - keys: Key[Array, ' num_trees'], - max_split: UInt[Array, ' p'], - p_nonterminal: Float32[Array, ' d-1'], - sigma_mu: Float32[Array, ''], -) -> SamplePriorTrees: - """Sample a set of independent trees from the BART prior. - - Parameters - ---------- - keys - A sequence of jax random keys, one for each tree. This determined the - number of trees sampled. - max_split - The maximum split value for each variable. - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - sigma_mu - The prior standard deviation of each leaf. - - Returns - ------- - An object containing the generated trees. - """ - return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) - - -@partial(jit, static_argnums=(1, 2)) -def sample_prior( - key: Key[Array, ''], - trace_length: int, - num_trees: int, - max_split: UInt[Array, ' p'], - p_nonterminal: Float32[Array, ' d-1'], - sigma_mu: Float32[Array, ''], -) -> SamplePriorTrees: - """Sample independent trees from the BART prior. - - Parameters - ---------- - key - A jax random key. - trace_length - The number of iterations. - num_trees - The number of trees for each iteration. - max_split - The number of cutpoints along each variable. - p_nonterminal - The prior probability of a node being non-terminal conditional on - its ancestors and on having available decision rules, at each depth. - This determines the maximum depth of the trees. - sigma_mu - The prior standard deviation of each leaf. - - Returns - ------- - An object containing the generated trees, with batch shape (trace_length, num_trees). - """ - keys = random.split(key, trace_length * num_trees) - trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) - return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) - - -class debug_mc_gbart(mc_gbart): - """A subclass of `mc_gbart` that adds debugging functionality. - - Parameters - ---------- - *args - Passed to `mc_gbart`. - check_trees - If `True`, check all trees with `check_trace` after running the MCMC, - and assert that they are all valid. Set to `False` to allow jax tracing. - **kw - Passed to `mc_gbart`. - """ - - def __init__(self, *args, check_trees: bool = True, **kw): - super().__init__(*args, **kw) - if check_trees: - bad = self.check_trees() - bad_count = jnp.count_nonzero(bad) - assert bad_count == 0 - - def print_tree( - self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False - ): - """Print a single tree in human-readable format. - - Parameters - ---------- - i_chain - The index of the MCMC chain. - i_sample - The index of the (post-burnin) sample in the chain. - i_tree - The index of the tree in the sample. - print_all - If `True`, also print the content of unused node slots. - """ - tree = TreesTrace.from_dataclass(self._main_trace) - tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree) - s = format_tree(tree, print_all=print_all) - print(s) # noqa: T201, this method is intended for debug - - def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: - """Return the harmonic mean of the error variance. - - Parameters - ---------- - prior - If `True`, use the prior distribution, otherwise use the full - conditional at the last MCMC iteration. - - Returns - ------- - The harmonic mean 1/E[1/sigma^2] in the selected distribution. - """ - bart = self._mcmc_state - assert bart.error_cov_df is not None - assert bart.z is None - # inverse gamma prior: alpha = df / 2, beta = scale / 2 - if prior: - alpha = bart.error_cov_df / 2 - beta = bart.error_cov_scale / 2 - else: - alpha = bart.error_cov_df / 2 + bart.resid.size / 2 - norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid) - beta = bart.error_cov_scale / 2 + norm2 / 2 - error_cov_inv = alpha / beta - return jnp.sqrt(lax.reciprocal(error_cov_inv)) - - def compare_resid( - self, - ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]: - """Re-compute residuals to compare them with the updated ones. - - Returns - ------- - resid1 : Float32[Array, 'mc_cores n'] - The final state of the residuals updated during the MCMC. - resid2 : Float32[Array, 'mc_cores n'] - The residuals computed from the final state of the trees. - """ - bart = self._mcmc_state - resid1 = bart.resid - - forests = TreesTrace.from_dataclass(bart.forest) - trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1) - - if bart.z is not None: - ref = bart.z - else: - ref = bart.y - resid2 = ref - (trees + bart.offset) - - return resid1, resid2 - - def avg_acc( - self, - ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: - """Compute the average acceptance rates of tree moves. - - Returns - ------- - acc_grow : Float32[Array, 'mc_cores'] - The average acceptance rate of grow moves. - acc_prune : Float32[Array, 'mc_cores'] - The average acceptance rate of prune moves. - """ - trace = self._main_trace - - def acc(prefix): - acc = getattr(trace, f'{prefix}_acc_count') - prop = getattr(trace, f'{prefix}_prop_count') - return acc.sum(axis=1) / prop.sum(axis=1) - - return acc('grow'), acc('prune') - - def avg_prop( - self, - ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: - """Compute the average proposal rate of grow and prune moves. - - Returns - ------- - prop_grow : Float32[Array, 'mc_cores'] - The fraction of times grow was proposed instead of prune. - prop_prune : Float32[Array, 'mc_cores'] - The fraction of times prune was proposed instead of grow. - - Notes - ----- - This function does not take into account cases where no move was - proposed. - """ - trace = self._main_trace - - def prop(prefix): - return getattr(trace, f'{prefix}_prop_count').sum(axis=1) - - pgrow = prop('grow') - pprune = prop('prune') - total = pgrow + pprune - return pgrow / total, pprune / total - - def avg_move( - self, - ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: - """Compute the move rate. - - Returns - ------- - rate_grow : Float32[Array, 'mc_cores'] - The fraction of times a grow move was proposed and accepted. - rate_prune : Float32[Array, 'mc_cores'] - The fraction of times a prune move was proposed and accepted. - """ - agrow, aprune = self.avg_acc() - pgrow, pprune = self.avg_prop() - return agrow * pgrow, aprune * pprune - - def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']: - """Histogram of tree depths for each state of the trees. - - Returns - ------- - A matrix where each row contains a histogram of tree depths. - """ - out: Int32[Array, '*chains samples d'] - out = forest_depth_distr(self._main_trace.split_tree) - if out.ndim < 3: - out = out[None, :, :] - return out - - def _points_per_node_distr( - self, node_type: str - ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: - out: Int32[Array, '*chains samples n+1'] - out = points_per_node_distr( - self._mcmc_state.X, - self._main_trace.var_tree, - self._main_trace.split_tree, - node_type, - sum_batch_axis=-1, - ) - if out.ndim < 3: - out = out[None, :, :] - return out - - def points_per_decision_node_distr( - self, - ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: - """Histogram of number of points belonging to parent-of-leaf nodes. - - Returns - ------- - For each chain, a matrix where each row contains a histogram of number of points. - """ - return self._points_per_node_distr('leaf-parent') - - def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: - """Histogram of number of points belonging to leaves. - - Returns - ------- - A matrix where each row contains a histogram of number of points. - """ - return self._points_per_node_distr('leaf') - - def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: - """Apply `check_trace` to all the tree draws.""" - out: UInt[Array, '*chains samples num_trees'] - out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) - if out.ndim < 3: - out = out[None, :, :] - return out - - def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: - """Find iterations where a tree becomes invalid. - - Returns - ------- - A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. - """ - bad = self.check_trees().astype(bool) - bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) - return bad & ~bad_before - - -class debug_gbart(debug_mc_gbart, gbart): - """A subclass of `gbart` that adds debugging functionality. - - Parameters - ---------- - *args - Passed to `gbart`. - check_trees - If `True`, check all trees with `check_trace` after running the MCMC, - and assert that they are all valid. Set to `False` to allow jax tracing. - **kw - Passed to `gbart`. - """ diff --git a/src/bartz/debug/__init__.py b/src/bartz/debug/__init__.py new file mode 100644 index 00000000..105a58e9 --- /dev/null +++ b/src/bartz/debug/__init__.py @@ -0,0 +1,39 @@ +# bartz/src/bartz/debug/__init__.py +# +# Copyright (c) 2024-2026, The Bartz Contributors +# +# This file is part of bartz. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +Debugging utilities. + + - `check_trace`: check the validity of a set of trees. + - `debug_mc_gbart`: version of `mc_gbart` with debug checks and methods. + - `trees_BART_to_bartz`: convert an R package BART3 trace to a bartz trace. + - `sample_prior`: sample the bart prior. +""" + +# ruff: noqa: F401 + +from bartz.debug._check import check_trace, describe_error +from bartz.debug._debuggbart import debug_gbart, debug_mc_gbart +from bartz.debug._prior import SamplePriorTrees, sample_prior +from bartz.debug._traceconv import BARTTraceMeta, TraceWithOffset, trees_BART_to_bartz diff --git a/src/bartz/debug/_check.py b/src/bartz/debug/_check.py new file mode 100644 index 00000000..db63f031 --- /dev/null +++ b/src/bartz/debug/_check.py @@ -0,0 +1,284 @@ +# bartz/src/bartz/debug/_check.py +# +# Copyright (c) 2026, The Bartz Contributors +# +# This file is part of bartz. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Implement functions to check validity of trees.""" + +from typing import Protocol + +from jax import jit +from jax import numpy as jnp +from jaxtyping import Array, Bool, Integer, UInt + +from bartz.grove import TreeHeaps, is_actual_leaf +from bartz.jaxext import autobatch, minimal_unsigned_dtype +from bartz.mcmcloop import TreesTrace + +CHECK_FUNCTIONS = [] + + +class CheckFunc(Protocol): + """Protocol for functions that check whether a tree is valid.""" + + def __call__( + self, tree: TreeHeaps, max_split: UInt[Array, ' p'], / + ) -> bool | Bool[Array, '']: + """Check whether a tree is valid. + + Parameters + ---------- + tree + The tree to check. + max_split + The maximum split value for each variable. + + Returns + ------- + A boolean scalar indicating whether the tree is valid. + """ + ... + + +def check(func: CheckFunc) -> CheckFunc: + """Add a function to a list of functions used to check trees. + + Use to decorate functions that check whether a tree is valid in some way. + These functions are invoked automatically by `check_tree`, `check_trace` and + `debug_gbart`. + + Parameters + ---------- + func + The function to add to the list. It must accept a `TreeHeaps` and a + `max_split` argument, and return a boolean scalar that indicates if the + tree is ok. + + Returns + ------- + The function unchanged. + """ + CHECK_FUNCTIONS.append(func) + return func + + +@check +def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: + """Check that integer types are as small as possible and coherent.""" + expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) + expected_split_dtype = max_split.dtype + return ( + tree.var_tree.dtype == expected_var_dtype + and tree.split_tree.dtype == expected_split_dtype + and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger) + ) + + +@check +def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool: + """Check that array sizes are coherent.""" + return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size + + +@check +def check_unused_node( + tree: TreeHeaps, _max_split: UInt[Array, ' p'] +) -> Bool[Array, '']: + """Check that the unused node slot at index 0 is not dirty.""" + return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) + + +@check +def check_leaf_values( + tree: TreeHeaps, _max_split: UInt[Array, ' p'] +) -> Bool[Array, '']: + """Check that all leaf values are not inf of nan.""" + return jnp.all(jnp.isfinite(tree.leaf_tree)) + + +@check +def check_stray_nodes( + tree: TreeHeaps, _max_split: UInt[Array, ' p'] +) -> Bool[Array, '']: + """Check if there is any marked-non-leaf node with a marked-leaf parent.""" + index = jnp.arange( + 2 * tree.split_tree.size, + dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), + ) + parent_index = index >> 1 + is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 + parent_is_leaf = tree.split_tree[parent_index] == 0 + stray = is_not_leaf & parent_is_leaf + stray = stray.at[1].set(False) + return ~jnp.any(stray) + + +@check +def check_rule_consistency( + tree: TreeHeaps, max_split: UInt[Array, ' p'] +) -> bool | Bool[Array, '']: + """Check that decision rules define proper subsets of ancestor rules.""" + if tree.var_tree.size < 4: + return True + + # initial boundaries of decision rules. use extreme integers instead of 0, + # max_split to avoid checking if there is something out of bounds. + dtype = tree.split_tree.dtype + small = jnp.iinfo(dtype).min + large = jnp.iinfo(dtype).max + lower = jnp.full(max_split.size, small, dtype) + upper = jnp.full(max_split.size, large, dtype) + # the split must be in (lower[var], upper[var]] + + def _check_recursive( + node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p'] + ) -> Bool[Array, '']: + # read decision rule + var = tree.var_tree[node] + split = tree.split_tree[node] + + # get rule boundaries from ancestors. use fill value in case var is + # out of bounds, we don't want to check out of bounds in this function + lower_var = lower.at[var].get(mode='fill', fill_value=small) + upper_var = upper.at[var].get(mode='fill', fill_value=large) + + # check rule is in bounds + bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) + + # recurse + if node < tree.var_tree.size // 2: + idx = jnp.where(split, var, max_split.size) + bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) + bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) + + return bad + + return ~_check_recursive(1, lower, upper) + + +@check +def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 + """Check that #leaves = 1 + #(internal nodes).""" + is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) + num_leaves = jnp.count_nonzero(is_leaf) + num_internal = jnp.count_nonzero(tree.split_tree) + return num_leaves == num_internal + 1 + + +@check +def check_var_in_bounds( + tree: TreeHeaps, max_split: UInt[Array, ' p'] +) -> Bool[Array, '']: + """Check that variables are in [0, max_split.size).""" + decision_node = tree.split_tree.astype(bool) + in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) + return jnp.all(in_bounds | ~decision_node) + + +@check +def check_split_in_bounds( + tree: TreeHeaps, max_split: UInt[Array, ' p'] +) -> Bool[Array, '']: + """Check that splits are in [0, max_split[var]].""" + max_split_var = ( + max_split.astype(jnp.int32) + .at[tree.var_tree] + .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) + ) + return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) + + +def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: + """Check the validity of a tree. + + Use `describe_error` to parse the error code returned by this function. + + Parameters + ---------- + tree + The tree to check. + max_split + The maximum split value for each variable. + + Returns + ------- + An integer where each bit indicates whether a check failed. + """ + error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) + error = error_type(0) + for i, func in enumerate(CHECK_FUNCTIONS): + ok = func(tree, max_split) + ok = jnp.bool_(ok) + bit = (~ok) << i + error |= bit + return error + + +def describe_error(error: int | Integer[Array, '']) -> list[str]: + """Describe an error code returned by `check_trace`. + + Parameters + ---------- + error + An error code returned by `check_trace`. + + Returns + ------- + A list of the function names that implement the failed checks. + """ + return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] + + +@jit +def check_trace( + trace: TreeHeaps, max_split: UInt[Array, ' p'] +) -> UInt[Array, '*batch_shape']: + """Check the validity of a set of trees. + + Use `describe_error` to parse the error codes returned by this function. + + Parameters + ---------- + trace + The set of trees to check. This object can have additional attributes + beyond the tree arrays, they are ignored. + max_split + The maximum split value for each variable. + + Returns + ------- + A tensor of error codes for each tree. + """ + # vectorize check_tree over all batch dimensions + unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) + is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim + signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' + vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) + + # automatically batch over all batch dimensions + max_io_nbytes = 2**24 # 16 MiB + batch_ndim = trace.split_tree.ndim - 1 + batched_check_tree = vec_check_tree + for i in reversed(range(batch_ndim)): + batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) + + return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) diff --git a/src/bartz/debug/_debuggbart.py b/src/bartz/debug/_debuggbart.py new file mode 100644 index 00000000..f5dcd469 --- /dev/null +++ b/src/bartz/debug/_debuggbart.py @@ -0,0 +1,316 @@ +# bartz/src/bartz/debug/_debuggbart.py +# +# Copyright (c) 2026, The Bartz Contributors +# +# This file is part of bartz. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" + +from dataclasses import replace +from typing import Any + +from equinox import error_if +from jax import numpy as jnp +from jax import tree +from jax.sharding import PartitionSpec +from jax.tree_util import tree_map +from jaxtyping import Array, Bool, Float32, Int32, UInt + +from bartz.BART import gbart, mc_gbart +from bartz.debug._check import check_trace +from bartz.grove import ( + evaluate_forest, + forest_depth_distr, + format_tree, + points_per_node_distr, +) +from bartz.jaxext import equal_shards +from bartz.mcmcloop import TreesTrace + + +class debug_mc_gbart(mc_gbart): + """A subclass of `mc_gbart` that adds debugging functionality. + + Parameters + ---------- + *args + Passed to `mc_gbart`. + check_trees + If `True`, check all trees with `check_trace` after running the MCMC, + and assert that they are all valid. + check_replicated_trees + If the data is sharded across devices, check that the trees are equal + on all devices in the final state. Set to `False` to allow jax tracing. + **kwargs + Passed to `mc_gbart`. + """ + + def __init__( + self, + *args: Any, + check_trees: bool = True, + check_replicated_trees: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if check_trees: + bad = self.check_trees() + bad_count = jnp.count_nonzero(bad) + self._bart.__dict__['offset'] = error_if( + self._bart.offset, bad_count > 0, 'invalid trees found in trace' + ) + + state = self._mcmc_state + mesh = state.config.mesh + if check_replicated_trees and mesh is not None and 'data' in mesh.axis_names: + replicated_forest = replace(state.forest, leaf_indices=None) + equal = equal_shards( + replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh + ) + equal_array = jnp.stack(tree.leaves(equal)) + all_equal = jnp.all(equal_array) + # we could use error_if here for traceability, but last time we + # tried it hanged on error, maybe it was due to sharding. + assert all_equal.item(), 'the trees are different across devices' + + def print_tree( + self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False + ) -> None: + """Print a single tree in human-readable format. + + Parameters + ---------- + i_chain + The index of the MCMC chain. + i_sample + The index of the (post-burnin) sample in the chain. + i_tree + The index of the tree in the sample. + print_all + If `True`, also print the content of unused node slots. + """ + tree = TreesTrace.from_dataclass(self._main_trace) + tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree) + s = format_tree(tree, print_all=print_all) + print(s) # noqa: T201, this method is intended for debug + + def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: + """Return the harmonic mean of the error variance. + + Parameters + ---------- + prior + If `True`, use the prior distribution, otherwise use the full + conditional at the last MCMC iteration. + + Returns + ------- + The harmonic mean 1/E[1/sigma^2] in the selected distribution. + """ + bart = self._mcmc_state + assert bart.error_cov_df is not None + assert bart.z is None + # inverse gamma prior: alpha = df / 2, beta = scale / 2 + if prior: + alpha = bart.error_cov_df / 2 + beta = bart.error_cov_scale / 2 + else: + alpha = bart.error_cov_df / 2 + bart.resid.size / 2 + norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid) + beta = bart.error_cov_scale / 2 + norm2 / 2 + error_cov_inv = alpha / beta + return jnp.sqrt(jnp.reciprocal(error_cov_inv)) + + def compare_resid( + self, + ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]: + """Re-compute residuals to compare them with the updated ones. + + Returns + ------- + resid1 : Float32[Array, 'mc_cores n'] + The final state of the residuals updated during the MCMC. + resid2 : Float32[Array, 'mc_cores n'] + The residuals computed from the final state of the trees. + """ + bart = self._mcmc_state + resid1 = bart.resid + + forests = TreesTrace.from_dataclass(bart.forest) + trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1) + + if bart.z is not None: + ref = bart.z + else: + ref = bart.y + resid2 = ref - (trees + bart.offset) + + return resid1, resid2 + + def avg_acc( + self, + ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: + """Compute the average acceptance rates of tree moves. + + Returns + ------- + acc_grow : Float32[Array, 'mc_cores'] + The average acceptance rate of grow moves. + acc_prune : Float32[Array, 'mc_cores'] + The average acceptance rate of prune moves. + """ + trace = self._main_trace + + def acc(prefix: str) -> Float32[Array, ' mc_cores']: + acc = getattr(trace, f'{prefix}_acc_count') + prop = getattr(trace, f'{prefix}_prop_count') + return acc.sum(axis=1) / prop.sum(axis=1) + + return acc('grow'), acc('prune') + + def avg_prop( + self, + ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: + """Compute the average proposal rate of grow and prune moves. + + Returns + ------- + prop_grow : Float32[Array, 'mc_cores'] + The fraction of times grow was proposed instead of prune. + prop_prune : Float32[Array, 'mc_cores'] + The fraction of times prune was proposed instead of grow. + + Notes + ----- + This function does not take into account cases where no move was + proposed. + """ + trace = self._main_trace + + def prop(prefix: str) -> Array: + return getattr(trace, f'{prefix}_prop_count').sum(axis=1) + + pgrow = prop('grow') + pprune = prop('prune') + total = pgrow + pprune + return pgrow / total, pprune / total + + def avg_move( + self, + ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: + """Compute the move rate. + + Returns + ------- + rate_grow : Float32[Array, 'mc_cores'] + The fraction of times a grow move was proposed and accepted. + rate_prune : Float32[Array, 'mc_cores'] + The fraction of times a prune move was proposed and accepted. + """ + agrow, aprune = self.avg_acc() + pgrow, pprune = self.avg_prop() + return agrow * pgrow, aprune * pprune + + def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']: + """Histogram of tree depths for each state of the trees. + + Returns + ------- + A matrix where each row contains a histogram of tree depths. + """ + out: Int32[Array, '*chains samples d'] + out = forest_depth_distr(self._main_trace.split_tree) + if out.ndim < 3: + out = out[None, :, :] + return out + + def _points_per_node_distr( + self, node_type: str + ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: + out: Int32[Array, '*chains samples n+1'] + out = points_per_node_distr( + self._mcmc_state.X, + self._main_trace.var_tree, + self._main_trace.split_tree, + node_type, + sum_batch_axis=-1, + ) + if out.ndim < 3: + out = out[None, :, :] + return out + + def points_per_decision_node_distr( + self, + ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: + """Histogram of number of points belonging to parent-of-leaf nodes. + + Returns + ------- + For each chain, a matrix where each row contains a histogram of number of points. + """ + return self._points_per_node_distr('leaf-parent') + + def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: + """Histogram of number of points belonging to leaves. + + Returns + ------- + A matrix where each row contains a histogram of number of points. + """ + return self._points_per_node_distr('leaf') + + def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: + """Apply `check_trace` to all the tree draws.""" + out: UInt[Array, '*chains samples num_trees'] + out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) + if out.ndim < 3: + out = out[None, :, :] + return out + + def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: + """Find iterations where a tree becomes invalid. + + Returns + ------- + A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. + """ + bad = self.check_trees().astype(bool) + bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) + return bad & ~bad_before + + +class debug_gbart(debug_mc_gbart, gbart): + """A subclass of `gbart` that adds debugging functionality. + + Parameters + ---------- + *args + Passed to `gbart`. + check_trees + If `True`, check all trees with `check_trace` after running the MCMC, + and assert that they are all valid. + check_replicated_trees + If the data is sharded across devices, check that the trees are equal + on all devices in the final state. Set to `False` to allow jax tracing. + **kw + Passed to `gbart`. + """ diff --git a/src/bartz/debug/_prior.py b/src/bartz/debug/_prior.py new file mode 100644 index 00000000..c2df29bb --- /dev/null +++ b/src/bartz/debug/_prior.py @@ -0,0 +1,402 @@ +# bartz/src/bartz/debug/_prior.py +# +# Copyright (c) 2026, The Bartz Contributors +# +# This file is part of bartz. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" + +from dataclasses import replace +from functools import partial + +from equinox import Module +from jax import jit, lax, random +from jax import numpy as jnp +from jax.tree_util import tree_map +from jaxtyping import Array, Bool, Float32, Int32, Key, UInt + +from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc +from bartz.jaxext import split as split_key +from bartz.mcmcstep._moves import randint_masked + + +class SamplePriorStack(Module): + """Represent the manually managed stack used in `sample_prior`. + + Each level of the stack represents a recursion into a child node in a + binary tree of maximum depth `d`. + """ + + nonterminal: Bool[Array, ' d-1'] + """Whether the node is valid or the recursion is into unused node slots.""" + + lower: UInt[Array, 'd-1 p'] + """The available cutpoints along ``var`` are in the integer range + ``[1 + lower[var], 1 + upper[var])``.""" + + upper: UInt[Array, 'd-1 p'] + """The available cutpoints along ``var`` are in the integer range + ``[1 + lower[var], 1 + upper[var])``.""" + + var: UInt[Array, ' d-1'] + """The variable of a decision node.""" + + split: UInt[Array, ' d-1'] + """The cutpoint of a decision node.""" + + @classmethod + def initial( + cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p'] + ) -> 'SamplePriorStack': + """Initialize the stack. + + Parameters + ---------- + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + max_split + The number of cutpoints along each variable. + + Returns + ------- + A `SamplePriorStack` initialized to start the recursion. + """ + var_dtype = minimal_unsigned_dtype(max_split.size - 1) + return cls( + nonterminal=jnp.ones(p_nonterminal.size, bool), + lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype), + upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)), + var=jnp.zeros(p_nonterminal.size, var_dtype), + split=jnp.zeros(p_nonterminal.size, max_split.dtype), + ) + + +class SamplePriorTrees(Module): + """Object holding the trees generated by `sample_prior`.""" + + leaf_tree: Float32[Array, '* 2**d'] + """The array representing the trees, see `bartz.grove`.""" + + var_tree: UInt[Array, '* 2**(d-1)'] + """The array representing the trees, see `bartz.grove`.""" + + split_tree: UInt[Array, '* 2**(d-1)'] + """The array representing the trees, see `bartz.grove`.""" + + @classmethod + def initial( + cls, + key: Key[Array, ''], + sigma_mu: Float32[Array, ''], + p_nonterminal: Float32[Array, ' d-1'], + max_split: UInt[Array, ' p'], + ) -> 'SamplePriorTrees': + """Initialize the trees. + + The leaves are already correct and do not need to be changed. + + Parameters + ---------- + key + A jax random key. + sigma_mu + The prior standard deviation of each leaf. + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + max_split + The number of cutpoints along each variable. + + Returns + ------- + Trees initialized with random leaves and stub tree structures. + """ + heap_size = 2 ** (p_nonterminal.size + 1) + return cls( + leaf_tree=sigma_mu * random.normal(key, (heap_size,)), + var_tree=jnp.zeros( + heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1) + ), + split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype), + ) + + +class SamplePriorCarry(Module): + """Object holding values carried along the recursion in `sample_prior`.""" + + key: Key[Array, ''] + """A jax random key used to sample decision rules.""" + + stack: SamplePriorStack + """The stack used to manage the recursion.""" + + trees: SamplePriorTrees + """The output arrays.""" + + @classmethod + def initial( + cls, + key: Key[Array, ''], + sigma_mu: Float32[Array, ''], + p_nonterminal: Float32[Array, ' d-1'], + max_split: UInt[Array, ' p'], + ) -> 'SamplePriorCarry': + """Initialize the carry object. + + Parameters + ---------- + key + A jax random key. + sigma_mu + The prior standard deviation of each leaf. + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + max_split + The number of cutpoints along each variable. + + Returns + ------- + A `SamplePriorCarry` initialized to start the recursion. + """ + keys = split_key(key) + return cls( + keys.pop(), + SamplePriorStack.initial(p_nonterminal, max_split), + SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split), + ) + + +class SamplePriorX(Module): + """Object representing the recursion scan in `sample_prior`. + + The sequence of nodes to visit is pre-computed recursively once, unrolling + the recursion schedule. + """ + + node: Int32[Array, ' 2**(d-1)-1'] + """The heap index of the node to visit.""" + + depth: Int32[Array, ' 2**(d-1)-1'] + """The depth of the node.""" + + next_depth: Int32[Array, ' 2**(d-1)-1'] + """The depth of the next node to visit, either the left child or the right + sibling of the node or of an ancestor.""" + + @classmethod + def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': + """Initialize the sequence of nodes to visit. + + Parameters + ---------- + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + + Returns + ------- + A `SamplePriorX` initialized with the sequence of nodes to visit. + """ + seq = cls._sequence(p_nonterminal.size) + assert len(seq) == 2**p_nonterminal.size - 1 + node = [node for node, depth in seq] + depth = [depth for node, depth in seq] + next_depth = [*depth[1:], p_nonterminal.size] + return cls( + node=jnp.array(node), + depth=jnp.array(depth), + next_depth=jnp.array(next_depth), + ) + + @classmethod + def _sequence( + cls, max_depth: int, depth: int = 0, node: int = 1 + ) -> tuple[tuple[int, int], ...]: + """Recursively generate a sequence [(node, depth), ...].""" + if depth < max_depth: + out = ((node, depth),) + out += cls._sequence(max_depth, depth + 1, 2 * node) + out += cls._sequence(max_depth, depth + 1, 2 * node + 1) + return out + return () + + +def sample_prior_onetree( + key: Key[Array, ''], + max_split: UInt[Array, ' p'], + p_nonterminal: Float32[Array, ' d-1'], + sigma_mu: Float32[Array, ''], +) -> SamplePriorTrees: + """Sample a tree from the BART prior. + + Parameters + ---------- + key + A jax random key. + max_split + The maximum split value for each variable. + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + sigma_mu + The prior standard deviation of each leaf. + + Returns + ------- + An object containing a generated tree. + """ + carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) + xs = SamplePriorX.initial(p_nonterminal) + + def loop(carry: SamplePriorCarry, x: SamplePriorX) -> tuple[SamplePriorCarry, None]: + keys = split_key(carry.key, 4) + + # get variables at current stack level + stack = carry.stack + nonterminal = stack.nonterminal[x.depth] + lower = stack.lower[x.depth, :] + upper = stack.upper[x.depth, :] + + # sample a random decision rule + available: Bool[Array, ' p'] = lower < upper + allowed = jnp.any(available) + var = randint_masked(keys.pop(), available) + split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) + + # cast to shorter integer types + var = var.astype(carry.trees.var_tree.dtype) + split = split.astype(carry.trees.split_tree.dtype) + + # decide whether to try to grow the node if it is growable + pnt = p_nonterminal[x.depth] + try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) + nonterminal &= try_nonterminal & allowed + + # update trees + trees = carry.trees + trees = replace( + trees, + var_tree=trees.var_tree.at[x.node].set(var), + split_tree=trees.split_tree.at[x.node].set( + jnp.where(nonterminal, split, 0) + ), + ) + + def write_push_stack() -> SamplePriorStack: + """Update the stack to go to the left child.""" + return replace( + stack, + nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal), + lower=stack.lower.at[x.next_depth, :].set(lower), + upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)), + var=stack.var.at[x.depth].set(var), + split=stack.split.at[x.depth].set(split), + ) + + def pop_push_stack() -> SamplePriorStack: + """Update the stack to go to the right sibling, possibly at lower depth.""" + var = stack.var[x.next_depth - 1] + split = stack.split[x.next_depth - 1] + lower = stack.lower[x.next_depth - 1, :] + upper = stack.upper[x.next_depth - 1, :] + return replace( + stack, + lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)), + upper=stack.upper.at[x.next_depth, :].set(upper), + ) + + # update stack + stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) + + # update carry + carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) + return carry, None + + carry, _ = lax.scan(loop, carry, xs) + return carry.trees + + +@partial(vmap_nodoc, in_axes=(0, None, None, None)) +def sample_prior_forest( + keys: Key[Array, ' num_trees'], + max_split: UInt[Array, ' p'], + p_nonterminal: Float32[Array, ' d-1'], + sigma_mu: Float32[Array, ''], +) -> SamplePriorTrees: + """Sample a set of independent trees from the BART prior. + + Parameters + ---------- + keys + A sequence of jax random keys, one for each tree. This determined the + number of trees sampled. + max_split + The maximum split value for each variable. + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + sigma_mu + The prior standard deviation of each leaf. + + Returns + ------- + An object containing the generated trees. + """ + return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) + + +@partial(jit, static_argnums=(1, 2)) +def sample_prior( + key: Key[Array, ''], + trace_length: int, + num_trees: int, + max_split: UInt[Array, ' p'], + p_nonterminal: Float32[Array, ' d-1'], + sigma_mu: Float32[Array, ''], +) -> SamplePriorTrees: + """Sample independent trees from the BART prior. + + Parameters + ---------- + key + A jax random key. + trace_length + The number of iterations. + num_trees + The number of trees for each iteration. + max_split + The number of cutpoints along each variable. + p_nonterminal + The prior probability of a node being non-terminal conditional on + its ancestors and on having available decision rules, at each depth. + This determines the maximum depth of the trees. + sigma_mu + The prior standard deviation of each leaf. + + Returns + ------- + An object containing the generated trees, with batch shape (trace_length, num_trees). + """ + keys = random.split(key, trace_length * num_trees) + trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) + return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) diff --git a/src/bartz/debug/_traceconv.py b/src/bartz/debug/_traceconv.py new file mode 100644 index 00000000..b9dc5aad --- /dev/null +++ b/src/bartz/debug/_traceconv.py @@ -0,0 +1,245 @@ +# bartz/src/bartz/debug/_traceconv.py +# +# Copyright (c) 2026, The Bartz Contributors +# +# This file is part of bartz. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" + +from math import ceil, log2 +from re import fullmatch + +import numpy +from equinox import Module, field +from jax import numpy as jnp +from jaxtyping import Array, Float32, UInt + +from bartz.BART._gbart import FloatLike +from bartz.grove import TreeHeaps +from bartz.jaxext import minimal_unsigned_dtype + + +def _get_next_line(s: str, i: int) -> tuple[str, int]: + """Get the next line from a string and the new index.""" + i_new = s.find('\n', i) + if i_new == -1: + return s[i:], len(s) + return s[i:i_new], i_new + 1 + + +class BARTTraceMeta(Module): + """Metadata of R BART tree traces.""" + + ndpost: int = field(static=True) + """The number of posterior draws.""" + + ntree: int = field(static=True) + """The number of trees in the model.""" + + numcut: UInt[Array, ' p'] + """The maximum split value for each variable.""" + + heap_size: int = field(static=True) + """The size of the heap required to store the trees.""" + + +def scan_BART_trees(trees: str) -> BARTTraceMeta: + """Scan an R BART tree trace checking for errors and parsing metadata. + + Parameters + ---------- + trees + The string representation of a trace of trees of the R BART package. + Can be accessed from ``mc_gbart(...).treedraws['trees']``. + + Returns + ------- + An object containing the metadata. + + Raises + ------ + ValueError + If the string is malformed or contains leftover characters. + """ + # parse first line + line, i_char = _get_next_line(trees, 0) + i_line = 1 + match = fullmatch(r'(\d+) (\d+) (\d+)', line) + if match is None: + msg = f'Malformed header at {i_line=}' + raise ValueError(msg) + ndpost, ntree, p = map(int, match.groups()) + + # initial values for maxima + max_heap_index = 0 + numcut = numpy.zeros(p, int) + + # cycle over iterations and trees + for i_iter in range(ndpost): + for i_tree in range(ntree): + # parse first line of tree definition + line, i_char = _get_next_line(trees, i_char) + i_line += 1 + match = fullmatch(r'(\d+)', line) + if match is None: + msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' + raise ValueError(msg) + num_nodes = int(line) + + # cycle over nodes + for i_node in range(num_nodes): + # parse node definition + line, i_char = _get_next_line(trees, i_char) + i_line += 1 + match = fullmatch( + r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line + ) + if match is None: + msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' + raise ValueError(msg) + i_heap = int(match.group(1)) + var = int(match.group(2)) + split = int(match.group(3)) + + # update maxima + numcut[var] = max(numcut[var], split) + max_heap_index = max(max_heap_index, i_heap) + + assert i_char <= len(trees) + if i_char < len(trees): + msg = f'Leftover {len(trees) - i_char} characters in string' + raise ValueError(msg) + + # determine minimal integer type for numcut + numcut += 1 # because BART is 0-based + split_dtype = minimal_unsigned_dtype(numcut.max()) + numcut = jnp.array(numcut.astype(split_dtype)) + + # determine minimum heap size to store the trees + heap_size = 2 ** ceil(log2(max_heap_index + 1)) + + return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) + + +class TraceWithOffset(Module): + """Implementation of `bartz.mcmcloop.Trace`.""" + + leaf_tree: Float32[Array, 'ndpost ntree 2**d'] + var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] + split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] + offset: Float32[Array, ' ndpost'] + + @classmethod + def from_trees_trace( + cls, trees: TreeHeaps, offset: Float32[Array, ''] + ) -> 'TraceWithOffset': + """Create a `TraceWithOffset` from a `TreeHeaps`.""" + ndpost, _, _ = trees.leaf_tree.shape + return cls( + leaf_tree=trees.leaf_tree, + var_tree=trees.var_tree, + split_tree=trees.split_tree, + offset=jnp.full(ndpost, offset), + ) + + +def trees_BART_to_bartz( + trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None +) -> tuple[TraceWithOffset, BARTTraceMeta]: + """Convert trees from the R BART format to the bartz format. + + Parameters + ---------- + trees + The string representation of a trace of trees of the R BART package. + Can be accessed from ``mc_gbart(...).treedraws['trees']``. + min_maxdepth + The maximum tree depth of the output will be set to the maximum + observed depth in the input trees. Use this parameter to require at + least this maximum depth in the output format. + offset + The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be + summed to the sum of trees. To match that behavior, this function + returns an offset as well, zero by default. Set with this parameter + otherwise. + + Returns + ------- + trace : TraceWithOffset + A representation of the trees compatible with the trace returned by + `bartz.mcmcloop.run_mcmc`. + meta : BARTTraceMeta + The metadata of the trace, containing the number of iterations, trees, + and the maximum split value. + """ + # scan all the string checking for errors and determining sizes + meta = scan_BART_trees(trees) + + # skip first line + _, i_char = _get_next_line(trees, 0) + + heap_size = max(meta.heap_size, 2**min_maxdepth) + leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) + var_trees = numpy.zeros( + (meta.ndpost, meta.ntree, heap_size // 2), + dtype=minimal_unsigned_dtype(meta.numcut.size - 1), + ) + split_trees = numpy.zeros( + (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype + ) + + # cycle over iterations and trees + for i_iter in range(meta.ndpost): + for i_tree in range(meta.ntree): + # parse first line of tree definition + line, i_char = _get_next_line(trees, i_char) + num_nodes = int(line) + + is_internal = numpy.zeros(heap_size // 2, dtype=bool) + + # cycle over nodes + for _ in range(num_nodes): + # parse node definition + line, i_char = _get_next_line(trees, i_char) + values = line.split() + i_heap = int(values[0]) + var = int(values[1]) + split = int(values[2]) + leaf = float(values[3]) + + # update values + leaf_trees[i_iter, i_tree, i_heap] = leaf + is_internal[i_heap // 2] = True + if i_heap < heap_size // 2: + var_trees[i_iter, i_tree, i_heap] = var + split_trees[i_iter, i_tree, i_heap] = split + 1 + + is_internal[0] = False + split_trees[i_iter, i_tree, ~is_internal] = 0 + + return TraceWithOffset( + leaf_tree=jnp.array(leaf_trees), + var_tree=jnp.array(var_trees), + split_tree=jnp.array(split_trees), + offset=jnp.zeros(meta.ndpost) + if offset is None + else jnp.full(meta.ndpost, offset), + ), meta diff --git a/src/bartz/grove.py b/src/bartz/grove.py index 9ff9613b..785dca62 100644 --- a/src/bartz/grove.py +++ b/src/bartz/grove.py @@ -26,18 +26,18 @@ import math from functools import partial -from typing import Protocol +from typing import Literal, Protocol from jax import jit, lax, vmap from jax import numpy as jnp -from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt +from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt try: from numpy.lib.array_utils import normalize_axis_tuple # numpy 2 except ImportError: from numpy.core.numeric import normalize_axis_tuple # numpy 1 -from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc +from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc class TreeHeaps(Protocol): @@ -72,31 +72,6 @@ class TreeHeaps(Protocol): 0. Unused nodes also have split set to 0. This array can't be dirty.""" -def make_tree( - depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = () -) -> Shaped[Array, '*batch_shape 2**{depth}']: - """ - Make an array to represent a binary tree. - - Parameters - ---------- - depth - The maximum depth of the tree. Depth 1 means that there is only a root - node. - dtype - The dtype of the array. - batch_shape - The leading shape of the array, to represent multiple trees and/or - multivariate trees. - - Returns - ------- - An array of zeroes with the appropriate shape. - """ - shape = (*batch_shape, 2**depth) - return jnp.zeros(shape, dtype) - - def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int: """ Return the maximum depth of a tree. @@ -104,8 +79,8 @@ def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int: Parameters ---------- tree - A tree created by `make_tree`. If the array is ND, the tree structure is - assumed to be along the last axis. + A tree array like those in a `TreeHeaps`. If the array is ND, the tree + structure is assumed to be along the last axis. Returns ------- @@ -140,7 +115,9 @@ def traverse_tree( jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)), ) - def loop(carry, _): + def loop( + carry: tuple[Bool[Array, ''], UInt[Array, '']], _: None + ) -> tuple[tuple[Bool[Array, ''], UInt[Array, '']], None]: leaf_found, index = carry split = split_tree[index] @@ -410,3 +387,223 @@ def scatter_add( scatter_add = vmap(scatter_add, in_axes=neg_i) return scatter_add(var_tree, is_internal) + + +def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: + """Convert a tree to a human-readable string. + + Parameters + ---------- + tree + A single tree to format. + print_all + If `True`, also print the contents of unused node slots in the arrays. + + Returns + ------- + A string representation of the tree. + """ + tee = '├──' + corner = '└──' + join = '│ ' + space = ' ' + down = '┐' + bottom = '╢' # '┨' # + + def traverse_tree( + lines: list[str], + index: int, + depth: int, + indent: str, + first_indent: str, + next_indent: str, + unused: bool, + ) -> None: + if index >= len(tree.leaf_tree): + return + + var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() + split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() + + is_leaf = split == 0 + left_child = 2 * index + right_child = 2 * index + 1 + + if print_all: + if unused: + category = 'unused' + elif is_leaf: + category = 'leaf' + else: + category = 'decision' + node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' + else: + assert not unused + if is_leaf: + node_str = f'{tree.leaf_tree[index]:#.2g}' + else: + node_str = f'x{var} < {split}' + + if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): + link = down + elif not print_all and left_child >= len(tree.leaf_tree): + link = bottom + else: + link = ' ' + + max_number = len(tree.leaf_tree) - 1 + ndigits = len(str(max_number)) + number = str(index).rjust(ndigits) + + lines.append(f' {number} {indent}{first_indent}{link}{node_str}') + + indent += next_indent + unused = unused or is_leaf + + if unused and not print_all: + return + + traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) + traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) + + lines = [] + traverse_tree(lines, 1, 0, '', '', '', False) + return '\n'.join(lines) + + +def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: + """Measure the depth of the tree. + + Parameters + ---------- + split_tree + The cutpoints of the decision rules. + + Returns + ------- + The depth of the deepest leaf in the tree. The root is at depth 0. + """ + # this could be done just with split_tree != 0 + is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) + depth = tree_depths(is_leaf.size) + depth = jnp.where(is_leaf, depth, 0) + return jnp.max(depth) + + +@jit +@partial(jnp.vectorize, signature='(nt,hts)->(d)') +def forest_depth_distr( + split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'], +) -> Int32[Array, '*batch_shape d']: + """Histogram the depths of a set of trees. + + Parameters + ---------- + split_tree + The cutpoints of the decision rules of the trees. + + Returns + ------- + An integer vector where the i-th element counts how many trees have depth i. + """ + depth = tree_depth(split_tree) + 1 + depths = vmap(tree_actual_depth)(split_tree) + return jnp.bincount(depths, length=depth) + + +@partial(jit, static_argnames=('node_type', 'sum_batch_axis')) +def points_per_node_distr( + X: UInt[Array, 'p n'], + var_tree: UInt[Array, '*batch_shape 2**(d-1)'], + split_tree: UInt[Array, '*batch_shape 2**(d-1)'], + node_type: Literal['leaf', 'leaf-parent'], + *, + sum_batch_axis: int | tuple[int, ...] = (), +) -> Int32[Array, '*reduced_batch_shape n+1']: + """Histogram points-per-node counts in a set of trees. + + Count how many nodes in a tree select each possible amount of points, + over a certain subset of nodes. + + Parameters + ---------- + X + The set of points to count. + var_tree + The variables of the decision rules. + split_tree + The cutpoints of the decision rules. + node_type + The type of nodes to consider. Can be: + + 'leaf' + Count only leaf nodes. + 'leaf-parent' + Count only parent-of-leaf nodes. + sum_batch_axis + Aggregate the histogram over these batch axes, counting how many nodes + have each possible amount of points over subsets of trees instead of + in each tree separately. + + Returns + ------- + A vector where the i-th element counts how many nodes have i points. + """ + batch_ndim = var_tree.ndim - 1 + axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) + + def func( + var_tree: UInt[Array, '*batch_shape 2**(d-1)'], + split_tree: UInt[Array, '*batch_shape 2**(d-1)'], + ) -> Int32[Array, '*reduced_batch_shape n+1']: + indices: UInt[Array, '*batch_shape n'] + indices = traverse_forest(X, var_tree, split_tree) + + @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') + def count_points( + split_tree: UInt[Array, '*batch_shape 2**(d-1)'], + indices: UInt[Array, '*batch_shape n'], + ) -> ( + tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']] + | tuple[ + UInt[Array, '*batch_shape 2**(d-1)'], + Bool[Array, '*batch_shape 2**(d-1)'], + ] + ): + if node_type == 'leaf-parent': + indices >>= 1 + predicate = is_leaves_parent(split_tree) + elif node_type == 'leaf': + predicate = is_actual_leaf(split_tree, add_bottom_level=True) + else: + raise ValueError(node_type) + count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) + return count_tree, predicate + + count_tree, predicate = count_points(split_tree, indices) + + def count_nodes( + count_tree: UInt[Array, '*summed_batch_axes half_tree_size'], + predicate: Bool[Array, '*summed_batch_axes half_tree_size'], + ) -> Int32[Array, ' n+1']: + return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) + + # vmap count_nodes over non-batched dims + for i in reversed(range(batch_ndim)): + neg_i = i - var_tree.ndim + if i not in axes: + count_nodes = vmap(count_nodes, in_axes=neg_i) + + return count_nodes(count_tree, predicate) + + # automatically batch over all batch dimensions + max_io_nbytes = 2**27 # 128 MiB + out_dim_shift = len(axes) + for i in reversed(range(batch_ndim)): + if i in axes: + out_dim_shift -= 1 + else: + func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) + assert out_dim_shift == 0 + + return func(var_tree, split_tree) diff --git a/src/bartz/jaxext/__init__.py b/src/bartz/jaxext/__init__.py index 51322c79..dd9971d8 100644 --- a/src/bartz/jaxext/__init__.py +++ b/src/bartz/jaxext/__init__.py @@ -25,9 +25,15 @@ """Additions to jax.""" import math -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from functools import partial +from typing import Any + +try: + from jax import shard_map # available since jax v0.6.1 +except ImportError: + from jax.experimental.shard_map import shard_map import jax from jax import ( @@ -36,20 +42,23 @@ device_count, ensure_compile_time_eval, jit, + lax, random, + tree, + typeof, vmap, ) from jax import numpy as jnp from jax.dtypes import prng_key -from jax.lax import scan from jax.scipy.special import ndtr -from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped +from jax.sharding import PartitionSpec +from jaxtyping import Array, Bool, Float32, Key, PyTree, Scalar, Shaped from bartz.jaxext._autobatch import autobatch # noqa: F401 from bartz.jaxext.scipy.special import ndtri -def vmap_nodoc(fun, *args, **kw): +def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable: """ Acts like `jax.vmap` but preserves the docstring of the function unchanged. @@ -62,7 +71,7 @@ def vmap_nodoc(fun, *args, **kw): return fun -def minimal_unsigned_dtype(value): +def minimal_unsigned_dtype(value: int) -> jnp.dtype: """Return the smallest unsigned integer dtype that can represent `value`.""" if value < 2**8: return jnp.uint8 @@ -103,14 +112,16 @@ def unique( return jnp.empty(0, x.dtype), 0 x = jnp.sort(x) - def loop(carry, x): + def loop( + carry: tuple[Scalar, Scalar, Shaped[Array, ' {size}']], x: Scalar + ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' {size}']], None]: i_out, last, out = carry i_out = jnp.where(x == last, i_out, i_out + 1) out = out.at[i_out].set(x) return (i_out, x, out), None carry = 0, x[0], jnp.full(size, fill_value, x.dtype) - (actual_length, _, out), _ = scan(loop, carry, x[:size]) + (actual_length, _, out), _ = lax.scan(loop, carry, x[:size]) return out, actual_length + 1 @@ -136,7 +147,7 @@ class split: _keys: tuple[Key[Array, '*batch'], ...] _num_used: int - def __init__(self, key: Key[Array, '*batch'], num: int = 2): + def __init__(self, key: Key[Array, '*batch'], num: int = 2) -> None: if key.ndim: context = debug_key_reuse(False) else: @@ -147,7 +158,7 @@ def __init__(self, key: Key[Array, '*batch'], num: int = 2): self._keys = _split_unpack(key, num) self._num_used = 0 - def __len__(self): + def __len__(self) -> int: return len(self._keys) - self._num_used def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']: @@ -273,7 +284,7 @@ def truncated_normal_onesided( def get_default_device() -> Device: """Get the current default JAX device.""" with ensure_compile_time_eval(): - return jnp.zeros(()).device + return jnp.empty(0).device def get_device_count() -> int: @@ -285,3 +296,57 @@ def get_device_count() -> int: def is_key(x: object) -> bool: """Determine if `x` is a jax random key.""" return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key) + + +def jit_active() -> bool: + """Check if we are under jit.""" + return not hasattr(jnp.empty(0), 'platform') + + +def _equal_shards(x: Array, axis_name: str) -> Bool[Array, '']: + """Check if all shards of `x` are equal, to be used in a `shard_map` context.""" + # get axis size, this could be `size = lax.axis_size(axis_name)`, but it's + # supported only since jax v0.6.1 + mesh = typeof(x).sharding.mesh + i = mesh.axis_names.index(axis_name) + size = mesh.axis_sizes[i] + + perm = [(i, (i + 1) % size) for i in range(size)] + perm_x = lax.ppermute(x, axis_name, perm) + diff = jnp.any(x != perm_x) + return jnp.logical_not(lax.psum(diff, axis_name)) + + +def equal_shards( + x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any +) -> PyTree[Bool[Array, ''], ' S']: + """Check that all shards of `x` are equal across axis `axis_name`. + + Parameters + ---------- + x + A pytree of arrays to check. Each array is checked separately. + axis_name + The mesh axis name across which equality is checked. It's not checked + across other axes. + **shard_map_kwargs + Additional arguments passed to `jax.shard_map` to set up the function + that checks equality. You may need to specify `in_specs` passing + the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x` + is sharded, if the axes are not explicit, and `mesh` if there is not + a default mesh set by `jax.set_mesh`. + + Returns + ------- + A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis. + """ + equal_shards_leaf = partial(_equal_shards, axis_name=axis_name) + + def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']: + return tree.map(equal_shards_leaf, x) + + sharded_check_equal = shard_map( + check_equal, out_specs=PartitionSpec(), **shard_map_kwargs + ) + + return sharded_check_equal(x) diff --git a/src/bartz/jaxext/_autobatch.py b/src/bartz/jaxext/_autobatch.py index 93d47ebe..aa405803 100644 --- a/src/bartz/jaxext/_autobatch.py +++ b/src/bartz/jaxext/_autobatch.py @@ -27,6 +27,7 @@ import math from collections.abc import Callable from functools import partial, wraps +from typing import Any from warnings import warn from jax.typing import DTypeLike @@ -36,28 +37,24 @@ except ImportError: from numpy.core.numeric import normalize_axis_index # numpy 1 -from jax import ShapeDtypeStruct, eval_shape, jit +from jax import ShapeDtypeStruct, eval_shape, jit, lax, tree from jax import numpy as jnp -from jax.lax import scan -from jax.tree import flatten as tree_flatten -from jax.tree import map as tree_map -from jax.tree import reduce as tree_reduce from jaxtyping import Array, PyTree, Shaped -def expand_axes(axes, tree): - """Expand `axes` such that they match the pytreedef of `tree`.""" +def expand_axes(axes: PyTree[int | None], tree_arg: PyTree) -> PyTree[int | None]: + """Expand `axes` such that they match the pytreedef of `tree_arg`.""" - def expand_axis(axis, subtree): - return tree_map(lambda _: axis, subtree) + def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]: + return tree.map(lambda _: axis, subtree) - return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) + return tree.map(expand_axis, axes, tree_arg, is_leaf=lambda x: x is None) def normalize_axes( - axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T'] + axes: PyTree[int | None, ' T'], tree_arg: PyTree[Array, ' T'] ) -> PyTree[int | None, ' T']: - """Normalize axes to be non-negative and valid for the corresponding arrays in the tree.""" + """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg.""" def normalize_axis(axis: int | None, x: Array) -> int | None: if axis is None: @@ -65,14 +62,14 @@ def normalize_axis(axis: int | None, x: Array) -> int | None: else: return normalize_axis_index(axis, len(x.shape)) - return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None) + return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None) -def check_no_nones(axes, tree): - def check_not_none(_, axis): +def check_no_nones(axes: PyTree[int | None], tree_arg: PyTree) -> None: + def check_not_none(_: object, axis: int | None) -> None: assert axis is not None - tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None) + tree.map(check_not_none, tree_arg, axes, is_leaf=lambda x: x is None) def remove_axis( @@ -85,39 +82,39 @@ def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: new_dtype = reduction_dtype(ufunc, x.dtype) return ShapeDtypeStruct(new_shape, new_dtype) - return tree_map(remove_axis, x, axis) + return tree.map(remove_axis, x, axis) -def extract_size(axes, tree): - """Get the size of each array in tree at the axis in axes, check they are equal and return it.""" +def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int: + """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it.""" - def get_size(x, axis): + def get_size(x: object, axis: int | None) -> int | None: if axis is None: return None else: return x.shape[axis] - sizes = tree_map(get_size, tree, axes) - sizes, _ = tree_flatten(sizes) + sizes = tree.map(get_size, tree_arg, axes) + sizes, _ = tree.flatten(sizes) assert all(s == sizes[0] for s in sizes) return sizes[0] -def sum_nbytes(tree): - def nbytes(x): +def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int: + def nbytes(x: Array | ShapeDtypeStruct) -> int: return math.prod(x.shape) * x.dtype.itemsize - return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) + return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0) -def next_divisor_small(dividend, min_divisor): +def next_divisor_small(dividend: int, min_divisor: int) -> int: for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): if dividend % divisor == 0: return divisor return dividend -def next_divisor_large(dividend, min_divisor): +def next_divisor_large(dividend: int, min_divisor: int) -> int: max_inv_divisor = dividend // min_divisor for inv_divisor in range(max_inv_divisor, 0, -1): if dividend % inv_divisor == 0: @@ -125,8 +122,8 @@ def next_divisor_large(dividend, min_divisor): return dividend -def next_divisor(dividend, min_divisor): - """Return divisor >= min_divisor such that divided % divisor == 0.""" +def next_divisor(dividend: int, min_divisor: int) -> int: + """Return divisor >= min_divisor such that dividend % divisor == 0.""" if dividend == 0: return min_divisor if min_divisor * min_divisor <= dividend: @@ -134,56 +131,60 @@ def next_divisor(dividend, min_divisor): return next_divisor_large(dividend, min_divisor) -def pull_nonbatched(axes, tree): - def pull_nonbatched(x, axis): +def pull_nonbatched( + axes: PyTree[int | None], tree_arg: PyTree +) -> tuple[PyTree, PyTree]: + def pull_nonbatched(x: object, axis: int | None) -> object: if axis is None: return None else: return x - return tree_map(pull_nonbatched, tree, axes), tree + return tree.map(pull_nonbatched, tree_arg, axes), tree_arg -def push_nonbatched(axes, tree, original_tree): - def push_nonbatched(original_x, x, axis): +def push_nonbatched( + axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree +) -> PyTree[Any]: + def push_nonbatched(original_x: object, x: object, axis: int | None) -> object: if axis is None: return original_x else: return x - return tree_map(push_nonbatched, original_tree, tree, axes) + return tree.map(push_nonbatched, original_tree, tree_arg, axes) -def move_axes_out(axes, tree): - def move_axis_out(x, axis): +def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: + def move_axis_out(x: Array, axis: int) -> Array: return jnp.moveaxis(x, axis, 0) - return tree_map(move_axis_out, tree, axes) + return tree.map(move_axis_out, tree_arg, axes) -def move_axes_in(axes, tree): - def move_axis_in(x, axis): +def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: + def move_axis_in(x: Array, axis: int) -> Array: return jnp.moveaxis(x, 0, axis) - return tree_map(move_axis_in, tree, axes) + return tree.map(move_axis_in, tree_arg, axes) -def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: +def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: """Split the first axis into two axes, the first of size `nbatches`.""" - def batch(x): + def batch(x: Array) -> Array: return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) - return tree_map(batch, tree) + return tree.map(batch, tree_arg) -def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']: +def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']: """Merge the first two axes into a single axis.""" - def unbatch(x): + def unbatch(x: Array) -> Array: return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) - return tree_map(unbatch, tree) + return tree.map(unbatch, tree_arg) def reduce( @@ -198,7 +199,7 @@ def reduce( def reduce(x: Array, axis: int) -> Array: return ufunc.reduce(x, axis=axis) - return tree_map(reduce, x, axes) + return tree.map(reduce, x, axes) else: @@ -206,7 +207,7 @@ def reduce(x: Array, initial: Array, axis: int) -> Array: reduced = ufunc.reduce(x, axis=axis) return ufunc(initial, reduced) - return tree_map(reduce, x, initial, axes) + return tree.map(reduce, x, initial, axes) def identity( @@ -218,7 +219,7 @@ def identity(x: ShapeDtypeStruct) -> Array: identity = identity_for(ufunc, x.dtype) return jnp.broadcast_to(identity, x.shape) - return tree_map(identity, x) + return tree.map(identity, x) def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike: @@ -235,12 +236,12 @@ def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']: return jnp.array(ufunc.identity, dtype) -def check_same(tree1, tree2): - def check_same(x1, x2): +def check_same(tree1: PyTree, tree2: PyTree) -> None: + def check_same(x1: Array | ShapeDtypeStruct, x2: Array | ShapeDtypeStruct) -> None: assert x1.shape == x2.shape assert x1.dtype == x2.dtype - tree_map(check_same, tree1, tree2) + tree.map(check_same, tree1, tree2) class NotDefined: @@ -310,7 +311,7 @@ def autobatch( @jit @wraps(func) - def autobatch_wrapper(*args): + def autobatch_wrapper(*args: PyTree) -> PyTree: return batched_func( func, max_io_nbytes, @@ -400,7 +401,7 @@ def batched_func( out_axes=out_axes, reduce_ufunc=reduce_ufunc, ) - reduced_result, result = scan(loop, initial, args) + reduced_result, result = lax.scan(loop, initial, args) # remove auxiliary batching axis and reverse transposition if reduce_ufunc is None: @@ -425,8 +426,15 @@ def batched_func( def batching_loop( - initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc -): + initial: PyTree[Array] | None, + args: PyTree[Array], + *, + func: Callable, + nonbatched_args: PyTree, + in_axes: PyTree[int | None], + out_axes: PyTree[int], + reduce_ufunc: jnp.ufunc | None, +) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]: """Implement the batching loop in `autobatch`.""" # evaluate the function args = move_axes_in(in_axes, args) diff --git a/src/bartz/jaxext/scipy/special.py b/src/bartz/jaxext/scipy/special.py index 96e81f01..ddb3e37e 100644 --- a/src/bartz/jaxext/scipy/special.py +++ b/src/bartz/jaxext/scipy/special.py @@ -1,6 +1,6 @@ # bartz/src/bartz/jaxext/scipy/special.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,29 +24,33 @@ """Mockup of the :external:py:mod:`scipy.special` module.""" +from collections.abc import Callable, Sequence from functools import wraps +from typing import Any from jax import ShapeDtypeStruct, jit, pure_callback from jax import numpy as jnp +from jax.typing import DTypeLike +from jaxtyping import Array, Float from scipy.special import gammainccinv as scipy_gammainccinv -def _float_type(*args): +def _float_type(*args: DTypeLike | Array) -> jnp.dtype: """Determine the jax floating point result type given operands/types.""" t = jnp.result_type(*args) return jnp.sin(jnp.empty(0, t)).dtype -def _castto(func, dtype): +def _castto(func: Callable[..., Array], dtype: DTypeLike) -> Callable[..., Array]: @wraps(func) - def newfunc(*args, **kw): + def newfunc(*args: Any, **kw: Any) -> Array: return func(*args, **kw).astype(dtype) return newfunc @jit -def gammainccinv(a, y): +def gammainccinv(a: Float[Array, '*'], y: Float[Array, '*']) -> Float[Array, '*']: """Survival function inverse of the Gamma(a, 1) distribution.""" shape = jnp.broadcast_shapes(a.shape, y.shape) dtype = _float_type(a.dtype, y.dtype) @@ -74,7 +78,7 @@ def gammainccinv(a, y): from jax import debug_infs, lax -def ndtri(p): +def ndtri(p: Float[Array, '*']) -> Float[Array, '*']: """Compute the inverse of the CDF of the Normal distribution function. This is a patch of `jax.scipy.special.ndtri`. @@ -86,7 +90,7 @@ def ndtri(p): return _ndtri(p) -def _ndtri(p): +def _ndtri(p: Float[Array, '...']) -> Float[Array, '...']: # Constants used in piece-wise rational approximations. Taken from the cephes # library: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html @@ -180,7 +184,9 @@ def _ndtri(p): dtype = lax.dtype(p).type shape = jnp.shape(p) - def _create_polynomial(var, coeffs): + def _create_polynomial( + var: Float[Array, '...'], coeffs: Sequence[float] + ) -> Float[Array, '...']: """Compute n_th order polynomial via Horner's method.""" coeffs = np.array(coeffs, dtype) if not coeffs.size: diff --git a/src/bartz/jaxext/scipy/stats.py b/src/bartz/jaxext/scipy/stats.py index cf1750d6..cac69680 100644 --- a/src/bartz/jaxext/scipy/stats.py +++ b/src/bartz/jaxext/scipy/stats.py @@ -1,6 +1,6 @@ # bartz/src/bartz/jaxext/scipy/stats.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,6 +24,8 @@ """Mockup of the :external:py:mod:`scipy.stats` module.""" +from jaxtyping import Array, Float + from bartz.jaxext.scipy.special import gammainccinv @@ -31,6 +33,6 @@ class invgamma: """Class that represents the distribution InvGamma(a, 1).""" @staticmethod - def ppf(q, a): + def ppf(q: Float[Array, '*'], a: Float[Array, '*']) -> Float[Array, '*']: """Percentile point function.""" return 1 / gammainccinv(a, q) diff --git a/src/bartz/mcmcloop.py b/src/bartz/mcmcloop.py index 7a58da04..70ef7358 100644 --- a/src/bartz/mcmcloop.py +++ b/src/bartz/mcmcloop.py @@ -31,7 +31,7 @@ from dataclasses import fields from functools import partial, wraps from math import floor -from typing import Any, Protocol +from typing import Any, NamedTuple, Protocol, TypeVar import jax import numpy @@ -48,17 +48,28 @@ from jax import numpy as jnp from jax.nn import softmax from jax.sharding import Mesh, PartitionSpec -from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt +from jaxtyping import ( + Array, + ArrayLike, + Bool, + Float32, + Int32, + Integer, + Key, + PyTree, + Shaped, + UInt, +) from bartz import jaxext, mcmcstep from bartz._profiler import ( cond_if_not_profiling, get_profile_mode, jit_if_not_profiling, - scan_if_not_profiling, + while_loop_if_not_profiling, ) from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram -from bartz.jaxext import autobatch +from bartz.jaxext import autobatch, jit_active from bartz.mcmcstep import State from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains @@ -129,6 +140,23 @@ def from_state(cls, state: State) -> 'MainTrace': CallbackState = PyTree[Any, 'T'] +class RunMCMCResult(NamedTuple): + """Return value of `run_mcmc`.""" + + final_state: State + """The final MCMC state.""" + + burnin_trace: PyTree[ + Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] + ] + """The trace of the burn-in phase. For the default layout, see `BurninTrace`.""" + + main_trace: PyTree[ + Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] + ] + """The trace of the main phase. For the default layout, see `MainTrace`.""" + + class Callback(Protocol): """Callback type for `run_mcmc`.""" @@ -139,13 +167,12 @@ def __call__( bart: State, burnin: Bool[Array, ''], i_total: Int32[Array, ''], - i_skip: Int32[Array, ''], callback_state: CallbackState, n_burn: Int32[Array, ''], n_save: Int32[Array, ''], n_skip: Int32[Array, ''], i_outer: Int32[Array, ''], - inner_loop_length: int, + inner_loop_length: Int32[Array, ''], ) -> tuple[State, CallbackState] | None: """Do an arbitrary action after an iteration of the MCMC. @@ -159,9 +186,6 @@ def __call__( Whether the last iteration was in the burn-in phase. i_total The index of the last MCMC iteration (0-based). - i_skip - The number of MCMC updates from the last saved state. The initial - state counts as saved, even if it's not copied into the trace. callback_state The callback state, initially set to the argument passed to `run_mcmc`, afterwards to the value returned by the last invocation @@ -218,11 +242,7 @@ def run_mcmc( callback_state: CallbackState = None, burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, main_extractor: Callable[[State], PyTree] = MainTrace.from_state, -) -> tuple[ - State, - PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']], - PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']], -]: +) -> RunMCMCResult: """ Run the MCMC for the BART posterior. @@ -266,12 +286,7 @@ def run_mcmc( Returns ------- - bart : State - The final MCMC state. - burnin_trace : PyTree[Shaped[Array, 'n_burn *']] - The trace of the burn-in phase. For the default layout, see `BurninTrace`. - main_trace : PyTree[Shaped[Array, 'n_save *']] - The trace of the main phase. For the default layout, see `MainTrace`. + A namedtuple with the final state, the burn-in trace, and the main trace. Raises ------ @@ -300,8 +315,7 @@ def run_mcmc( # same code path for benchmarking and testing # error if under jit and there are unrolled loops or profile mode is on - under_jit = not hasattr(jnp.empty(0), 'platform') - if under_jit and (n_outer > 1 or get_profile_mode()): + if jit_active() and (n_outer > 1 or get_profile_mode()): msg = ( '`run_mcmc` was called within a jit-compiled function and ' 'there are either more than 1 outer loops or profile mode is active, ' @@ -334,7 +348,7 @@ def run_mcmc( n_iters, ) - return carry.bart, carry.burnin_trace, carry.main_trace + return RunMCMCResult(carry.bart, carry.burnin_trace, carry.main_trace) def _replicate(x: Array, mesh: Mesh | None) -> Array: @@ -360,24 +374,13 @@ def _empty_trace( return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) -@jit -def _compute_i_skip( - i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, ''] -) -> Int32[Array, '']: - """Compute the `i_skip` argument passed to `callback`.""" - burnin = i_total < n_burn - return jnp.where( - burnin, - i_total + 1, - (i_total - n_burn + 1) % n_skip - + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0), - ) +T = TypeVar('T') class _CallCounter: """Wrap a callable to check it's not called more than once.""" - def __init__(self, func: Callable) -> None: + def __init__(self, func: Callable[..., T]) -> None: self.func = func self.n_calls = 0 @@ -385,7 +388,7 @@ def reset_call_counter(self) -> None: """Reset the call counter.""" self.n_calls = 0 - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> T: if self.n_calls and not get_profile_mode(): msg = ( 'The inner loop of `run_mcmc` was traced more than once, ' @@ -400,11 +403,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.func(*args, **kwargs) -@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) +@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(2, 3, 4)) @_CallCounter def _run_mcmc_inner_loop( carry: _Carry, - inner_loop_length: int, + inner_loop_length: Int32[Array, ''], callback: Callback | None, burnin_extractor: Callable[[State], PyTree], main_extractor: Callable[[State], PyTree], @@ -414,8 +417,15 @@ def _run_mcmc_inner_loop( i_outer: Int32[Array, ''], n_iters: Int32[Array, ''], ) -> _Carry: - def loop_impl(carry: _Carry) -> _Carry: - """Loop body to run if i_total < n_iters.""" + # determine number of iterations for this loop batch + i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) + + def cond(carry: _Carry) -> Bool[Array, '']: + """Whether to continue the MCMC loop.""" + return carry.i_total < i_upper + + def body(carry: _Carry) -> _Carry: + """Update the MCMC state.""" # split random key keys = jaxext.split(carry.key, 3) key = keys.pop() @@ -426,13 +436,11 @@ def loop_impl(carry: _Carry) -> _Carry: # invoke callback callback_state = carry.callback_state if callback is not None: - i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) rt = callback( key=keys.pop(), bart=bart, burnin=carry.i_total < n_burn, i_total=carry.i_total, - i_skip=i_skip, callback_state=callback_state, n_burn=n_burn, n_save=n_save, @@ -464,18 +472,7 @@ def loop_impl(carry: _Carry) -> _Carry: callback_state=callback_state, ) - def loop_noop(carry: _Carry) -> _Carry: - """Loop body to run if i_total >= n_iters; it does nothing.""" - return carry - - def loop(carry: _Carry, _) -> tuple[_Carry, None]: - carry = cond_if_not_profiling( - carry.i_total < n_iters, loop_impl, loop_noop, carry - ) - return carry, None - - carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length) - return carry + return while_loop_if_not_profiling(cond, body, carry) @partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3)) @@ -522,11 +519,12 @@ def _set( def at_set( trace: Shaped[Array, 'chains samples *shape'] + | None | Shaped[Array, ' samples *shape'] | None, val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None, chain_axis: int | None, - ): + ) -> Shaped[Array, 'chains samples *shape'] | None: if trace is None or trace.size == 0: # this handles the case where an array is empty because jax refuses # to index into an axis of length 0, even if just in the abstract, @@ -576,7 +574,7 @@ def make_default_callback( >>> run_mcmc(key, state, ..., **make_default_callback(state, ...)) """ - def as_replicated_array_or_none(val: None | Any) -> None | Array: + def as_replicated_array_or_none(val: ArrayLike | None) -> None | Array: return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) return dict( @@ -608,8 +606,8 @@ def print_callback( n_save: Int32[Array, ''], n_skip: Int32[Array, ''], callback_state: PrintCallbackState, - **_, -): + **_: Any, +) -> None: """Print a dot and/or a report periodically during the MCMC.""" report_every = callback_state.report_every dot_every = callback_state.dot_every @@ -621,7 +619,7 @@ def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: report_cond = get_cond(report_every) dot_cond = get_cond(dot_every) - def line_report_branch(): + def line_report_branch() -> None: if report_every is None: return if dot_every is None: @@ -643,7 +641,7 @@ def line_report_branch(): fill=forest_fill(bart.forest.split_tree), ) - def just_dot_branch(): + def just_dot_branch() -> None: if dot_every is None: return debug.callback( @@ -658,7 +656,7 @@ def just_dot_branch(): ) -def _convert_jax_arrays_in_args(func: Callable) -> Callable: +def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]: """Remove jax arrays from a function arguments. Converts all `jax.Array` instances in the arguments to either Python scalars @@ -666,7 +664,7 @@ def _convert_jax_arrays_in_args(func: Callable) -> Callable: """ def convert_jax_arrays(pytree: PyTree) -> PyTree: - def convert_jax_array(val: Any) -> Any: + def convert_jax_array(val: object) -> object: if not isinstance(val, Array): return val elif val.shape: @@ -677,7 +675,7 @@ def convert_jax_array(val: Any) -> Any: return tree.map(convert_jax_array, pytree) @wraps(func) - def new_func(*args, **kw): + def new_func(*args: Any, **kw: Any) -> T: args = convert_jax_arrays(args) kw = convert_jax_arrays(kw) return func(*args, **kw) @@ -701,7 +699,7 @@ def _print_report( prune_acc_count: float, prop_total: int, fill: float, -): +) -> None: """Print the report for `print_callback`.""" # compute fractions grow_prop = grow_prop_count / prop_total @@ -748,7 +746,7 @@ class TreesTrace(Module): split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] @classmethod - def from_dataclass(cls, obj: TreeHeaps): + def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace': """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) diff --git a/src/bartz/mcmcstep/_moves.py b/src/bartz/mcmcstep/_moves.py index e81c0c61..d071c0da 100644 --- a/src/bartz/mcmcstep/_moves.py +++ b/src/bartz/mcmcstep/_moves.py @@ -616,7 +616,9 @@ def randint_exclude( return u, num_allowed -def _process_exclude(sup, exclude): +def _process_exclude( + sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] +) -> tuple[Integer[Array, ' n'], Integer[Array, '']]: exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) num_allowed = sup - jnp.sum(exclude < sup) return exclude, num_allowed diff --git a/src/bartz/mcmcstep/_state.py b/src/bartz/mcmcstep/_state.py index 9772dbac..8275aa6a 100644 --- a/src/bartz/mcmcstep/_state.py +++ b/src/bartz/mcmcstep/_state.py @@ -25,19 +25,20 @@ """Module defining the BART MCMC state and initialization.""" from collections.abc import Callable, Hashable -from dataclasses import fields +from dataclasses import fields, replace from functools import partial, wraps from math import log2 from typing import Any, Literal, TypedDict, TypeVar import numpy -from equinox import Module, error_if +from equinox import Module, error_if, filter_jit from equinox import field as eqx_field from jax import ( NamedSharding, device_put, eval_shape, jit, + lax, make_mesh, random, tree, @@ -46,14 +47,13 @@ from jax import numpy as jnp from jax.scipy.linalg import solve_triangular from jax.sharding import AxisType, Mesh, PartitionSpec -from jax.tree import flatten from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt -from bartz.grove import make_tree, tree_depths +from bartz.grove import tree_depths from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype -def field(*, chains: bool = False, data: bool = False, **kwargs): +def field(*, chains: bool = False, data: bool = False, **kwargs: Any): # noqa: ANN202 """Extend `equinox.field` with two new parameters. Parameters @@ -117,27 +117,34 @@ def _find_metadata( x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T ) -> PyTree[T, ' S']: """Replace all subtrees of x marked with a metadata key.""" - if isinstance(x, Module): + + def is_lazy_array(x: object) -> bool: + return isinstance(x, _LazyArray) + + def is_module(x: object) -> bool: + return isinstance(x, Module) and not is_lazy_array(x) + + if is_module(x): args = [] for f in fields(x): v = getattr(x, f.name) if f.metadata.get('static', False): args.append(v) elif f.metadata.get(key, False): - subtree = tree.map(lambda _: if_true, v) + subtree = tree.map(lambda _: if_true, v, is_leaf=is_lazy_array) args.append(subtree) else: args.append(_find_metadata(v, key, if_true, if_false)) return x.__class__(*args) - def is_leaf(x) -> bool: - return isinstance(x, Module) - - def get_axes(x: Module | Any) -> PyTree[T]: - if isinstance(x, Module): + def get_axes(x: object) -> PyTree[T]: + if is_module(x): return _find_metadata(x, key, if_true, if_false) else: - return tree.map(lambda _: if_false, x) + return tree.map(lambda _: if_false, x, is_leaf=is_lazy_array) + + def is_leaf(x: object) -> bool: + return isinstance(x, Module) # this catches _LazyArray as well return tree.map(get_axes, x, is_leaf=is_leaf) @@ -413,7 +420,9 @@ def _parse_p_nonterminal( def make_p_nonterminal( - d: int, alpha: float | Float32[Array, ''], beta: float | Float32[Array, ''] + d: int, + alpha: float | Float32[Array, ''] = 0.95, + beta: float | Float32[Array, ''] = 2.0, ) -> Float32[Array, ' {d}-1']: """Prepare the `p_nonterminal` argument to `init`. @@ -441,6 +450,28 @@ def make_p_nonterminal( return alpha / (1 + depth).astype(float) ** beta +class _LazyArray(Module): + """Like `functools.partial` but specialized to array-creating functions like `jax.numpy.zeros`.""" + + array_creator: Callable + shape: tuple[int, ...] + args: tuple + + def __init__( + self, array_creator: Callable, shape: tuple[int, ...], *args: Any + ) -> None: + self.array_creator = array_creator + self.shape = shape + self.args = args + + def __call__(self, **kwargs: Any) -> T: + return self.array_creator(self.shape, *self.args, **kwargs) + + @property + def ndim(self) -> int: + return len(self.shape) + + def init( *, X: UInt[Any, 'p n'], @@ -634,7 +665,6 @@ def init( # process multichain settings chain_shape = () if num_chains is None else (num_chains,) resid_shape = chain_shape + y.shape - tree_shape = (*chain_shape, num_trees) add_chains = partial(_add_chains, chain_shape=chain_shape) # determine batch sizes for reductions @@ -660,51 +690,56 @@ def init( ) offset = error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) + # determine shapes for trees + tree_shape = (*chain_shape, num_trees) + tree_size = 2**max_depth + # initialize all remaining stuff and put it in an unsharded state state = State( X=X, y=y, - z=jnp.full(resid_shape, offset) if is_binary else None, + z=_LazyArray(jnp.full, resid_shape, offset) if is_binary else None, offset=offset, - resid=jnp.zeros(resid_shape) + resid=_LazyArray(jnp.zeros, resid_shape) if is_binary - else jnp.broadcast_to(y - offset[..., None], resid_shape), + else None, # in this case, resid is created later after y and offset are sharded error_cov_inv=add_chains(error_cov_inv), - prec_scale=_get_prec_scale(error_scale), + prec_scale=error_scale, # temporarily set to error_scale, fix after sharding error_cov_df=error_cov_df, error_cov_scale=error_cov_scale, forest=Forest( - leaf_tree=make_tree(max_depth, jnp.float32, tree_shape + kshape), - var_tree=make_tree( - max_depth - 1, minimal_unsigned_dtype(p - 1), tree_shape + leaf_tree=_LazyArray( + jnp.zeros, (*tree_shape, *kshape, tree_size), jnp.float32 + ), + var_tree=_LazyArray( + jnp.zeros, (*tree_shape, tree_size // 2), minimal_unsigned_dtype(p - 1) ), - split_tree=make_tree(max_depth - 1, max_split.dtype, tree_shape), - affluence_tree=( - make_tree(max_depth - 1, bool, tree_shape) - .at[..., 1] - .set( - True - if min_points_per_decision_node is None - else n >= min_points_per_decision_node - ) + split_tree=_LazyArray( + jnp.zeros, (*tree_shape, tree_size // 2), max_split.dtype + ), + affluence_tree=_LazyArray( + _initial_affluence_tree, + (*tree_shape, tree_size // 2), + n, + min_points_per_decision_node, ), blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split), max_split=max_split, - grow_prop_count=jnp.zeros(chain_shape, int), - grow_acc_count=jnp.zeros(chain_shape, int), - prune_prop_count=jnp.zeros(chain_shape, int), - prune_acc_count=jnp.zeros(chain_shape, int), - p_nonterminal=p_nonterminal[tree_depths(2**max_depth)], - p_propose_grow=p_nonterminal[tree_depths(2 ** (max_depth - 1))], - leaf_indices=jnp.ones( - (*tree_shape, n), minimal_unsigned_dtype(2**max_depth - 1) + grow_prop_count=_LazyArray(jnp.zeros, chain_shape, int), + grow_acc_count=_LazyArray(jnp.zeros, chain_shape, int), + prune_prop_count=_LazyArray(jnp.zeros, chain_shape, int), + prune_acc_count=_LazyArray(jnp.zeros, chain_shape, int), + p_nonterminal=p_nonterminal[tree_depths(tree_size)], + p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)], + leaf_indices=_LazyArray( + jnp.ones, (*tree_shape, n), minimal_unsigned_dtype(tree_size - 1) ), min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), min_points_per_leaf=_asarray_or_none(min_points_per_leaf), - log_trans_prior=jnp.zeros((*chain_shape, num_trees)) + log_trans_prior=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) if save_ratios else None, - log_likelihood=jnp.zeros((*chain_shape, num_trees)) + log_likelihood=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) if save_ratios else None, leaf_prior_cov_inv=leaf_prior_cov_inv, @@ -722,23 +757,62 @@ def init( ), ) + # delete big input arrays such that they can be deleted as soon as they + # are sharded, only those arrays that contain an (n,) sized axis + del X, y, error_scale + # move all arrays to the appropriate device - return _shard_state(state) + state = _shard_state(state) + + # calculate initial resid in the continuous outcome case, such that y and + # offset are already sharded if needed + if state.resid is None: + resid = _LazyArray(_initial_resid, resid_shape, state.y, state.offset) + resid = _shard_leaf(resid, 0, -1, state.config.mesh) + state = replace(state, resid=resid) + + # calculate prec_scale after sharding to do the calculation on the right + # devices + if state.prec_scale is not None: + prec_scale = _compute_prec_scale(state.prec_scale) + state = replace(state, prec_scale=prec_scale) + + # make all types strong to avoid unwanted recompilations + return _remove_weak_types(state) + + +def _initial_resid( + shape: tuple[int, ...], + y: Float32[Array, ' n'] | Float32[Array, 'k n'], + offset: Float32[Array, ''] | Float32[Array, ' k'], +) -> Float32[Array, ' n'] | Float32[Array, 'k n']: + """Calculate the initial value for `State.resid` in the continuous outcome case.""" + return jnp.broadcast_to(y - offset[..., None], shape) + + +def _initial_affluence_tree( + shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None +) -> Array: + """Create the initial value of `Forest.affluence_tree`.""" + return ( + jnp.zeros(shape, bool) + .at[..., 1] + .set( + True + if min_points_per_decision_node is None + else n >= min_points_per_decision_node + ) + ) @partial(jit, donate_argnums=(0,)) -def _get_prec_scale( - error_scale: Float32[Array, ' n'] | None, -) -> Float32[Array, ' n'] | None: +def _compute_prec_scale(error_scale: Float32[Array, ' n']) -> Float32[Array, ' n']: """Compute 1 / error_scale**2. This is a separate function to use donate_argnums to avoid intermediate copies. """ - if error_scale is None: - return None - else: - return jnp.reciprocal(jnp.square(jnp.asarray(error_scale))) + return jnp.reciprocal(jnp.square(error_scale)) def _get_blocked_vars( @@ -827,18 +901,49 @@ def _auto_axes(mesh: Mesh) -> list[str]: ] +@partial(filter_jit, donate='all') +# jit and donate because otherwise type conversion would create copies +def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']: + """Make all types strong. + + This is to avoid recompilation in `run_mcmc` or `step`. + """ + + def remove_weak(x: T) -> T: + if isinstance(x, Array) and x.weak_type: + return x.astype(x.dtype) + else: + return x + + return tree.map(remove_weak, x) + + def _shard_state(state: State) -> State: - """Place all fields in the state on the appropriate devices.""" + """Place all arrays on the appropriate devices, and instantiate lazily defined arrays.""" mesh = state.config.mesh - if mesh is None: - return state + shard_leaf = partial(_shard_leaf, mesh=mesh) + return tree.map( + shard_leaf, + state, + chain_vmap_axes(state), + data_vmap_axes(state), + is_leaf=lambda x: x is None or isinstance(x, _LazyArray), + ) - def shard_leaf( - x: Array | None, chain_axis: int | None, data_axis: int | None - ) -> Array | None: - if x is None: - return None +def _shard_leaf( + x: Array | None | _LazyArray, + chain_axis: int | None, + data_axis: int | None, + mesh: Mesh | None, +) -> Array | None: + """Create `x` if it's lazy and shard it.""" + if x is None: + return None + + if mesh is None: + sharding = None + else: spec = [None] * x.ndim if chain_axis is not None and 'chains' in mesh.axis_names: spec[chain_axis] = 'chains' @@ -851,23 +956,34 @@ def shard_leaf( spec.pop() spec = PartitionSpec(*spec) - return device_put(x, NamedSharding(mesh, spec), donate=True) + sharding = NamedSharding(mesh, spec) - return tree.map( - shard_leaf, - state, - chain_vmap_axes(state), - data_vmap_axes(state), - is_leaf=lambda x: x is None, - ) + if isinstance(x, _LazyArray): + x = _concretize_lazy_array(x, sharding) + elif sharding is not None: + x = device_put(x, sharding, donate=True) + + return x + + +@filter_jit +# jit such that in recent jax versions the shards are created on the right +# devices immediately instead of being created on the wrong device and then +# copied +def _concretize_lazy_array(x: _LazyArray, sharding: NamedSharding | None) -> Array: + """Create an array from an abstract spec on the appropriate devices.""" + x = x() + if sharding is not None: + x = lax.with_sharding_constraint(x, sharding) + return x -def _all_none_or_not_none(*args): +def _all_none_or_not_none(*args: object) -> bool: is_none = [x is None for x in args] return all(is_none) or not any(is_none) -def _asarray_or_none(x): +def _asarray_or_none(x: object) -> Array | None: if x is None: return None return jnp.asarray(x) @@ -1026,7 +1142,7 @@ def get_num_chains(x: PyTree) -> int | None: traversal at nodes that define it. Check all values obtained invoking `num_chains` are equal, then return it. """ - leaves, _ = flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) + leaves, _ = tree.flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] ref = num_chains[0] assert all(c == ref for c in num_chains) @@ -1037,7 +1153,7 @@ def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]: """Return `chain_vmap_axes(x)` but also set to 0 for random keys.""" axes = chain_vmap_axes(x) - def axis_if_key(x, axis): + def axis_if_key(x: object, axis: int | None) -> int | None: if is_key(x): return 0 else: @@ -1061,7 +1177,7 @@ def _find_mesh(x: PyTree) -> Mesh | None: class MeshFound(Exception): pass - def find_mesh(x: State | Any): + def find_mesh(x: object) -> None: if isinstance(x, State): raise MeshFound(x.config.mesh) @@ -1077,7 +1193,7 @@ def _split_all_keys(x: PyTree, num_chains: int) -> PyTree: """Split all random keys in `num_chains` keys.""" mesh = _find_mesh(x) - def split_key(x): + def split_key(x: object) -> object: if is_key(x): x = random.split(x, num_chains) if mesh is not None and 'chains' in mesh.axis_names: @@ -1093,14 +1209,14 @@ def vmap_chains( """Apply vmap on chain axes automatically if the inputs are multichain.""" @wraps(fun) - def auto_vmapped_fun(*args, **kwargs) -> T: + def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T: all_args = args, kwargs num_chains = get_num_chains(all_args) if num_chains is not None: if auto_split_keys: all_args = _split_all_keys(all_args, num_chains) - def wrapped_fun(args, kwargs): + def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: return fun(*args, **kwargs) mc_in_axes = _chain_axes_with_keys(all_args) diff --git a/src/bartz/mcmcstep/_step.py b/src/bartz/mcmcstep/_step.py index 899d7bb8..52638c83 100644 --- a/src/bartz/mcmcstep/_step.py +++ b/src/bartz/mcmcstep/_step.py @@ -38,7 +38,6 @@ from equinox import Module, tree_at from jax import lax, random, vmap from jax import numpy as jnp -from jax.lax import cond from jax.scipy.linalg import solve_triangular from jax.scipy.special import gammaln, logsumexp from jax.sharding import Mesh, PartitionSpec @@ -445,7 +444,9 @@ def _compute_count_or_prec_trees( compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None)) return compute(prec_scale, leaf_indices, moves, config) - def compute(args): + def compute( + args: tuple[UInt[Array, ' n'], Moves], + ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]: leaf_indices, moves = args return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config) @@ -644,8 +645,8 @@ def _precompute_likelihood_terms_uv( leaf_prior_cov_inv: Float32[Array, ''], move_precs: Precs | Counts, ) -> tuple[PreLkV, PreLk]: - sigma2 = lax.reciprocal(error_cov_inv) - sigma_mu2 = lax.reciprocal(leaf_prior_cov_inv) + sigma2 = jnp.reciprocal(error_cov_inv) + sigma_mu2 = jnp.reciprocal(leaf_prior_cov_inv) left = sigma2 + move_precs.left * sigma_mu2 right = sigma2 + move_precs.right * sigma_mu2 total = sigma2 + move_precs.total * sigma_mu2 @@ -752,7 +753,7 @@ def _precompute_leaf_terms_uv( z: Float32[Array, 'num_trees 2**d'] | None = None, ) -> PreLf: prec_lk = prec_trees * error_cov_inv - var_post = lax.reciprocal(prec_lk + leaf_prior_cov_inv) + var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) if z is None: z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) return PreLf( @@ -885,7 +886,17 @@ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: The accepted/rejected moves, with `acc` and `to_prune` set. """ - def loop(resid, pt): + def loop( + resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree + ) -> tuple[ + Float32[Array, ' n'] | Float32[Array, ' k n'], + tuple[ + Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], + Bool[Array, ''], + Bool[Array, ''], + Float32[Array, ''] | None, + ], + ]: resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( resid, SeqStageInAllTrees( @@ -1142,7 +1153,7 @@ def _scatter_add( return _scatter_add(values, indices) -def _get_shard_map_patch_kwargs(): +def _get_shard_map_patch_kwargs() -> dict[str, bool]: # see jax/issues/#34249, problem with vmap(shard_map(psum)) # we tried the config jax_disable_vmap_shmap_error but it didn't work if jax.__version__ in ('0.8.1', '0.8.2'): @@ -1201,7 +1212,9 @@ def _compute_likelihood_ratio_mv( right_resid: Float32[Array, ' k'], prelkv: PreLkV, ) -> Float32[Array, '']: - def _quadratic_form(r, mat): + def _quadratic_form( + r: Float32[Array, ' k'], mat: Float32[Array, 'k k'] + ) -> Float32[Array, '']: return r @ mat @ r qf_left = _quadratic_form(left_resid, prelkv.left) @@ -1577,7 +1590,7 @@ def step_sparse(key: Key[Array, ''], bart: State) -> State: Updated BART state with re-sampled `log_s` and `theta`. """ if bart.config.sparse_on_at is not None: - bart = cond( + bart = lax.cond( bart.config.steps_done < bart.config.sparse_on_at, lambda _key, bart: bart, _step_sparse, @@ -1587,7 +1600,7 @@ def step_sparse(key: Key[Array, ''], bart: State) -> State: return bart -def _step_sparse(key, bart): +def _step_sparse(key: Key[Array, ''], bart: State) -> State: keys = split(key) bart = step_s(keys.pop(), bart) if bart.forest.rho is not None: @@ -1597,7 +1610,7 @@ def _step_sparse(key, bart): @jit_if_profiling # jit to avoid the overhead of replace(_: Module) -def step_config(bart): +def step_config(bart: State) -> State: config = bart.config config = replace(config, steps_done=config.steps_done + 1) return replace(bart, config=config) diff --git a/src/bartz/prepcovars.py b/src/bartz/prepcovars.py index af9ef8f7..fc238202 100644 --- a/src/bartz/prepcovars.py +++ b/src/bartz/prepcovars.py @@ -1,6 +1,6 @@ # bartz/src/bartz/prepcovars.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -25,6 +25,7 @@ """Functions to preprocess data.""" from functools import partial +from typing import Any from jax import jit, vmap from jax import numpy as jnp @@ -100,7 +101,9 @@ def quantilized_splits_from_matrix( raise ValueError(msg) @partial(autobatch, max_io_nbytes=2**29) - def quantilize(X): + def quantilize( + X: Real[Array, 'p n'], + ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: # wrap this function because autobatch needs traceable args return _quantilized_splits_from_matrix(X, out_length) @@ -221,7 +224,7 @@ def uniform_splits_from_matrix( @partial(jit, static_argnames=('method',)) def bin_predictors( - X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw + X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw: Any ) -> UInt[Array, 'p n']: """ Bin the predictors according to the given splits. @@ -247,7 +250,9 @@ def bin_predictors( @partial(autobatch, max_io_nbytes=2**29) @vmap - def bin_predictors(x, splits): + def bin_predictors( + x: Real[Array, 'p n'], splits: Real[Array, 'p m'] + ) -> UInt[Array, 'p n']: dtype = minimal_unsigned_dtype(splits.size) return jnp.searchsorted(splits, x, **kw).astype(dtype) diff --git a/src/bartz/testing/_dgp.py b/src/bartz/testing/_dgp.py index 6857f6e3..bad54183 100644 --- a/src/bartz/testing/_dgp.py +++ b/src/bartz/testing/_dgp.py @@ -26,10 +26,11 @@ """Define `gen_data` that generates simulated data for testing.""" from dataclasses import replace +from functools import partial from equinox import Module, error_if +from jax import jit, random from jax import numpy as jnp -from jax import random from jaxtyping import Array, Bool, Float, Int, Integer, Key from bartz.jaxext import split @@ -233,14 +234,14 @@ def generate_outcome( class DGP(Module): - """Quadratic multivariate DGP. + """Output of `gen_data`. Parameters ---------- x Predictors of shape (p, n), variance 1 y - Noisy outcomes of shape (k, n) + Noisy outcomes of shape (k, n) or (n,) partition Predictor-outcome assignment partition of shape (k, p) beta_shared @@ -279,7 +280,7 @@ class DGP(Module): # Main outputs x: Float[Array, 'p n'] - y: Float[Array, 'k n'] + y: Float[Array, 'k n'] | Float[Array, ' n'] # Intermediate results partition: Bool[Array, 'k p'] @@ -320,7 +321,17 @@ def sigma2_mean(self) -> Float[Array, '']: return self.sigma2_quad / (self.kurt_x - 1 + self.q) def split(self, n_train: int | None = None) -> tuple['DGP', 'DGP']: - """Split the data into training and test sets.""" + """Split the data into training and test sets. + + Parameters + ---------- + n_train + Number of training observations. If None, split in half. + + Returns + ------- + Two `DGP` object with the train and test splits. + """ if n_train is None: n_train = self.x.shape[1] // 2 assert 0 < n_train < self.x.shape[1], 'n_train must be in (0, n)' @@ -351,12 +362,13 @@ def split(self, n_train: int | None = None) -> tuple['DGP', 'DGP']: return train, test +@partial(jit, static_argnames=('n', 'p', 'k')) def gen_data( key: Key[Array, ''], *, n: int, p: int, - k: int, + k: int | None = None, q: Integer[Array, ''] | int, lam: Float[Array, ''] | float, sigma2_lin: Float[Array, ''] | float, @@ -390,20 +402,18 @@ def gen_data( ------- An object with all generated data and parameters. """ + squeeze = k is None + if squeeze: + k = 1 + assert p >= k, 'p must be at least k' # check q - q = jnp.asarray(q) q = error_if(q, q % 2 != 0, 'q must be even') q = error_if(q, q >= p // k, 'q must be less than p // k') keys = split(key, 7) - lam = jnp.asarray(lam) - sigma2_lin = jnp.asarray(sigma2_lin) - sigma2_quad = jnp.asarray(sigma2_quad) - sigma2_eps = jnp.asarray(sigma2_eps) - x = generate_x(keys.pop(), n, p) partition = generate_partition(keys.pop(), p, k) beta_shared = generate_beta_shared(keys.pop(), p, sigma2_lin) @@ -418,6 +428,8 @@ def gen_data( muquad = combine_muquad(muquad_shared, muquad_separate, lam) mu = mulin + muquad y = generate_outcome(keys.pop(), mu, sigma2_eps) + if squeeze: + y = y.squeeze(0) return DGP( x=x, diff --git a/tests/conftest.py b/tests/conftest.py index 878f4416..283e40b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,23 +34,29 @@ from bartz.jaxext import get_default_device, split -# enable debug checks; these slow down unit tests +# enable debug checks; some of these slow down unit tests config.update('jax_debug_key_reuse', True) config.update('jax_debug_nans', True) config.update('jax_debug_infs', True) config.update('jax_legacy_prng_key', 'error') +if jax.__version_info__ >= (0, 9, 0): + config.update('jax_check_static_indices', True) + config.update('jax_explicit_x64_dtypes', 'error') -# enable virtual cpu devices to do multi-device testing on cpu -config.update('jax_num_cpu_devices', 10) # 2 * 5 +# enable logging arrays destroyed by the gc +config.update('jax_array_garbage_collection_guard', 'log') # enable compilation cache -config.update('jax_compilation_cache_dir', 'config/jax_cache') -config.update('jax_persistent_cache_min_entry_size_bytes', -1) -config.update('jax_persistent_cache_min_compile_time_secs', 0.1) +if jax.__version_info__ >= (0, 9, 0): + # enable only on latest jax because `make tests-old` fails if there is a + # cache created with a newer jax version + config.update('jax_compilation_cache_dir', 'config/jax_cache') + config.update('jax_persistent_cache_min_entry_size_bytes', -1) + config.update('jax_persistent_cache_min_compile_time_secs', 0.1) @pytest.fixture -def keys(request) -> split: +def keys(request: pytest.FixtureRequest) -> split: """ Return a deterministic per-test-case list of jax random keys. @@ -76,10 +82,28 @@ def pytest_addoption(parser: pytest.Parser) -> None: default='auto', help='JAX platform to use: cpu, gpu, or auto (default: auto)', ) + parser.addoption( + '--num-cpu-devices', + type=int, + default=10, + help='Number of virtual jax cpu devices to create (default: 10)', + ) def pytest_sessionstart(session: pytest.Session) -> None: - """Configure and print the jax device.""" + """Customizable jax setup.""" + setup_jax_num_cpu_devices(session) + setup_jax_platform(session) + + +def setup_jax_num_cpu_devices(session: pytest.Session) -> None: + """Configure the number of virtual jax cpu devices.""" + num_cpu_devices = session.config.getoption('--num-cpu-devices') + config.update('jax_num_cpu_devices', num_cpu_devices) + + +def setup_jax_platform(session: pytest.Session) -> None: + """Configure, check, and log the default jax platform.""" # Get the platform option platform = session.config.getoption('--platform') @@ -100,5 +124,6 @@ def pytest_sessionstart(session: pytest.Session) -> None: ctx = nullcontext() with ctx: - device_kind = get_default_device().device_kind - print(f'jax default device: {device_kind}') + dd = get_default_device() + num_devices = len(jax.devices(dd.platform)) + print(f'jax default device: {dd.device_kind}, num devices: {num_devices}') diff --git a/tests/rbartpackages/BART.py b/tests/rbartpackages/BART.py index 53c1957a..463b5a2c 100644 --- a/tests/rbartpackages/BART.py +++ b/tests/rbartpackages/BART.py @@ -1,6 +1,6 @@ # bartz/tests/rbartpackages/BART.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,6 +24,8 @@ """Wrapper for the R package BART.""" +# ruff: noqa: ANN002, ANN003 + from typing import NamedTuple, TypedDict, cast import numpy as np @@ -87,7 +89,7 @@ class mc_gbart(RObjectBase): yhat_train: Float64[ndarray, 'ndpost n'] yhat_train_mean: Float64[ndarray, ' n'] | None = None - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: super().__init__(*args, **kw) # fix up attributes diff --git a/tests/rbartpackages/BART3.py b/tests/rbartpackages/BART3.py index 5bd921d8..6a460057 100644 --- a/tests/rbartpackages/BART3.py +++ b/tests/rbartpackages/BART3.py @@ -1,6 +1,6 @@ # bartz/tests/rbartpackages/BART3.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,6 +24,8 @@ """Wrapper for the R package BART3.""" +# ruff: noqa: ANN002, ANN003 + from typing import NamedTuple, TypedDict import numpy as np @@ -92,7 +94,7 @@ class mc_gbart(RObjectBase): # noqa: D101 because the R doc is added automatica yhat_train_mean: Float64[ndarray, ' n'] | None = None yhat_train_upper: Float64[ndarray, ' n'] | None = None - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: super().__init__(*args, **kw) # fix up attributes diff --git a/tests/rbartpackages/_base.py b/tests/rbartpackages/_base.py index fd38b4eb..9cbc411b 100644 --- a/tests/rbartpackages/_base.py +++ b/tests/rbartpackages/_base.py @@ -1,6 +1,6 @@ # bartz/tests/rbartpackages/_base.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -22,9 +22,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from collections.abc import Callable +from collections.abc import Callable, Iterable, Mapping from functools import wraps from re import fullmatch, match +from typing import Any import numpy as np from rpy2 import robjects @@ -50,7 +51,7 @@ pass else: - def polars_to_r(df): + def polars_to_r(df: polars.DataFrame) -> object: df = df.to_pandas() return pandas2ri.py2rpy(df) @@ -65,7 +66,7 @@ def polars_to_r(df): pass else: - def jax_to_r(x): + def jax_to_r(x: jax.Array) -> object: x = np.asarray(x) if x.ndim == 0: x = x[()] @@ -78,7 +79,7 @@ def jax_to_r(x): # converter for BoolVector (why isn't it in the numpy converter?) -def bool_vector_to_python(x): +def bool_vector_to_python(x: BoolVector) -> np.ndarray[Any, np.dtype[np.bool_]]: return np.array(x, bool) @@ -90,7 +91,7 @@ def bool_vector_to_python(x): dict_converter = conversion.Converter('dict') -def dict_to_r(x): +def dict_to_r(x: dict[str, Any]) -> robjects.ListVector: return robjects.ListVector(x) @@ -125,20 +126,20 @@ class RObjectBase: ) _convctx = conversion.localconverter(_converter) - def _py2r(self, x): + def _py2r(self, x: object) -> object: if isinstance(x, __class__): return x._robject with self._convctx: return self._converter.py2rpy(x) - def _r2py(self, x): + def _r2py(self, x: object) -> object: with self._convctx: return self._converter.rpy2py(x) - def _args2r(self, args): + def _args2r(self, args: Iterable[Any]) -> tuple[Any, ...]: return tuple(map(self._py2r, args)) - def _kw2r(self, kw): + def _kw2r(self, kw: Mapping[str, Any]) -> dict[str, Any]: return {key: self._py2r(value) for key, value in kw.items()} _rfuncname: str = NotImplemented @@ -153,7 +154,7 @@ def _library(self) -> str: raise ValueError(msg) return m.group(1) - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any) -> None: robjects.r(f'loadNamespace("{self._library}")') func = robjects.r(self._rfuncname) obj = func(*self._args2r(args), **self._kw2r(kw)) @@ -162,7 +163,7 @@ def __init__(self, *args, **kw): for s, v in obj.items(): setattr(self, s.replace('.', '_'), self._r2py(v)) - def __init_subclass__(cls, **kw): + def __init_subclass__(cls, **kw: Any) -> None: """Automatically add R documentation to subclasses.""" library, name = cls._rfuncname.split('::') page = Package(library).fetch(name) @@ -202,7 +203,7 @@ def rmethod(meth: Callable, *, rname: str | None = None) -> Callable: # can be determined at runtime @wraps(meth) - def impl(self, *args, **kw): + def impl(self: RObjectBase, *args: Any, **kw: Any) -> object: if isinstance(self._robject, RS4): func = robjects.r['$'](self._robject, rname) out = func(*self._args2r(args), **self._kw2r(kw)) diff --git a/tests/rbartpackages/bartMachine.py b/tests/rbartpackages/bartMachine.py index c16b9359..03a389e2 100644 --- a/tests/rbartpackages/bartMachine.py +++ b/tests/rbartpackages/bartMachine.py @@ -1,6 +1,6 @@ # bartz/tests/rbartpackages/bartMachine.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,7 +24,7 @@ """Python wrapper of the R package bartMachine.""" -# ruff: noqa: D102 +# ruff: noqa: D102, ANN201, ANN002, ANN003 from rpy2 import robjects @@ -34,7 +34,9 @@ class bartMachine(RObjectBase): # noqa: D101, because the doc is pulled from R _rfuncname = 'bartMachine::bartMachine' - def __init__(self, *args, num_cores=None, megabytes=5000, **kw): + def __init__( + self, *args, num_cores: int | None = None, megabytes: int = 5000, **kw + ) -> None: robjects.r(f'options(java.parameters = "-Xmx{megabytes:d}m")') robjects.r('loadNamespace("bartMachine")') if num_cores is not None: diff --git a/tests/rbartpackages/dbarts.py b/tests/rbartpackages/dbarts.py index ca61a280..54fac0ab 100644 --- a/tests/rbartpackages/dbarts.py +++ b/tests/rbartpackages/dbarts.py @@ -1,6 +1,6 @@ # bartz/tests/rbartpackages/dbarts.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -24,7 +24,7 @@ """Python wrapper of the R package `dbarts`.""" -# ruff: noqa: D101, D102 +# ruff: noqa: D101, D102, ANN201, ANN002, ANN003 from rpy2 import robjects @@ -44,7 +44,7 @@ class bart(RObjectBase): _rfuncname = 'dbarts::bart' _split_probs = 'splitprobs' - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: split_probs = kw.get(self._split_probs) if isinstance(split_probs, dict): values = list(split_probs.values()) @@ -78,9 +78,9 @@ class bart2(bart): _rfuncname = 'dbarts::bart2' _split_probs = 'split_probs' - def __init__(self, formula, *args, **kw): - formula = robjects.Formula(formula) - super().__init__(formula, *args, **kw) + def __init__(self, formula: str, *args, **kw) -> None: + rformula = robjects.Formula(formula) + super().__init__(rformula, *args, **kw) class rbart_vi(bart2): diff --git a/tests/test_BART.py b/tests/test_BART.py index ab15c370..00da99e5 100644 --- a/tests/test_BART.py +++ b/tests/test_BART.py @@ -27,9 +27,11 @@ This is the main suite of tests. """ +from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial +from gc import collect from os import getpid, kill from signal import SIG_IGN, SIGINT, getsignal, signal from sys import version_info @@ -41,32 +43,42 @@ import numpy import polars as pl import pytest -from equinox import EquinoxRuntimeError -from jax import debug_nans, lax, random, vmap +from equinox import EquinoxRuntimeError, tree_at +from jax import block_until_ready, config, debug_nans, devices, random, tree, vmap from jax import numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import logit, ndtr -from jax.sharding import SingleDeviceSharding -from jax.tree import map_with_path -from jax.tree_util import KeyPath -from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, UInt +from jax.sharding import Mesh, SingleDeviceSharding +from jax.tree_util import KeyPath, keystr +from jaxtyping import ( + Array, + Bool, + Float, + Float32, + Int32, + Key, + PyTree, + Real, + Shaped, + UInt, +) from numpy.testing import assert_allclose, assert_array_equal from pytest_subtests import SubTests from bartz import profile_mode -from bartz.debug import ( - TraceWithOffset, - check_trace, +from bartz.debug import TraceWithOffset, check_trace, sample_prior, trees_BART_to_bartz +from bartz.debug import debug_gbart as gbart +from bartz.debug import debug_mc_gbart as mc_gbart +from bartz.grove import ( forest_depth_distr, - sample_prior, + is_actual_leaf, tree_actual_depth, - trees_BART_to_bartz, + tree_depth, + tree_depths, ) -from bartz.debug import debug_gbart as gbart -from bartz.debug import debug_mc_gbart as mc_gbart -from bartz.grove import is_actual_leaf, tree_depth, tree_depths from bartz.jaxext import get_default_device, get_device_count, split from bartz.mcmcloop import compute_varcount, evaluate_trace +from bartz.mcmcstep import State from bartz.mcmcstep._state import chain_vmap_axes from tests.rbartpackages import BART3 from tests.test_mcmcstep import check_sharding, get_normal_spec, normalize_spec @@ -138,7 +150,7 @@ def gen_y( @pytest.fixture(params=list(range(1, N_VARIANTS + 1)), scope='module') -def variant(request) -> int: +def variant(request: pytest.FixtureRequest) -> int: """Return a parametrized indicator to select different BART configurations.""" return request.param @@ -300,12 +312,12 @@ def cachedbart(self, variant: int) -> CachedBart: return CachedBart(kwargs=kw, bart=bart) - def test_residuals_accuracy(self, cachedbart: CachedBart): + def test_residuals_accuracy(self, cachedbart: CachedBart) -> None: """Check that running residuals are close to the recomputed final residuals.""" accum_resid, actual_resid = cachedbart.bart.compare_resid() assert_close_matrices(accum_resid, actual_resid, rtol=1e-4) - def test_convergence(self, cachedbart: CachedBart): + def test_convergence(self, cachedbart: CachedBart) -> None: """Run multiple chains and check convergence with rhat.""" bart = cachedbart.bart nchains, _ = bart._mcmc_state.resid.shape @@ -363,7 +375,9 @@ def kw_bartz_to_BART3(self, key: Key[Array, ''], kw: dict, bart: mc_gbart) -> di return kw_BART - def check_rbart(self, kw, bart, rbart): + def check_rbart( + self, kw: dict[str, Any], bart: mc_gbart, rbart: BART3.mc_gbart + ) -> None: """Subroutine for `test_comparison_BART3`, check that the R BART output is self-consistent.""" # convert the trees to bartz format trees = rbart.treedraws['trees'] @@ -405,7 +419,9 @@ def check_rbart(self, kw, bart, rbart): prob_test, rbart.prob_test.astype(numpy.float32), rtol=1e-7 ) - def test_comparison_BART3(self, cachedbart: CachedBart, keys, subtests: SubTests): + def test_comparison_BART3( + self, cachedbart: CachedBart, keys: split, subtests: SubTests + ) -> None: """Check `bartz.BART` gives results similar to the R package BART3.""" bart = cachedbart.bart kw = cachedbart.kwargs @@ -511,15 +527,17 @@ def test_comparison_BART3(self, cachedbart: CachedBart, keys, subtests: SubTests atol=1.7 * (p - 1) ** 0.5, ) - def test_different_chains(self, cachedbart: CachedBart): + def test_different_chains(self, cachedbart: CachedBart) -> None: """Check that different chains give different results.""" bart = cachedbart.bart step_theta = bart._mcmc_state.forest.rho is not None - def assert_different(x, **kwargs): - def assert_different(path: KeyPath, x, chain_axis: int | None): - str_path = ''.join(map(str, path)) + def assert_different(x: PyTree[Array], **kwargs: Any) -> None: + def assert_different( + path: KeyPath, x: Array | None, chain_axis: int | None + ) -> None: + str_path = keystr(path) if str_path.endswith('.theta') and not step_theta: return if x is not None and chain_axis is not None: @@ -534,7 +552,7 @@ def assert_different(path: KeyPath, x, chain_axis: int | None): ) axes = chain_vmap_axes(x) - map_with_path(assert_different, x, axes, is_leaf=lambda x: x is None) + tree.map_with_path(assert_different, x, axes, is_leaf=lambda x: x is None) assert_different(bart._mcmc_state, rtol=0.05) assert_different(bart._main_trace, rtol=0.03) @@ -546,7 +564,7 @@ def clipped_logit(x: Array, eps: float) -> Array: return logit(jnp.clip(x, eps, 1 - eps)) -def test_sequential_guarantee(kw: dict, subtests: SubTests): +def test_sequential_guarantee(kw: dict, subtests: SubTests) -> None: """Check that the way iterations are saved does not influence the result.""" # reference run kw['keepevery'] = 1 @@ -595,7 +613,7 @@ def test_sequential_guarantee(kw: dict, subtests: SubTests): ) -def test_output_shapes(kw): +def test_output_shapes(kw: dict[str, Any]) -> None: """Check the output shapes of all the array attributes of `bartz.BART.mc_gbart`.""" bart = mc_gbart(**kw) @@ -644,7 +662,7 @@ def test_output_shapes(kw): assert bart.yhat_train_mean.shape == (n,) -def test_output_types(kw): +def test_output_types(kw: dict[str, Any]) -> None: """Check the output types of all the attributes of BART.gbart.""" bart = mc_gbart(**kw) @@ -673,48 +691,50 @@ def test_output_types(kw): assert bart.yhat_train_mean.dtype == jnp.float32 -def test_predict(kw): +def test_predict(kw: dict[str, Any]) -> None: """Check that the public BART.gbart.predict method works.""" bart = mc_gbart(**kw) yhat_train = bart.predict(kw['x_train']) assert_array_equal(bart.yhat_train, yhat_train) -def test_varprob(kw): - """Basic checks of the `varprob` attribute.""" - bart = mc_gbart(**kw) - - # basic properties of probabilities - assert jnp.all(bart.varprob >= 0) - assert jnp.all(bart.varprob <= 1) - assert_allclose(bart.varprob.sum(axis=1), 1, rtol=1e-6) - - # probabilities are either 0 or 1/peff if sparsity is disabled - sparse = kw.get('sparse', False) - if not sparse: - unique = jnp.unique(bart.varprob) - assert unique.size in (1, 2) - if unique.size == 2: # pragma: no cover - assert unique[0] == 0 - - # the mean is the mean - assert_array_equal(bart.varprob_mean, bart.varprob.mean(axis=0)) +class TestVarprobAttr: + """Test the `mc_gbart.varprob` attribute.""" + def test_basic_properties(self, kw: dict[str, Any]) -> None: + """Basic checks of the `varprob` attribute.""" + bart = mc_gbart(**kw) -def test_varprob_blocked_vars(keys): - """Check that varprob = 0 on predictors blocked a priori.""" - X = gen_X(keys.pop(), 2, 30, 'continuous') - y = gen_y(keys.pop(), X, None, 'continuous') - with debug_nans(False): - xinfo = jnp.array([[jnp.nan], [0]]) - bart = mc_gbart(x_train=X, y_train=y, xinfo=xinfo, seed=keys.pop()) - assert_array_equal(bart._mcmc_state.forest.max_split, [0, 1]) - assert_array_equal(bart.varprob_mean, [0, 1]) - assert jnp.all(bart.varprob_mean == bart.varprob) + # basic properties of probabilities + assert jnp.all(bart.varprob >= 0) + assert jnp.all(bart.varprob <= 1) + assert_allclose(bart.varprob.sum(axis=1), 1, rtol=1e-6) + + # probabilities are either 0 or 1/peff if sparsity is disabled + sparse = kw.get('sparse', False) + if not sparse: + unique = jnp.unique(bart.varprob) + assert unique.size in (1, 2) + if unique.size == 2: # pragma: no cover + assert unique[0] == 0 + + # the mean is the mean + assert_array_equal(bart.varprob_mean, bart.varprob.mean(axis=0)) + + def test_blocked_vars(self, keys: split) -> None: + """Check that varprob = 0 on predictors blocked a priori.""" + X = gen_X(keys.pop(), 2, 30, 'continuous') + y = gen_y(keys.pop(), X, None, 'continuous') + with debug_nans(False): + xinfo = jnp.array([[jnp.nan], [0]]) + bart = mc_gbart(x_train=X, y_train=y, xinfo=xinfo, seed=keys.pop()) + assert_array_equal(bart._mcmc_state.forest.max_split, [0, 1]) + assert_array_equal(bart.varprob_mean, [0, 1]) + assert jnp.all(bart.varprob_mean == bart.varprob) @pytest.mark.parametrize('theta', ['fixed', 'free']) -def test_variable_selection(keys: split, theta: Literal['fixed', 'free']): +def test_variable_selection(keys: split, theta: Literal['fixed', 'free']) -> None: """Check that variable selection works.""" # data config p = 100 # number of predictors @@ -746,7 +766,7 @@ def test_variable_selection(keys: split, theta: Literal['fixed', 'free']): assert bart.varprob_mean[~mask].max().item() < 1 / (p - peff) -def test_scale_shift(kw): +def test_scale_shift(kw: dict[str, Any]) -> None: """Check self-consistency of rescaling the inputs.""" if kw['y_train'].dtype == bool: pytest.skip('Cannot rescale binary responses.') @@ -791,7 +811,7 @@ def test_scale_shift(kw): assert_allclose(bart1.sigma_mean, bart2.sigma_mean / scale, rtol=1e-6, atol=1e-6) -def test_min_points_per_decision_node(kw): +def test_min_points_per_decision_node(kw: dict[str, Any]) -> None: """Check that the limit of at least 10 datapoints per decision node is respected.""" kw.setdefault('bart_kwargs', {}).setdefault('init_kw', {}).update( min_points_per_leaf=None @@ -813,7 +833,7 @@ def test_min_points_per_decision_node(kw): assert jnp.any(distr_marg[min_points:] > 0) -def test_min_points_per_leaf(kw): +def test_min_points_per_leaf(kw: dict[str, Any]) -> None: """Check that the limit of at least 5 datapoints per leaf is respected.""" kw.setdefault('bart_kwargs', {}).setdefault('init_kw', {}).update( min_points_per_decision_node=None @@ -833,7 +853,7 @@ def test_min_points_per_leaf(kw): assert distr_marg[min_points] > 0 -def set_num_datapoints(kw: dict, n): +def set_num_datapoints(kw: dict, n: int) -> dict: """Set the number of datapoints in the kw dictionary.""" assert n <= kw['y_train'].size kw = kw.copy() @@ -844,7 +864,7 @@ def set_num_datapoints(kw: dict, n): return kw -def test_no_datapoints(kw): +def test_no_datapoints(kw: dict[str, Any]) -> None: """Check automatic data scaling with 0 datapoints.""" # remove all datapoints kw = set_num_datapoints(kw, 0) @@ -855,9 +875,22 @@ def test_no_datapoints(kw): xinfo = jnp.broadcast_to(jnp.arange(nsplits, dtype=jnp.float32), (p, nsplits)) kw.update(xinfo=xinfo) + # disable data sharding + kw.setdefault('bart_kwargs', {}).update(num_data_devices=None) + + # enable saving the likelihood ratio to check it's always 1 + kw.setdefault('bart_kwargs', {}).setdefault('init_kw', {}).update( + save_ratios=True, min_points_per_decision_node=None, min_points_per_leaf=None + ) + + # run bart bart = mc_gbart(**kw) + + # check there are indeed 0 datapoints in the output ndpost = kw['ndpost'] assert bart.yhat_train.shape == (ndpost, 0) + + # check default values that may be set in a special way if there are 0 datapoints assert bart.offset == 0 if kw['y_train'].dtype == bool: tau_num = 3 @@ -871,8 +904,12 @@ def test_no_datapoints(kw): rtol=1e-6, ) + # check the likelihood ratio is always 1 + assert_array_equal(bart._burnin_trace.log_likelihood, 0.0) + assert_array_equal(bart._main_trace.log_likelihood, 0.0) + -def test_one_datapoint(kw): +def test_one_datapoint(kw: dict[str, Any]) -> None: """Check automatic data scaling with 1 datapoint.""" kw = set_num_datapoints(kw, 1) @@ -887,6 +924,11 @@ def test_one_datapoint(kw): # disable data sharding kw.setdefault('bart_kwargs', {}).update(num_data_devices=None) + # enable saving the likelihood ratio to check it's always 1 + kw.setdefault('bart_kwargs', {}).setdefault('init_kw', {}).update( + save_ratios=True, min_points_per_decision_node=None, min_points_per_leaf=None + ) + bart = mc_gbart(**kw) if kw['y_train'].dtype == bool: tau_num = 3 @@ -902,18 +944,27 @@ def test_one_datapoint(kw): rtol=1e-6, ) + # check the likelihood ratio is always 1 + assert_array_equal(bart._burnin_trace.log_likelihood, 0.0) + assert_array_equal(bart._main_trace.log_likelihood, 0.0) + -def test_two_datapoints(kw): +def test_two_datapoints(kw: dict[str, Any]) -> None: """Check automatic data scaling with 2 datapoints.""" kw = set_num_datapoints(kw, 2) + kw.setdefault('bart_kwargs', {}).setdefault('init_kw', {}).update( + save_ratios=True, min_points_per_decision_node=None, min_points_per_leaf=None + ) bart = mc_gbart(**kw) if kw['y_train'].dtype != bool: assert_allclose(bart.sigest, kw['y_train'].std(), rtol=1e-6) if kw['usequants']: assert jnp.all(bart._mcmc_state.forest.max_split <= 1) + assert not jnp.all(bart._burnin_trace.log_likelihood == 0.0) + assert not jnp.all(bart._main_trace.log_likelihood == 0.0) -def test_few_datapoints(kw): +def test_few_datapoints(kw: dict[str, Any]) -> None: """Check that the trees cannot grow if there are not enough datapoints. If there are less than 10 datapoints, it is not possible to satisfy the 10 @@ -935,7 +986,7 @@ def test_few_datapoints(kw): assert jnp.all(bart.yhat_train == bart.yhat_train[:, :1]) -def test_xinfo(): +def test_xinfo() -> None: """Simple check that the `xinfo` parameter works.""" with debug_nans(False): xinfo = jnp.array( @@ -959,7 +1010,7 @@ def test_xinfo(): assert_array_equal(bart._mcmc_state.forest.max_split, [2, 3, 0]) -def test_xinfo_wrong_p(): +def test_xinfo_wrong_p() -> None: """Check that `xinfo` must have the same number of rows as `X`.""" with debug_nans(False): xinfo = jnp.array( @@ -981,12 +1032,82 @@ def test_xinfo_wrong_p(): (10, 255), # likely always available decision rules for all variables ], ) -def test_prior(keys, p, nsplits): +def test_prior(keys: split, p: int, nsplits: int, subtests: SubTests) -> None: """Check that the posterior without data is equivalent to the prior.""" - # sample from posterior without data - xinfo = jnp.broadcast_to(jnp.arange(nsplits, dtype=jnp.float32), (p, nsplits)) + # run bart without data + bart = run_bart_like_prior(keys.pop(), p, nsplits, subtests) + + # sample from prior + prior_trace = sample_prior_like(keys.pop(), bart, subtests) + + with subtests.test('number of stub trees'): + nstub_mcmc = count_stub_trees(bart._main_trace.split_tree) + nstub_prior = count_stub_trees(prior_trace.split_tree) + rhat_nstub = rhat([nstub_mcmc, nstub_prior]) + assert rhat_nstub < 1.01 + + if (p, nsplits) != (1, 1): + # all the following are equivalent to nstub in the 1-1 case + + with subtests.test('number of simple trees'): + nsimple_mcmc = count_simple_trees(bart._main_trace.split_tree) + nsimple_prior = count_simple_trees(prior_trace.split_tree) + rhat_nsimple = rhat([nsimple_mcmc, nsimple_prior]) + assert rhat_nsimple < 1.01 + + varcount_prior = compute_varcount( + bart._mcmc_state.forest.max_split.size, prior_trace + ) + + with subtests.test('varcount'): + rhat_varcount = multivariate_rhat([bart.varcount, varcount_prior]) + if p == 10: + # varcount is p-dimensional + assert rhat_varcount < 1.4 + else: + assert rhat_varcount < 1.05 + + with subtests.test('number of nodes'): + sum_varcount_mcmc = bart.varcount.sum(axis=1) + sum_varcount_prior = varcount_prior.sum(axis=1) + rhat_sum_varcount = rhat([sum_varcount_mcmc, sum_varcount_prior]) + assert rhat_sum_varcount < 1.05 + + with subtests.test('imbalance index'): + imb_mcmc = avg_imbalance_index(bart._main_trace.split_tree) + imb_prior = avg_imbalance_index(prior_trace.split_tree) + rhat_imb = rhat([imb_mcmc, imb_prior]) + assert rhat_imb < 1.02 + + with subtests.test('average max tree depth'): + maxd_mcmc = avg_max_tree_depth(bart._main_trace.split_tree) + maxd_prior = avg_max_tree_depth(prior_trace.split_tree) + rhat_maxd = rhat([maxd_mcmc, maxd_prior]) + assert rhat_maxd < 1.02 + + with subtests.test('max tree depth distribution'): + dd_mcmc = bart.depth_distr() + dd_prior = forest_depth_distr(prior_trace.split_tree) + rhat_dd = multivariate_rhat([dd_mcmc.squeeze(0), dd_prior]) + assert rhat_dd < 1.05 + + with subtests.test('y_test'): + X = random.randint(keys.pop(), (p, 30), 0, nsplits + 1) + yhat_mcmc = bart._bart._predict(X) + yhat_prior = evaluate_trace(X, prior_trace) + rhat_yhat = multivariate_rhat([yhat_mcmc, yhat_prior]) + assert rhat_yhat < 1.1 + + +def run_bart_like_prior( + key: Key[Array, ''], p: int, nsplits: int, subtests: SubTests +) -> mc_gbart: + """Run `mc_gbart` without datapoints to sample the prior distribution.""" # set the split grid manually because automatic setting relies on datapoints - kw = dict( + xinfo = jnp.broadcast_to(jnp.arange(nsplits, dtype=jnp.float32), (p, nsplits)) + + # configure bart to run many mcmc iterations, without data + kw: dict = dict( x_train=jnp.empty((p, 0)), y_train=jnp.empty(0), ntree=20, @@ -994,15 +1115,32 @@ def test_prior(keys, p, nsplits): nskip=3000, printevery=None, xinfo=xinfo, - seed=keys.pop(), + seed=key, mc_cores=1, bart_kwargs=dict( - init_kw=dict(min_points_per_decision_node=None, min_points_per_leaf=None) + init_kw=dict( + # unset limits on datapoints per node because there's no data + min_points_per_decision_node=None, + min_points_per_leaf=None, + # save likelihood ratio to check it's 1 + save_ratios=True, + ) ), - # unset limits on datapoints per node because there's no data ) + bart = mc_gbart(**kw) + with subtests.test('likelihood ratio = 1'): + assert_array_equal(bart._burnin_trace.log_likelihood, 0.0) + assert_array_equal(bart._main_trace.log_likelihood, 0.0) + + return bart + + +def sample_prior_like( + key: Key[Array, ''], bart: mc_gbart, subtests: SubTests +) -> TraceWithOffset: + """Sample from the prior with the same settings used in `bart`.""" # extract p_nonterminal in original format from mcmc state p_nonterminal = bart._mcmc_state.forest.p_nonterminal max_depth = tree_depth(p_nonterminal) @@ -1011,77 +1149,21 @@ def test_prior(keys, p, nsplits): # sample from prior prior_trees = sample_prior( - keys.pop(), - kw['ndpost'], - kw['ntree'], + key, + bart.ndpost, + len(bart._mcmc_state.forest.leaf_tree), bart._mcmc_state.forest.max_split, p_nonterminal, - jnp.sqrt(lax.reciprocal(bart._mcmc_state.forest.leaf_prior_cov_inv)), + jnp.sqrt(jnp.reciprocal(bart._mcmc_state.forest.leaf_prior_cov_inv)), ) - prior_trace = TraceWithOffset.from_trees_trace(prior_trees, bart.offset) - - # check prior samples - bad = check_trace(prior_trees, bart._mcmc_state.forest.max_split) - bad_count = jnp.count_nonzero(bad) - assert bad_count == 0 - # compare number of stub trees - nstub_mcmc = count_stub_trees(bart._main_trace.split_tree) - nstub_prior = count_stub_trees(prior_trace.split_tree) - rhat_nstub = rhat([nstub_mcmc, nstub_prior]) - assert rhat_nstub < 1.01 + with subtests.test('check prior trees'): + bad = check_trace(prior_trees, bart._mcmc_state.forest.max_split) + bad_count = jnp.count_nonzero(bad) + assert bad_count == 0 - if (p, nsplits) != (1, 1): - # all the following are equivalent to nstub in the 1-1 case - - # compare number of "simple" trees - nsimple_mcmc = count_simple_trees(bart._main_trace.split_tree) - nsimple_prior = count_simple_trees(prior_trace.split_tree) - rhat_nsimple = rhat([nsimple_mcmc, nsimple_prior]) - assert rhat_nsimple < 1.01 - - # compare varcount - varcount_prior = compute_varcount( - bart._mcmc_state.forest.max_split.size, prior_trace - ) - rhat_varcount = multivariate_rhat([bart.varcount, varcount_prior]) - if p == 10: - # varcount is p-dimensional - assert rhat_varcount < 1.4 - else: - assert rhat_varcount < 1.05 - - # compare number of nodes. since #leaves = 1 + #(internal nodes) in binary - # trees, I only check #(internal nodes) = sum(varcount). - sum_varcount_mcmc = bart.varcount.sum(axis=1) - sum_varcount_prior = varcount_prior.sum(axis=1) - rhat_sum_varcount = rhat([sum_varcount_mcmc, sum_varcount_prior]) - assert rhat_sum_varcount < 1.05 - - # compare imbalance index - imb_mcmc = avg_imbalance_index(bart._main_trace.split_tree) - imb_prior = avg_imbalance_index(prior_trace.split_tree) - rhat_imb = rhat([imb_mcmc, imb_prior]) - assert rhat_imb < 1.02 - - # compare average max tree depth - maxd_mcmc = avg_max_tree_depth(bart._main_trace.split_tree) - maxd_prior = avg_max_tree_depth(prior_trace.split_tree) - rhat_maxd = rhat([maxd_mcmc, maxd_prior]) - assert rhat_maxd < 1.02 - - # compare max tree depth distribution - dd_mcmc = bart.depth_distr() - dd_prior = forest_depth_distr(prior_trace.split_tree) - rhat_dd = multivariate_rhat([dd_mcmc.squeeze(0), dd_prior]) - assert rhat_dd < 1.05 - - # compare y - X = random.randint(keys.pop(), (p, 30), 0, nsplits + 1) - yhat_mcmc = bart._bart._predict(X) - yhat_prior = evaluate_trace(X, prior_trace) - rhat_yhat = multivariate_rhat([yhat_mcmc, yhat_prior]) - assert rhat_yhat < 1.1 + # pack up trees together with offset + return TraceWithOffset.from_trees_trace(prior_trees, bart.offset) def count_stub_trees( @@ -1153,7 +1235,9 @@ def multivariate_rhat(chains: Real[Any, 'chain sample dim']) -> Float[Array, ''] chain_means = jnp.mean(chains, axis=1) - def compute_chain_cov(chain_samples, chain_mean): + def compute_chain_cov( + chain_samples: Float[Array, 'sample dim'], chain_mean: Float[Array, ' dim'] + ) -> Float[Array, 'dim dim']: centered = chain_samples - chain_mean return jnp.dot(centered.T, centered) / (n - 1) @@ -1199,7 +1283,7 @@ def rhat(chains: Real[Any, 'chain sample']) -> Float[Array, '']: return multivariate_rhat(chains[:, :, None]) -def test_rhat(keys): +def test_rhat(keys: split) -> None: """Test the multivariate R-hat implementation.""" chains, divergent_chains = random.normal(keys.pop(), (2, 2, 1000, 10)) mean_offset = jnp.arange(len(chains)) @@ -1210,18 +1294,18 @@ def test_rhat(keys): assert rhat_divergent > 5 -def test_jit(kw): +def test_jit(kw: dict[str, Any]) -> None: """Test that jitting around the whole interface works.""" # set printevery to None to move all iterations to the inner loop and avoid # multiple compilation kw.update(printevery=None) - # do not check trees because the assert breaks abstract tracing - kw.update(check_trees=False) - # do not count splitless variables because it breaks tracing kw.update(rm_const=False) + # do not check tree replicas because it breaks tracing + kw.update(check_replicated_trees=False) + # set device as under jit it can not be inferred from the array platform = kw['y_train'].platform() kw.setdefault('bart_kwargs', {}).update(devices=jax.devices(platform)) @@ -1235,7 +1319,12 @@ def test_jit(kw): w = kw.pop('w', None) key = kw.pop('seed') - def task(X, y, w, key): + def task( + X: Shaped[Array, 'p n'], + y: Shaped[Array, ' n'], + w: Float32[Array, ' n'] | None, + key: Key[Array, ''], + ) -> tuple[State, Shaped[Array, 'ndpost n']]: bart = mc_gbart(X, y, w=w, **kw, seed=key) return bart._mcmc_state, bart.yhat_train @@ -1256,35 +1345,31 @@ class PeriodicSigintTimer: Time in seconds to wait before sending the first SIGINT. interval Time in seconds between subsequent SIGINTs. - announce - Whether to print messages when sending SIGINTs and when stopping. """ - def __init__(self, *, first_after: float, interval: float, announce: bool): + def __init__(self, *, first_after: float, interval: float) -> None: self.first_after = max(0.0, float(first_after)) self.interval = max(0.001, float(interval)) self.pid = getpid() self._stop = Event() self._thread: Thread | None = None self.sent = 0 - self.announce = announce def _run(self) -> None: """Run the main loop of the timer.""" t0 = monotonic() + # Wait initial delay (cancellable) - if self._stop.wait(self.first_after): + if self._stop.wait(self.first_after): # pragma: no cover return + # Periodically send SIGINT until stopped - while not self._stop.is_set(): + while not self._stop.is_set(): # pragma: no branch kill(self.pid, SIGINT) self.sent += 1 - if self.announce: - elapsed = monotonic() - t0 - print( - f'[PeriodicSigintTimer] sent SIGINT #{self.sent} at t={elapsed:.2f}s' - ) - if self._stop.wait(self.interval): + elapsed = monotonic() - t0 + print(f'[PeriodicSigintTimer] sent SIGINT #{self.sent} at t={elapsed:.2f}s') + if self._stop.wait(self.interval): # pragma: no branch break def start(self) -> None: @@ -1303,18 +1388,17 @@ def cancel(self) -> None: try: self._stop.set() - if self.announce: - print(f'[PeriodicSigintTimer] stopped after {self.sent} SIGINT(s)') + print(f'[PeriodicSigintTimer] stopped after {self.sent} SIGINT(s)') finally: signal(SIGINT, prev) @contextmanager -def periodic_sigint(*, first_after: float, interval: float, announce: bool): +def periodic_sigint( + *, first_after: float, interval: float +) -> Generator[PeriodicSigintTimer, None, None]: """Context manager to periodically send SIGINT to the main thread.""" - timer = PeriodicSigintTimer( - first_after=first_after, interval=interval, announce=announce - ) + timer = PeriodicSigintTimer(first_after=first_after, interval=interval) timer.start() try: yield timer @@ -1325,7 +1409,7 @@ def periodic_sigint(*, first_after: float, interval: float, announce: bool): @pytest.mark.flaky # it's flaky because the interrupt may be caught and converted by jax internals (#33054) @pytest.mark.timeout(32) -def test_interrupt(kw): +def test_interrupt(kw: dict[str, Any]) -> None: """Test that the MCMC can be interrupted with ^C.""" kw['printevery'] = 1 kw.update(ndpost=0, nskip=10000) @@ -1334,16 +1418,16 @@ def test_interrupt(kw): # a first interruptible phase of jax compilation. Then send ^C every second, # in case the first ^C landed during a second non-interruptible compilation phase # that eats ^C and ignores it. - with periodic_sigint(first_after=3.0, interval=1.0, announce=True): + with periodic_sigint(first_after=3.0, interval=1.0): try: with pytest.raises(KeyboardInterrupt): mc_gbart(**kw) - except KeyboardInterrupt: + except KeyboardInterrupt: # pragma: no cover # Stray ^C during/after __exit__; treat as expected. pass -def test_polars(kw): +def test_polars(kw: dict[str, Any]) -> None: """Test passing data as DataFrame and Series.""" bart = mc_gbart(**kw) pred = bart.predict(kw['x_test']) @@ -1366,7 +1450,7 @@ def test_polars(kw): assert_close_matrices(pred, pred2, rtol=rtol) -def test_data_format_mismatch(kw): +def test_data_format_mismatch(kw: dict[str, Any]) -> None: """Test that passing predictors with mismatched formats raises an error.""" kw.update( x_train=pl.DataFrame(numpy.array(kw['x_train']).T), @@ -1378,14 +1462,14 @@ def test_data_format_mismatch(kw): bart.predict(kw['x_test'].to_numpy().T) -def test_automatic_integer_types(kw): +def test_automatic_integer_types(kw: dict[str, Any]) -> None: """Test that integer variables in the MCMC state have the correct type. Some integer variables change type automatically to be as small as possible. """ bart = mc_gbart(**kw) - def select_type(cond): + def select_type(cond: bool) -> type: return jnp.uint8 if cond else jnp.uint16 leaf_indices_type = select_type(kw['bart_kwargs']['maxdepth'] <= 8) @@ -1399,7 +1483,7 @@ def select_type(cond): assert bart._mcmc_state.forest.max_split.dtype == split_trees_type -def test_gbart_multichain_error(keys): +def test_gbart_multichain_error(keys: split) -> None: """Check that `bartz.BART.gbart` does not support `mc_cores`.""" X = gen_X(keys.pop(), 10, 100, 'continuous') y = gen_y(keys.pop(), X, None, 'continuous') @@ -1411,66 +1495,45 @@ def test_gbart_multichain_error(keys): gbart(X, y, mc_cores='gatto') -PLATFORM = get_default_device().platform -PYTHON_VERSION = version_info[:2] -OLD_PYTHON = get_old_python_tuple() -EXACT_CHECK = PLATFORM != 'gpu' and PYTHON_VERSION != OLD_PYTHON - +def test_same_result_profiling(variant: int, kw: dict) -> None: + """Check that the result is the same in profiling mode.""" + bart = mc_gbart(**kw) + with profile_mode(True): + kw.update(seed=random.clone(kw['seed'])) + bartp = mc_gbart(**kw) -class TestProfile: - """Test the behavior of `mc_gbart` in profiling mode.""" + platform = get_default_device().platform + python_version = version_info[:2] + old_python = get_old_python_tuple() + exact_check = platform != 'gpu' and python_version != old_python - @pytest.mark.xfail( - not EXACT_CHECK, reason='exact equality fails on old toolchain or gpu' - ) - def test_same_result(self, kw: dict): - """Check that the result is the same in profiling mode.""" - bart = mc_gbart(**kw) - with profile_mode(True): - kw.update(seed=random.clone(kw['seed'])) - bartp = mc_gbart(**kw) - - def check_same(_path, x, xp): + def check_same(_path: KeyPath, x: Array, xp: Array) -> None: + if exact_check: assert_array_equal(xp, x) - - map_with_path(check_same, bart._mcmc_state, bartp._mcmc_state) - map_with_path(check_same, bart._main_trace, bartp._main_trace) - - @pytest.mark.skipif( - EXACT_CHECK, reason='run only when same_result is expected to fail' - ) - def test_similar_result(self, kw: dict, variant: int): - """Check that the result is similar in profiling mode.""" - bart = mc_gbart(**kw) - with profile_mode(True): - kw.update(seed=random.clone(kw['seed'])) - bartp = mc_gbart(**kw) - - def check_same(_path, x, xp): + else: assert_allclose(xp, x, atol=1e-5, rtol=1e-5) - # maybe this should be close_matrices - try: - map_with_path(check_same, bart._mcmc_state, bartp._mcmc_state) - map_with_path(check_same, bart._main_trace, bartp._main_trace) - except AssertionError as a: - if ( - '\nNot equal to tolerance ' in str(a) - and PYTHON_VERSION == OLD_PYTHON - and variant in (1, 3) - ): - pytest.xfail('unsolved bug with old toolchain') - else: - raise + try: + tree.map_with_path(check_same, bart._mcmc_state, bartp._mcmc_state) + tree.map_with_path(check_same, bart._main_trace, bartp._main_trace) + except AssertionError as a: + if ( + '\nNot equal to tolerance ' in str(a) + and not exact_check + and python_version == old_python + and variant in (1, 3) + ): + pytest.xfail('unsolved bug with old toolchain') + else: + raise -def test_sharding(kw: dict): - """Check that chains live on their own devices throughout the interface.""" - # determine whether we expect sharding to be set up based on the arguments +def get_expect_sharded(kw: dict) -> bool: + """Check whether we expect sharding to be set up based on the arguments.""" bart_kwargs = kw.get('bart_kwargs', {}) num_chain_devices = bart_kwargs.get('num_chain_devices') num_data_devices = bart_kwargs.get('num_data_devices') - expect_sharded = ( + return ( num_chain_devices is not None or num_data_devices is not None or ( @@ -1481,9 +1544,38 @@ def test_sharding(kw: dict): ) ) + +def check_data_sharding(x: Array | None, mesh: Mesh) -> None: + """Check the sharding of `x` assuming it may be sharded only along the last 'data' axis.""" + if x is None: + return + elif mesh is None: + assert isinstance(x.sharding, SingleDeviceSharding) + elif 'data' in mesh.axis_names: + expected_num_devices = min(2, get_device_count()) + assert x.sharding.num_devices == expected_num_devices + expected_spec = (None,) * (x.ndim - 1) + ('data',) + assert get_normal_spec(x) == normalize_spec(expected_spec, mesh, x.shape) + + +def check_chain_sharding(x: Array | None, mesh: Mesh) -> None: + """Check the sharding of `x` assuming it may be sharded only along the first 'chains' axis.""" + if x is None: + return + elif mesh is None: + assert isinstance(x.sharding, SingleDeviceSharding) + elif 'chains' in mesh.axis_names: + expected_num_devices = min(2, get_device_count()) + assert x.sharding.num_devices == expected_num_devices + assert get_normal_spec(x) == ('chains',) + (None,) * (x.ndim - 1) + + +def test_sharding(kw: dict) -> None: + """Check that chains live on their own devices throughout the interface.""" bart = mc_gbart(**kw) # check the mesh is set up iff we expect sharding + expect_sharded = get_expect_sharded(kw) mesh = bart._mcmc_state.config.mesh assert expect_sharded == (mesh is not None) @@ -1492,45 +1584,38 @@ def test_sharding(kw: dict): check(bart._burnin_trace) check(bart._main_trace) - def check_chain_sharding(x: Array | None): - if x is None: - return - elif mesh is None: - assert isinstance(x.sharding, SingleDeviceSharding) - elif 'chains' in mesh.axis_names: - expected_num_devices = min(2, get_device_count()) - assert x.sharding.num_devices == expected_num_devices - assert get_normal_spec(x) == ('chains',) + (None,) * (x.ndim - 1) - - check_chain_sharding(bart.yhat_test) - check_chain_sharding(bart.prob_test) - check_chain_sharding(bart.prob_train) + check_chain = partial(check_chain_sharding, mesh=mesh) + + check_chain(bart.yhat_test) + check_chain(bart.prob_test) + check_chain(bart.prob_train) if bart.sigma is not None: - check_chain_sharding(bart.sigma.T) - check_chain_sharding(bart.sigma_) - check_chain_sharding(bart.varcount) - check_chain_sharding(bart.varprob) - check_chain_sharding(bart.yhat_train) - - def check_data_sharding(x: Array | None): - if x is None: - return - elif mesh is None: - assert isinstance(x.sharding, SingleDeviceSharding) - elif 'data' in mesh.axis_names: - expected_num_devices = min(2, get_device_count()) - assert x.sharding.num_devices == expected_num_devices - expected_spec = (None,) * (x.ndim - 1) + ('data',) - assert get_normal_spec(x) == normalize_spec(expected_spec, mesh, x.shape) + check_chain(bart.sigma.T) + check_chain(bart.sigma_) + check_chain(bart.varcount) + check_chain(bart.varprob) + check_chain(bart.yhat_train) + + check_data = partial(check_data_sharding, mesh=mesh) - check_data_sharding(bart.prob_train) - check_data_sharding(bart.prob_train_mean) - check_data_sharding(bart.yhat_train) - check_data_sharding(bart.yhat_train_mean) + check_data(bart.prob_train) + check_data(bart.prob_train_mean) + check_data(bart.yhat_train) + check_data(bart.yhat_train_mean) + assert bart.offset.is_fully_replicated + if bart.sigest is not None: + assert bart.sigest.is_fully_replicated + if bart.sigma_mean is not None: + assert bart.sigma_mean.is_fully_replicated + assert bart.varcount_mean.is_fully_replicated + assert bart.varprob_mean.is_fully_replicated + if bart.yhat_test_mean is not None: + assert bart.yhat_test_mean.is_fully_replicated -class TestVarprob: - """Test the `varprob` parameter thoroughly.""" + +class TestVarprobParam: + """Test the `varprob` parameter.""" def test_biased_predictor_choice(self, keys: split, kw: dict) -> None: """Check that if `varprob[i]` is high then predictor `i` is used more than others.""" @@ -1560,3 +1645,97 @@ def test_positive(self, kw: dict, subtests: SubTests) -> None: kw.update(varprob=varprob) with pytest.raises(EquinoxRuntimeError, match='varprob must be > 0'): mc_gbart(**kw) + + +def run_bart_and_block(kw: dict) -> None: + """Run bart and block until all outputs are ready.""" + bart = mc_gbart(**kw) + stuff = ( + bart.yhat_test, + bart.prob_test, + bart.prob_train, + bart.sigma, + bart.sigma_, + bart.varcount, + bart.varprob, + bart.yhat_train, + ) + block_until_ready((bart, *stuff)) + + +def test_array_no_gc(kw: dict) -> None: + """Check that arrays are not garbage collected.""" + setting = 'jax_array_garbage_collection_guard' + prev = getattr(config, setting) + config.update(setting, 'fatal') + try: + run_bart_and_block(kw) + collect() + finally: + config.update(setting, prev) + + +def test_equiv_sharding(kw: dict, subtests: SubTests) -> None: + """Check that the result is the same with/without sharding.""" + if len(devices()) < 2: + pytest.skip('Need at least 2 devices for this test') + + # baseline without sharding + baseline_kw = tree.map(lambda x: x, kw) # deep copy of structure + baseline_kw.setdefault('bart_kwargs', {}).update( + num_chain_devices=None, num_data_devices=None + ) + baseline_kw.update(nskip=0, ndpost=20, mc_cores=2) + bart = mc_gbart(**baseline_kw) + + def check_equal(path: KeyPath, xb: Array, xs: Array) -> None: + assert_close_matrices( + xs, xb, err_msg=f'{keystr(path)}: ', rtol=1e-5, reduce_rank=True + ) + + def remove_mesh(bart: mc_gbart) -> mc_gbart: + config = bart._mcmc_state.config + config = replace(config, mesh=None) + return tree_at(lambda bart: bart._mcmc_state.config, bart, config) + + with subtests.test('shard chains'): + chains_kw = tree.map(lambda x: x, baseline_kw) + chains_kw.setdefault('bart_kwargs', {}).update(num_chain_devices=2) + bart_chains = mc_gbart(**chains_kw) + bart_chains = remove_mesh(bart_chains) + tree.map_with_path(check_equal, bart, bart_chains) + + with subtests.test('shard data'): + data_kw = tree.map(lambda x: x, baseline_kw) + data_kw.setdefault('bart_kwargs', {}).update(num_data_devices=2) + bart_data = mc_gbart(**data_kw) + bart_data = remove_mesh(bart_data) + tree.map_with_path(check_equal, bart, bart_data) + + if len(devices()) >= 4: + with subtests.test('shard data and chains'): + both_kw = tree.map(lambda x: x, baseline_kw) + both_kw.setdefault('bart_kwargs', {}).update( + num_chain_devices=2, num_data_devices=2 + ) + bart_both = mc_gbart(**both_kw) + bart_both = remove_mesh(bart_both) + tree.map_with_path(check_equal, bart, bart_both) + + +def test_num_trees(kw: dict, subtests: SubTests) -> None: + """Test the number of trees.""" + kw.update(nskip=0, ndpost=0) + + with subtests.test('given ntree'): + bart = mc_gbart(**kw) + assert bart._bart.num_trees == kw['ntree'] + + with subtests.test('default ntree'): + if kw['y_train'].dtype == bool: + default_ntree = 50 + else: + default_ntree = 200 + kw.pop('ntree') + bart = mc_gbart(**kw) + assert bart._bart.num_trees == default_ntree diff --git a/tests/test_debug.py b/tests/test_debug.py index 7b9862bd..a15e56a9 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -1,6 +1,6 @@ # bartz/tests/test_debug.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -33,12 +33,13 @@ from scipy import stats from scipy.stats import ks_1samp -from bartz.debug import check_trace, format_tree, sample_prior -from bartz.jaxext import minimal_unsigned_dtype +from bartz.debug import check_trace, sample_prior +from bartz.grove import format_tree +from bartz.jaxext import minimal_unsigned_dtype, split from tests.util import manual_tree -def test_format_tree(): +def test_format_tree() -> None: """Check the output of `format_tree` on a single example.""" tree = manual_tree( [[1.0], [2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], [[4], [1, 2]], [[15], [0, 3]] @@ -63,7 +64,7 @@ class TestSamplePrior: ) @pytest.fixture - def args(self, keys): + def args(self, keys: split) -> Args: """Prepare arguments for `sample_prior`.""" # config trace_length = 1000 @@ -84,7 +85,7 @@ def args(self, keys): keys.pop(), trace_length, num_trees, max_split, p_nonterminal, sigma_mu ) - def test_valid_trees(self, args: Args): + def test_valid_trees(self, args: Args) -> None: """Check all sampled trees are valid.""" trees = sample_prior(*args) batch_shape = (args.trace_length, args.num_trees) @@ -96,7 +97,7 @@ def test_valid_trees(self, args: Args): num_bad = jnp.count_nonzero(bad).item() assert num_bad == 0 - def test_max_depth(self, keys, args: Args): + def test_max_depth(self, keys: split, args: Args) -> None: """Check that trees stop growing when p_nonterminal = 0.""" for max_depth in range(args.p_nonterminal.size + 1): p_nonterminal = jnp.zeros_like(args.p_nonterminal) @@ -107,7 +108,7 @@ def test_max_depth(self, keys, args: Args): assert jnp.all(trees.split_tree[:, :, 1 : 2**max_depth]) assert not jnp.any(trees.split_tree[:, :, 2**max_depth :]) - def test_forest_sdev(self, keys, args: Args): + def test_forest_sdev(self, keys: split, args: Args) -> None: """Check that the sum of trees is standard Normal.""" trees = sample_prior(*args) leaf_indices = random.randint( @@ -122,7 +123,7 @@ def test_forest_sdev(self, keys, args: Args): test = ks_1samp(sum_of_trees, stats.norm.cdf) assert test.pvalue > 0.1 - def test_trees_differ(self, args: Args): + def test_trees_differ(self, args: Args) -> None: """Check that trees are different across iterations.""" trees = sample_prior(*args) for attr in ('leaf_tree', 'var_tree', 'split_tree'): diff --git a/tests/test_dgp.py b/tests/test_dgp.py index c9746cd0..4166a859 100644 --- a/tests/test_dgp.py +++ b/tests/test_dgp.py @@ -25,11 +25,12 @@ """Tests `bartz.testing.gen_data`.""" from collections.abc import Mapping +from dataclasses import replace from functools import partial from types import MappingProxyType import pytest -from jax import jit, vmap +from jax import jit, random, tree, vmap from jax import numpy as jnp from jaxtyping import Array, Bool, Float, Key from numpy.testing import assert_allclose, assert_array_equal, assert_array_less @@ -77,7 +78,7 @@ def dgps_lambda_one(keys: split) -> DGP: return generate_dgps(keys.pop(REPS), 1.0) -def test_shapes_and_dtypes(keys: split): +def test_shapes_and_dtypes(keys: split) -> None: """Test that all DGP attributes have correct shapes and dtypes.""" dgp = gen_data(keys.pop(), lam=0.5, **KWARGS) n, p, k = KWARGS['n'], KWARGS['p'], KWARGS['k'] @@ -128,7 +129,7 @@ def test_shapes_and_dtypes(keys: split): class TestGenerateX: """Test the _generate_x method.""" - def test_x_mean(self, dgps: DGP): + def test_x_mean(self, dgps: DGP) -> None: """Test that x has mean close to 0.""" x_samples = dgps.x # Shape: (REPS, P, N) n_reps = x_samples.shape[0] @@ -141,7 +142,7 @@ def test_x_mean(self, dgps: DGP): z_scores = jnp.abs(means / stds_of_mean) assert_array_less(z_scores, SIGMA_THRESHOLD) - def test_x_variance(self, dgps: DGP): + def test_x_variance(self, dgps: DGP) -> None: """Test that x has variance close to 1.""" x_samples = dgps.x # Shape: (REPS, P, N) n_reps = x_samples.shape[0] @@ -161,7 +162,7 @@ def test_x_variance(self, dgps: DGP): class TestGeneratePartition: """Test the _generate_partition method.""" - def test_partition_coverage(self, dgps: DGP): + def test_partition_coverage(self, dgps: DGP) -> None: """Test that each predictor is assigned to exactly one component.""" partitions = dgps.partition # Shape: (REPS, K, P) @@ -169,7 +170,7 @@ def test_partition_coverage(self, dgps: DGP): col_sums = jnp.sum(partitions, axis=1) # Shape: (N_REPS, P) assert_array_equal(col_sums, 1) - def test_partition_counts(self, dgps: DGP): + def test_partition_counts(self, dgps: DGP) -> None: """Test that counts are either p//c or p//c + 1.""" partitions = dgps.partition # Shape: (REPS, K, P) p, k = partitions.shape[2], partitions.shape[1] @@ -181,7 +182,7 @@ def test_partition_counts(self, dgps: DGP): valid = (counts == floor_count) | (counts == floor_count + 1) assert_array_equal(valid, True) - def test_partition_balance(self, dgps: DGP): + def test_partition_balance(self, dgps: DGP) -> None: """Test that predictors are roughly balanced across components.""" partitions = dgps.partition # Shape: (REPS, K, P) n_reps, k, p = partitions.shape @@ -200,7 +201,7 @@ def test_partition_balance(self, dgps: DGP): class TestGenerateBetaShared: """Test the _generate_beta_shared method.""" - def test_beta_shared_mean(self, dgps: DGP): + def test_beta_shared_mean(self, dgps: DGP) -> None: """Test that beta_shared has mean close to 0.""" beta_samples = dgps.beta_shared # Shape: (REPS, P) n_reps = beta_samples.shape[0] @@ -215,7 +216,7 @@ def test_beta_shared_mean(self, dgps: DGP): class TestGenerateBetaSeparate: """Test the _generate_beta_separate method.""" - def test_beta_separate_mean(self, dgps: DGP): + def test_beta_separate_mean(self, dgps: DGP) -> None: """Test that beta_separate has mean close to 0.""" beta_samples = dgps.beta_separate # Shape: (REPS, K, P) n_reps = beta_samples.shape[0] @@ -226,7 +227,7 @@ def test_beta_separate_mean(self, dgps: DGP): z_scores = jnp.abs(means / stds_of_mean) assert_array_less(z_scores, SIGMA_THRESHOLD) - def test_beta_separate_independence(self, dgps: DGP): + def test_beta_separate_independence(self, dgps: DGP) -> None: """Test that rows of beta_separate are independent.""" beta_samples = dgps.beta_separate # Shape: (REPS, K, P) n_reps = beta_samples.shape[0] @@ -261,7 +262,7 @@ def test_beta_separate_independence(self, dgps: DGP): 'y', ], ) -def test_outcome_prior_variance(dgps: DGP, which: str): +def test_outcome_prior_variance(dgps: DGP, which: str) -> None: """Test that latent mean and outcome have the expected elementwise variance.""" samples = getattr(dgps, which) # Shape: (REPS, K?, N) n_reps = samples.shape[0] @@ -276,7 +277,7 @@ def test_outcome_prior_variance(dgps: DGP, which: str): expected_var = dgps.sigma2_pri - dgps.sigma2_eps elif which == 'y': expected_var = dgps.sigma2_pri - else: + else: # pragma: no cover raise KeyError(which) expected_var = expected_var[0].item() @@ -299,7 +300,7 @@ def test_outcome_prior_variance(dgps: DGP, which: str): 'y', ], ) -def test_outcome_pop_variance(dgps: DGP, which: str): +def test_outcome_pop_variance(dgps: DGP, which: str) -> None: """Test that latent mean and outcome have the expected elementwise variance.""" samples = getattr(dgps, which) # Shape: (REPS, K?, N) n_reps = samples.shape[0] @@ -315,7 +316,7 @@ def test_outcome_pop_variance(dgps: DGP, which: str): expected_var = dgps.sigma2_pop - dgps.sigma2_eps elif which == 'y': expected_var = dgps.sigma2_pop - else: + else: # pragma: no cover raise KeyError(which) expected_var = expected_var[0].item() @@ -325,7 +326,7 @@ def test_outcome_pop_variance(dgps: DGP, which: str): assert_array_less(z_scores, SIGMA_THRESHOLD) -def test_variance_relationships(dgps: DGP): +def test_variance_relationships(dgps: DGP) -> None: """Check some simple inequalities on variances.""" assert jnp.all(dgps.sigma2_pri >= 0) assert jnp.all(dgps.sigma2_pop >= 0) @@ -346,7 +347,7 @@ def test_variance_relationships(dgps: DGP): 'y', ], ) -def test_rows_independent(dgps_lambda_zero: DGP, which: str): +def test_rows_independent(dgps_lambda_zero: DGP, which: str) -> None: """Test that rows are independent when lambda=0.""" samples = getattr(dgps_lambda_zero, which) # Shape: (REPS, K, N or P) n_reps = samples.shape[0] @@ -368,7 +369,7 @@ def test_rows_independent(dgps_lambda_zero: DGP, which: str): @pytest.mark.parametrize('which', ['mulin', 'muquad', 'mu']) -def test_rows_identical(dgps_lambda_one: DGP, which: str): +def test_rows_identical(dgps_lambda_one: DGP, which: str) -> None: """Test that rows are identical when lambda=1.""" samples = getattr(dgps_lambda_one, which) # Shape: (REPS, K, N) @@ -387,15 +388,15 @@ def pattern(self) -> Bool[Array, 'p p']: """Return the predictor interaction pattern.""" return interaction_pattern(p=10, q=4) - def test_symmetry(self, pattern: Bool[Array, 'p p']): + def test_symmetry(self, pattern: Bool[Array, 'p p']) -> None: """Test that interaction pattern is symmetric.""" assert_array_equal(pattern, pattern.T) - def test_diagonal(self, pattern: Bool[Array, 'p p']): + def test_diagonal(self, pattern: Bool[Array, 'p p']) -> None: """Test that diagonal is True.""" assert_array_equal(jnp.diag(pattern), True) - def test_row_sums(self, pattern: Bool[Array, 'p p']): + def test_row_sums(self, pattern: Bool[Array, 'p p']) -> None: """Test that each row sums to q+1.""" row_sums = jnp.sum(pattern, axis=1) assert_array_equal(row_sums, 4 + 1) @@ -422,7 +423,7 @@ def pattern(self, partition: Bool[Array, 'k p'], q: int) -> Bool[Array, 'k p p'] def test_respects_partition( self, partition: Bool[Array, 'k p'], pattern: Bool[Array, 'k p p'] - ): + ) -> None: """Test that pattern only has True values within partition blocks.""" # For each component, check that True values only occur where partition is True # pattern[i, r, s] can only be True if partition[i, r] and partition[i, s] are True @@ -431,7 +432,7 @@ def test_respects_partition( def test_diagonal_within_partition( self, partition: Bool[Array, 'k p'], pattern: Bool[Array, 'k p p'] - ): + ) -> None: """Test that diagonal elements within partition are True.""" k, _, _ = pattern.shape for i in range(k): @@ -440,12 +441,25 @@ def test_diagonal_within_partition( def test_row_sums( self, pattern: Bool[Array, 'k p p'], partition: Bool[Array, 'k p'], q: int - ): + ) -> None: """Test that each row sums to q+1.""" row_sums = jnp.sum(pattern, axis=2) target = jnp.where(partition, q + 1, 0) assert_array_equal(row_sums, target) - def test_symmetry(self, pattern: Bool[Array, 'k p p']): + def test_symmetry(self, pattern: Bool[Array, 'k p p']) -> None: """Test that interaction pattern is symmetric.""" assert_array_equal(pattern, jnp.swapaxes(pattern, 1, 2)) + + +def test_univariate(keys: split) -> None: + """Check that k=None produces the same result with squeezed y.""" + key = keys.pop() + kw = dict(KWARGS) + kw.update(lam=0.5, k=1) + dgp_mv = gen_data(key, **kw) + kw.update(k=None) + key = random.clone(key) + dgp_uv = gen_data(key, **kw) + dgp_mv = replace(dgp_mv, y=dgp_mv.y.squeeze(0)) + tree.map(partial(assert_array_equal, strict=True), dgp_mv, dgp_uv) diff --git a/tests/test_jaxext.py b/tests/test_jaxext.py index ed04322d..eaf5a999 100644 --- a/tests/test_jaxext.py +++ b/tests/test_jaxext.py @@ -25,21 +25,39 @@ """Test bartz.jaxext.""" from functools import partial +from inspect import signature from itertools import product from warnings import catch_warnings +try: + from jax import shard_map # available since jax v0.6.1 +except ImportError: + from jax.experimental.shard_map import shard_map + import numpy import pytest -from jax import debug_infs, jit, random, tree +from jax import ( + NamedSharding, + debug_infs, + device_put, + devices, + jit, + lax, + make_mesh, + random, + tree, +) from jax import numpy as jnp from jax.scipy.special import ndtri -from numpy.testing import assert_allclose +from jax.sharding import AxisType, Mesh, PartitionSpec +from jaxtyping import Array, Float, Float32, Key, Shaped +from numpy.testing import assert_allclose, assert_array_equal from pytest_subtests import SubTests from scipy.stats import invgamma as scipy_invgamma from scipy.stats import ks_1samp, truncnorm from bartz import jaxext -from bartz.jaxext import split +from bartz.jaxext import equal_shards, split from bartz.jaxext.scipy.special import ndtri as patched_ndtri from bartz.jaxext.scipy.stats import invgamma from tests.util import assert_close_matrices @@ -48,7 +66,7 @@ class TestUnique: """Test jaxext.unique.""" - def test_sort(self): + def test_sort(self) -> None: """Check that it's equivalent to sort if no values are repeated.""" x = jnp.arange(10)[::-1] out, length = jaxext.unique(x, x.size, 666) @@ -56,7 +74,7 @@ def test_sort(self): assert out.dtype == x.dtype assert length == x.size - def test_fill(self): + def test_fill(self) -> None: """Check that the trailing fill value is used correctly.""" x = jnp.ones(10) out, length = jaxext.unique(x, x.size, 666) @@ -64,7 +82,7 @@ def test_fill(self): assert out.dtype == x.dtype assert length == 1 - def test_empty_input(self): + def test_empty_input(self) -> None: """Check that the function works on empty input.""" x = jnp.array([]) out, length = jaxext.unique(x, 2, 666) @@ -72,7 +90,7 @@ def test_empty_input(self): assert out.dtype == x.dtype assert length == 0 - def test_empty_output(self): + def test_empty_output(self) -> None: """Check that the function works if the output is forced to be empty.""" x = jnp.array([1, 1, 1]) out, length = jaxext.unique(x, 0, 666) @@ -87,10 +105,14 @@ class TestAutoBatch: @pytest.mark.parametrize('target_nbatches', [1, 7]) @pytest.mark.parametrize('with_margin', [False, True]) @pytest.mark.parametrize('additional_size', [3, 0]) - def test_batch_size(self, keys, target_nbatches, with_margin, additional_size): + def test_batch_size( + self, keys: split, target_nbatches: int, with_margin: bool, additional_size: int + ) -> None: """Check batch sizes are correct in various conditions.""" - def func(a, b, c): + def func( + a: Float[Array, 'n m'], b: Float[Array, ' n'], c: Float[Array, 'p n'] + ) -> tuple[Float[Array, ' n'], Float[Array, 'p n']]: return (a * b[:, None]).sum(1), c * b[None, :] atomic_batch_size = additional_size + 12 @@ -125,10 +147,10 @@ def func(a, b, c): @pytest.mark.parametrize('max_memory', [32, 1024]) # test with large max memory to trigger noop code path - def test_unbatched_arg(self, max_memory: int): + def test_unbatched_arg(self, max_memory: int) -> None: """Check the function with batching disabled on a scalar argument.""" - def func(a, b): + def func(a: Shaped[Array, ' n'], b: int) -> Shaped[Array, ' n']: return a + b batched_func = jaxext.autobatch(func, max_memory, (0, None)) @@ -141,10 +163,10 @@ def func(a, b): numpy.testing.assert_array_max_ulp(out1, out2) - def test_batch_axis_pytree(self): + def test_batch_axis_pytree(self) -> None: """Check the that a batch axis can be specified for a whole sub-pytree.""" - def func(a, b): + def func(a: int, b: dict[str, Shaped[Array, ' n']]) -> Shaped[Array, ' n']: return a + b['foo'] + b['bar'] batched_func = jaxext.autobatch(func, 32, (None, 0)) @@ -157,22 +179,22 @@ def func(a, b): numpy.testing.assert_array_max_ulp(out1, out2) - def test_large_batch_warning(self): + def test_large_batch_warning(self) -> None: """Check the function emits a warning if the size limit can't be honored.""" x = jnp.arange(10_000).reshape(10, 1000) - def f(x): + def f(x: Shaped[Array, 'n m']) -> Shaped[Array, 'n m']: return x g = jaxext.autobatch(f, 100) with pytest.warns(UserWarning, match=' > max_io_nbytes = '): g(x) - def test_empty_values(self): + def test_empty_values(self) -> None: """Check that the function works with batchable empty arrays.""" x = jnp.empty((10, 0)) - def f(x): + def f(x: Shaped[Array, 'n m']) -> Shaped[Array, 'n m']: return x g = jaxext.autobatch(f, 100, return_nbatches=True) @@ -180,11 +202,11 @@ def f(x): assert nbatches == 1 assert jnp.all(y == x) - def test_zero_size(self): + def test_zero_size(self) -> None: """Check the function works with a batch axis with length 0.""" x = jnp.empty((0, 10)) - def f(x): + def f(x: Shaped[Array, 'n m']) -> Shaped[Array, 'n m']: return x g = jaxext.autobatch(f, 100, return_nbatches=True) @@ -192,7 +214,7 @@ def f(x): assert nbatches == 1 assert jnp.all(y == x) - def test_reduction_basic(self, keys: split, subtests: SubTests): + def test_reduction_basic(self, keys: split, subtests: SubTests) -> None: """Check that reduction produces the expected result.""" # use an internal loop instead of pytest.mark.parametrize because there # are too many combinations of parameters @@ -231,7 +253,9 @@ def test_reduction_basic(self, keys: split, subtests: SubTests): dtype=dtype.dtype.name, ): - def func(*args, nin=nin): + def func( + *args: Shaped[Array, '*shape'], nin: int = nin + ) -> Shaped[Array, '*shape'] | tuple[Shaped[Array, '*shape'], ...]: out = sum(args) if nin == 1: return out @@ -251,7 +275,7 @@ def func(*args, nin=nin): (jnp.iinfo(dtype).max + 1) // 2, dtype, ) - elif jnp.issubdtype(dtype, jnp.bool_): + elif jnp.issubdtype(dtype, jnp.bool_): # pragma: no branch args = random.bernoulli(keys.pop(), 0.5, (nin, *shape)) expected = tree.map(partial(reduction, axis=axis), func(*args)) @@ -279,10 +303,10 @@ def func(*args, nin=nin): tree.map(partial(assert_close_matrices, rtol=1e-6), result, expected) - def test_reduction_with_unbatched_input(self, keys): + def test_reduction_with_unbatched_input(self, keys: split) -> None: """Check reduction works with unbatched (None) input arguments.""" - def func(x, scalar): + def func(x: Float[Array, 'n m'], scalar: float) -> Float[Array, 'n m']: return x * scalar x = random.uniform(keys.pop(), (50, 8)) @@ -295,10 +319,10 @@ def func(x, scalar): assert result.shape == (8,) assert_allclose(result, expected, rtol=1e-6) - def test_reduction_with_return_nbatches(self, keys): + def test_reduction_with_return_nbatches(self, keys: split) -> None: """Check reduce_ufunc works together with return_nbatches.""" - def func(x): + def func(x: Float[Array, 'n m']) -> Float[Array, 'n m']: return x x = random.uniform(keys.pop(), (100, 10)) @@ -316,12 +340,12 @@ def func(x): assert_allclose(result, expected, rtol=1e-6) -def different_keys(keya, keyb): +def different_keys(keya: Key[Array, ''], keyb: Key[Array, '']) -> bool: """Return True iff two jax random keys are different.""" return jnp.any(random.key_data(keya) != random.key_data(keyb)).item() -def test_split(keys): +def test_split(keys: split) -> None: """Test jaxext.split.""" key = keys.pop() ks = jaxext.split(key, 3) @@ -370,14 +394,14 @@ def test_split(keys): class TestJaxPatches: """Check that some jax stuff I patch is correct and still to be patched.""" - def test_invgamma_missing(self): + def test_invgamma_missing(self) -> None: """Check that jax does not implement the inverse gamma distribution.""" with pytest.raises(ImportError, match=r'gammainccinv'): from jax.scipy.special import gammainccinv # noqa: F401, PLC0415 with pytest.raises(ImportError, match=r'invgamma'): from jax.scipy.stats import invgamma # noqa: F401, PLC0415 - def test_invgamma_correct(self, keys): + def test_invgamma_correct(self, keys: split) -> None: """Compare my implementation of invgamma against scipy's.""" p = random.uniform(keys.pop(), (100,), float, 0.01, 0.99) alpha = 3.5 @@ -386,13 +410,13 @@ def test_invgamma_correct(self, keys): assert_allclose(x1, x0, rtol=1e-6) @pytest.mark.xfail(reason='Fixed in jax 0.6.2.') - def test_ndtri_bugged(self, keys): + def test_ndtri_bugged(self, keys: split) -> None: """Check that `jax.scipy.special.ndtri` triggers `jax.debug_infs`.""" x = random.uniform(keys.pop(), (100,), float, 0.01, 0.99) with debug_infs(True), pytest.raises(FloatingPointError, match=r'inf'): ndtri(x) - def test_ndtri_correct(self, keys): + def test_ndtri_correct(self, keys: split) -> None: """Check that my copy-pasted ndtri impl is equivalent to the jax one.""" x = random.uniform(keys.pop(), (100,), float, 0.01, 0.99) with debug_infs(False): @@ -404,7 +428,7 @@ def test_ndtri_correct(self, keys): class TestTruncatedNormalOneSided: """Test `jaxext.truncated_normal_onesided`.""" - def test_truncated_normal_incorrect(self, keys): + def test_truncated_normal_incorrect(self, keys: split) -> None: """Check that `jax.random.truncated_normal` is wrong out of 5 sigma.""" nsamples = 1000 lower, upper = jnp.array([(-100.0, -5.0), (5.0, 100.0)]).T @@ -415,7 +439,7 @@ def test_truncated_normal_incorrect(self, keys): test = ks_1samp(sample, truncnorm(l, u).cdf) assert test.pvalue < 0.01 - def test_correct(self, keys): + def test_correct(self, keys: split) -> None: """Check the samples come from the right distribution.""" nparams = 20 nsamples = 1000 @@ -430,7 +454,7 @@ def test_correct(self, keys): test = ks_1samp(sample, truncnorm(left, right).cdf) assert test.pvalue > 0.01 - def test_accurate(self, keys): + def test_accurate(self, keys: split) -> None: """Check that it does not over/under shoot.""" x = jaxext.truncated_normal_onesided( keys.pop(), (), jnp.bool_(True), jnp.float32(-12) @@ -441,21 +465,21 @@ def test_accurate(self, keys): ) assert 12 < x <= 12.1 - def test_finite(self, keys): + def test_finite(self, keys: split) -> None: """Check that the outputs are always finite.""" # shape and n_loops combined shall be enough that all possible # float32 values in [0, 1) are drawn by random.uniform shape = (1_000_000,) n_loops = 100 - keys = random.split(keys.pop(), n_loops) + keys = keys.pop(n_loops) platform = keys.device.platform clip = platform == 'gpu' @jit - def loop_body(key): - keys = jaxext.split(key, 3) + def loop_body(key: Key[Array, '']) -> Float32[Array, ' n']: + keys = split(key, 3) upper = random.bernoulli(keys.pop(), 0.5, shape) bound = random.uniform(keys.pop(), shape, float, -1, 1) return jaxext.truncated_normal_onesided( @@ -467,7 +491,7 @@ def loop_body(key): assert jnp.all(jnp.isfinite(vals)) -def test_is_key(keys): +def test_is_key(keys: split) -> None: """Test jaxext.is_key.""" # JAX keys should be recognized key = keys.pop() @@ -491,3 +515,76 @@ def test_is_key(keys): # NumPy arrays should not be recognized assert not jaxext.is_key(numpy.array([1, 2, 3])) + + +def make_broken_replicated_array(x: Array, axis_name: str, mesh: Mesh) -> Array: + """Replicate `x` across devices, but make it different on each device across an axis.""" + + @partial( + shard_map, + mesh=mesh, + in_specs=PartitionSpec(), + out_specs=PartitionSpec(), + # this disables the check that would notice the inconsistency + **_get_check_vma_false_kwargs(), + ) + def breaker(x: Array) -> Array: + return x + lax.axis_index(axis_name) + + return breaker(x) + + +def _get_check_vma_false_kwargs() -> dict[str, bool]: + """Get `dict(check_vma=False)` or the equivalent for old jax versions.""" + sig = signature(shard_map) + if 'check_vma' in sig.parameters: + # since jax v0.6.1 + return dict(check_vma=False) + else: + return dict(check_rep=False) + + +def test_make_broken_replicated_array() -> None: + """Test `make_broken_replicated_array`.""" + nd = len(devices()) + if nd < 2: + pytest.skip('Requires at least 2 devices') + mesh = make_mesh((nd,), ('a',), axis_types=(AxisType.Auto,)) + x = jnp.arange(nd) + xb = make_broken_replicated_array(x, 'a', mesh) + for i, shard in enumerate(xb.addressable_shards): + data: Array = shard.data + if i == 0: + assert_array_equal(data, x, strict=True) + else: + assert jnp.all(data != x) + + +@pytest.mark.parametrize('equal', [True, False]) +@pytest.mark.parametrize('replicated', [True, False]) +def test_equal_shards(equal: bool, replicated: bool) -> None: + """Test `jaxext.equal_shards`.""" + nd = len(devices()) + if nd < 2: + pytest.skip('Requires at least 2 devices') + + # define mesh + mesh = make_mesh((nd,), ('a',), axis_types=(AxisType.Auto,)) + + # create dummy array + if equal: + x = jnp.zeros(nd) + elif replicated: + x = jnp.zeros(nd) + x = make_broken_replicated_array(x, 'a', mesh) + else: + x = jnp.arange(nd) + + # shard x + spec = PartitionSpec() if replicated else PartitionSpec('a') + sharding = NamedSharding(mesh, spec) + x = device_put(x, sharding) + + # check the shards are equal or different + result = equal_shards(x, 'a', mesh=mesh, in_specs=spec) + assert result.item() == equal diff --git a/tests/test_mcmcloop.py b/tests/test_mcmcloop.py index 8fa1e3c8..131456be 100644 --- a/tests/test_mcmcloop.py +++ b/tests/test_mcmcloop.py @@ -24,14 +24,16 @@ """Test `bartz.mcmcloop`.""" +from dataclasses import replace from functools import partial +from typing import Any import pytest from equinox import filter_jit -from jax import debug_key_reuse, jit, vmap +from jax import NamedSharding, debug_key_reuse, device_put, jit, make_mesh, tree, vmap from jax import numpy as jnp -from jax.tree import map_with_path -from jax.tree_util import tree_map +from jax.sharding import AxisType, PartitionSpec +from jax.tree_util import KeyPath, tree_map from jaxtyping import Array, Float32, UInt8 from numpy.testing import assert_array_equal from pytest import FixtureRequest # noqa: PT013 @@ -39,8 +41,8 @@ from bartz import profile_mode from bartz.jaxext import get_default_device, split -from bartz.mcmcloop import run_mcmc -from bartz.mcmcstep import State, init +from bartz.mcmcloop import BurninTrace, MainTrace, run_mcmc +from bartz.mcmcstep import State, init, make_p_nonterminal from bartz.mcmcstep._state import chain_vmap_axes @@ -63,16 +65,10 @@ def gen_data( return X, y, max_split -def make_p_nonterminal(maxdepth: int) -> Float32[Array, ' {maxdepth}-1']: - """Prepare the p_nonterminal argument to `mcmcstep.init`.""" - depth = jnp.arange(maxdepth - 1) - base = 0.95 - power = 2 - return base / (1 + depth).astype(float) ** power - - @filter_jit -def simple_init(p: int, n: int, ntree: int, k: int | None = None, **kwargs) -> State: +def simple_init( + p: int, n: int, ntree: int, k: int | None = None, **kwargs: Any +) -> State: """Simplified version of `bartz.mcmcstep.init` with data pre-filled.""" X, y, max_split = gen_data(p, n, k) eye = 1.0 if k is None else jnp.eye(k) @@ -110,7 +106,7 @@ def initial_state(self, num_chains: int | None, k: int | None) -> State: """Prepare state for tests.""" return simple_init(10, 100, 20, k, num_chains=num_chains) - def test_final_state_overflow(self, keys: split, initial_state: State): + def test_final_state_overflow(self, keys: split, initial_state: State) -> None: """Check that the final state is the one in the trace even if there's overflow.""" with debug_key_reuse(initial_state.forest.num_chains() != 0): final_state, _, main_trace = run_mcmc( @@ -133,7 +129,7 @@ def test_final_state_overflow(self, keys: split, initial_state: State): final_state.error_cov_inv, main_trace.error_cov_inv[last_index] ) - def test_zero_iterations(self, keys: split, initial_state: State): + def test_zero_iterations(self, keys: split, initial_state: State) -> None: """Check 0 iterations produces a noop.""" with debug_key_reuse(initial_state.forest.num_chains() != 0): final_state, burnin_trace, main_trace = run_mcmc( @@ -142,7 +138,9 @@ def test_zero_iterations(self, keys: split, initial_state: State): tree_map(partial(assert_array_equal, strict=True), initial_state, final_state) - def assert_empty_trace(_path, x, chain_axis): + def assert_empty_trace( + _path: KeyPath, x: Array | None, chain_axis: int | None + ) -> None: if initial_state.forest.num_chains() is None or chain_axis is None: sample_axis = 0 else: @@ -150,8 +148,8 @@ def assert_empty_trace(_path, x, chain_axis): if x is not None: assert x.shape[sample_axis] == 0 - def check_trace(trace): - map_with_path( + def check_trace(trace: MainTrace | BurninTrace) -> None: + tree.map_with_path( assert_empty_trace, trace, chain_vmap_axes(trace), @@ -161,8 +159,10 @@ def check_trace(trace): check_trace(burnin_trace) check_trace(main_trace) - def test_jit_error(self, keys: split, subtests: SubTests): - """Check that an error is raised under jit in some conditions.""" + def test_predicted_double_compilation( + self, keys: split, subtests: SubTests + ) -> None: + """Check that an error is raised under jit if the configuration would lead to double compilation.""" initial_state = simple_init(10, 100, 20) compiled_run_mcmc = jit( @@ -180,3 +180,17 @@ def test_jit_error(self, keys: split, subtests: SubTests): pytest.raises(RuntimeError, match=msg), ): compiled_run_mcmc(keys.pop(), initial_state, 1) + + def test_detected_double_compilation(self, keys: split) -> None: + """Check that double compilation is detected.""" + state = simple_init(10, 100, 20) + + mesh = make_mesh((1,), ('a',), axis_types=(AxisType.Auto,)) + sharding = NamedSharding(mesh, PartitionSpec()) + resid = device_put(state.resid, sharding) + state = replace(state, resid=resid) + + with pytest.raises( + RuntimeError, match='The inner loop of `run_mcmc` was traced more than once' + ): + run_mcmc(keys.pop(), state, 2, inner_loop_length=1) diff --git a/tests/test_mcmcstep.py b/tests/test_mcmcstep.py index 3bd88a0c..20975bf0 100644 --- a/tests/test_mcmcstep.py +++ b/tests/test_mcmcstep.py @@ -26,18 +26,16 @@ from collections.abc import Sequence from math import prod -from typing import Literal +from typing import Literal, NamedTuple import jax import pytest from beartype import beartype -from jax import debug_key_reuse, make_mesh, random, vmap +from jax import debug_key_reuse, make_mesh, random, tree, vmap from jax import numpy as jnp -from jax.random import bernoulli, clone, normal, permutation, randint from jax.sharding import AxisType, Mesh, PartitionSpec, SingleDeviceSharding -from jax.tree import map_with_path -from jax.tree_util import KeyPath -from jaxtyping import Array, Bool, Int32, Key, PyTree, jaxtyped +from jax.tree_util import KeyPath, keystr +from jaxtyping import Array, Bool, Int32, Key, PyTree, UInt8, jaxtyped from numpy.testing import assert_array_equal from pytest_subtests import SubTests from scipy import stats @@ -55,6 +53,21 @@ from tests.util import assert_close_matrices, manual_tree +class VarTreeData(NamedTuple): + """Fixture data pairing a variable tree with its max-split array.""" + + var_tree: UInt8[Array, ' nodes'] + max_split: UInt8[Array, ' p'] + + +class SplitRangeData(NamedTuple): + """Fixture data pairing variable/split trees with a max-split array.""" + + var_tree: UInt8[Array, ' nodes'] + split_tree: UInt8[Array, ' nodes'] + max_split: UInt8[Array, ' p'] + + def vmap_randint_masked( key: Key[Array, ''], mask: Bool[Array, ' n'], size: int ) -> Int32[Array, '* n']: @@ -67,26 +80,26 @@ def vmap_randint_masked( class TestRandintMasked: """Test `mcmcstep.randint_masked`.""" - def test_all_false(self, keys): + def test_all_false(self, keys: split) -> None: """Check what happens when no value is allowed.""" for size in range(1, 10): u = randint_masked(keys.pop(), jnp.zeros(size, bool)) assert u == size - def test_all_true(self, keys): + def test_all_true(self, keys: split) -> None: """Check it's equivalent to `randint` when all values are allowed.""" key = keys.pop() size = 10_000 u1 = randint_masked(key, jnp.ones(size, bool)) - u2 = randint(clone(key), (), 0, size) + u2 = random.randint(random.clone(key), (), 0, size) assert u1 == u2 - def test_no_disallowed_values(self, keys): + def test_no_disallowed_values(self, keys: split) -> None: """Check disallowed values are never selected.""" key = keys.pop() for _ in range(100): keys = split(key, 3) - mask = bernoulli(keys.pop(), 0.5, (10,)) + mask = random.bernoulli(keys.pop(), 0.5, (10,)) if not jnp.any(mask): # pragma: no cover, rarely happens continue u = randint_masked(keys.pop(), mask) @@ -94,14 +107,14 @@ def test_no_disallowed_values(self, keys): assert mask[u] key = keys.pop() - def test_correct_distribution(self, keys): + def test_correct_distribution(self, keys: split) -> None: """Check the distribution of values is uniform.""" # create mask num_allowed = 10 mask = jnp.zeros(2 * num_allowed, bool) mask = mask.at[:num_allowed].set(True) indices = jnp.arange(mask.size) - indices = permutation(keys.pop(), indices) + indices = random.permutation(keys.pop(), indices) mask = mask[indices] # sample values @@ -123,7 +136,7 @@ class TestAncestorVariables: """Test `mcmcstep._moves.ancestor_variables`.""" @pytest.fixture - def depth2_tree(self): + def depth2_tree(self) -> VarTreeData: R""" Tree with var_tree of size 4 (tree_depth=2, max_num_ancestors=1). @@ -141,10 +154,10 @@ def depth2_tree(self): ) var_tree = tree.var_tree.astype(jnp.uint8) max_split = jnp.full(5, 10, jnp.uint8) - return var_tree, max_split + return VarTreeData(var_tree, max_split) @pytest.fixture - def depth3_tree(self): + def depth3_tree(self) -> VarTreeData: """ Tree with var_tree of size 8 (tree_depth=3, max_num_ancestors=2). @@ -157,9 +170,9 @@ def depth3_tree(self): ) var_tree = tree.var_tree.astype(jnp.uint8) max_split = jnp.full(10, 10, jnp.uint8) - return var_tree, max_split + return VarTreeData(var_tree, max_split) - def test_root_node(self, depth2_tree): + def test_root_node(self, depth2_tree: VarTreeData) -> None: """Check that root node has no ancestors (all slots filled with p).""" var_tree, max_split = depth2_tree @@ -169,7 +182,7 @@ def test_root_node(self, depth2_tree): # All slots should be p (sentinel) since root has no ancestors assert_array_equal(result, [max_split.size]) - def test_child_of_root(self, depth2_tree): + def test_child_of_root(self, depth2_tree: VarTreeData) -> None: """Check that children of root have one ancestor (the root's variable).""" var_tree, max_split = depth2_tree @@ -182,7 +195,7 @@ def test_child_of_root(self, depth2_tree): result = ancestor_variables(var_tree, max_split, jnp.int32(3)) assert_array_equal(result, [2]) - def test_deep_node(self, depth3_tree): + def test_deep_node(self, depth3_tree: VarTreeData) -> None: """Check ancestors for nodes at depth 3.""" var_tree, max_split = depth3_tree @@ -203,7 +216,7 @@ def test_deep_node(self, depth3_tree): result = ancestor_variables(var_tree, max_split, jnp.int32(7)) assert_array_equal(result, [3, 1]) - def test_intermediate_node(self, depth3_tree): + def test_intermediate_node(self, depth3_tree: VarTreeData) -> None: """Check ancestors for an intermediate (non-leaf) node.""" var_tree, max_split = depth3_tree @@ -215,7 +228,7 @@ def test_intermediate_node(self, depth3_tree): result = ancestor_variables(var_tree, max_split, jnp.int32(3)) assert_array_equal(result, [max_split.size, 3]) - def test_single_variable(self): + def test_single_variable(self) -> None: """Check with only one variable (p=1).""" tree = manual_tree( [[0.0], [0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], [[0], [0, 0]], [[4], [3, 5]] @@ -231,7 +244,7 @@ def test_single_variable(self): result = ancestor_variables(var_tree, max_split, jnp.int32(1)) assert_array_equal(result, [max_split.size]) - def test_type_edge(self, depth3_tree): + def test_type_edge(self, depth3_tree: VarTreeData) -> None: """Check that types are handled correctly when using uint8 and uint16 together.""" var_tree, max_split = depth3_tree var_tree = var_tree.astype(jnp.uint8) @@ -250,16 +263,16 @@ def test_type_edge(self, depth3_tree): class TestRandintExclude: """Test `mcmcstep._moves.randint_exclude`.""" - def test_empty_exclude(self, keys): + def test_empty_exclude(self, keys: split) -> None: """If exclude is empty, it's equivalent to randint(key, (), 0, sup).""" key = keys.pop() sup = 10_000 u1, num_allowed = randint_exclude(key, sup, jnp.array([], jnp.int32)) - u2 = randint(clone(key), (), 0, sup) + u2 = random.randint(random.clone(key), (), 0, sup) assert num_allowed == sup assert u1 == u2 - def test_exclude_out_of_range_is_ignored(self, keys): + def test_exclude_out_of_range_is_ignored(self, keys: split) -> None: """Values >= sup are ignored for both u and num_allowed.""" key = keys.pop() sup = 7 @@ -268,7 +281,7 @@ def test_exclude_out_of_range_is_ignored(self, keys): assert num_allowed == sup assert 0 <= u < sup - def test_duplicate_excludes_ignored(self, keys): + def test_duplicate_excludes_ignored(self, keys: split) -> None: """Duplicates should be de-duplicated (set semantics for allowed count).""" sup = 10 exclude_with_dupes = jnp.array([1, 1, 1, 3, 3, 9]) @@ -280,7 +293,7 @@ def test_duplicate_excludes_ignored(self, keys): assert u1 == u2 assert n1 == n2 == (sup - 3) - def test_all_values_excluded_returns_sup(self, keys): + def test_all_values_excluded_returns_sup(self, keys: split) -> None: """If all values are excluded, u must be sup and num_allowed=0.""" for sup in range(1, 30, 5): exclude = jnp.arange(sup) @@ -288,13 +301,13 @@ def test_all_values_excluded_returns_sup(self, keys): assert num_allowed == 0 assert u == sup - def test_never_returns_excluded_values(self, keys): + def test_never_returns_excluded_values(self, keys: split) -> None: """Across repeated sampling, u is always in [0,sup) and not excluded, unless num_allowed=0.""" sup = 20 reps = 200 # Use a fixed-length exclude array; include invalid values so masking paths are hit. - exclude = randint(keys.pop(), (reps, 30), 0, sup + 10) + exclude = random.randint(keys.pop(), (reps, 30), 0, sup + 10) randint_exclude_v = vmap(randint_exclude, in_axes=(0, None, 0)) keys_v = keys.pop(reps) u, num_allowed = randint_exclude_v(keys_v, sup, exclude) @@ -310,12 +323,14 @@ def test_never_returns_excluded_values(self, keys): ) ) - def test_num_allowed_matches_count(self, keys): + def test_num_allowed_matches_count(self, keys: split) -> None: """num_allowed must match sup - |unique(exclude ∩ [0,sup))|.""" sup = 50 reps = 50 - exclude = randint(keys.pop(), (reps, 80), 0, sup + 25) # includes some >= sup + exclude = random.randint( + keys.pop(), (reps, 80), 0, sup + 25 + ) # includes some >= sup randint_exclude_v = vmap(randint_exclude, in_axes=(0, None, 0)) keys_v = keys.pop(reps) @@ -332,7 +347,7 @@ def test_num_allowed_matches_count(self, keys): assert jnp.all(num_allowed == expected_num_allowed) - def test_correct_distribution_single_excluded(self, keys): + def test_correct_distribution_single_excluded(self, keys: split) -> None: """ With one excluded value, u should be uniform over the remaining sup-1 values. @@ -366,14 +381,14 @@ class TestSplitRange: """Test `mcmcstep._moves.split_range`.""" @pytest.fixture - def max_split(self): + def max_split(self) -> UInt8[Array, ' p']: """Maximum split indices for 3 variables.""" # max_split[v] = maximum split index for variable v # split_range returns [l, r) in *1-based* split indices, so initial r = 1 + max_split[v] return jnp.array([10, 10, 10], dtype=jnp.uint8) @pytest.fixture - def depth3_tree(self, max_split): + def depth3_tree(self, max_split: UInt8[Array, ' p']) -> SplitRangeData: R""" Small depth-3 tree (var_tree size 8 => nodes 1..7 exist). @@ -393,9 +408,9 @@ def depth3_tree(self, max_split): ) var_tree = tree.var_tree.astype(jnp.uint8) split_tree = tree.split_tree.astype(jnp.uint8) - return var_tree, split_tree, max_split + return SplitRangeData(var_tree, split_tree, max_split) - def test_dtypes(self, depth3_tree): + def test_dtypes(self, depth3_tree: SplitRangeData) -> None: """Check the output types.""" var_tree, split_tree, max_split = depth3_tree l, r = split_range( @@ -404,7 +419,7 @@ def test_dtypes(self, depth3_tree): assert l.dtype == jnp.int32 assert r.dtype == jnp.int32 - def test_ref_var_out_of_bounds(self, depth3_tree): + def test_ref_var_out_of_bounds(self, depth3_tree: SplitRangeData) -> None: """If ref_var is out of bounds, l=r=1.""" var_tree, split_tree, max_split = depth3_tree l, r = split_range( @@ -413,7 +428,7 @@ def test_ref_var_out_of_bounds(self, depth3_tree): assert l == 1 assert r == 1 - def test_root_node_no_constraints(self, depth3_tree): + def test_root_node_no_constraints(self, depth3_tree: SplitRangeData) -> None: """Root has no ancestors => range should be the full [1, 1+max_split[var]).""" var_tree, split_tree, max_split = depth3_tree @@ -422,7 +437,9 @@ def test_root_node_no_constraints(self, depth3_tree): assert l == 1 assert r == 1 + max_split[0] - def test_unrelated_variable_no_constraints(self, depth3_tree): + def test_unrelated_variable_no_constraints( + self, depth3_tree: SplitRangeData + ) -> None: """If ancestors don't use ref_var, range should be full [1, 1+max_split[ref_var]).""" var_tree, split_tree, max_split = depth3_tree @@ -432,7 +449,7 @@ def test_unrelated_variable_no_constraints(self, depth3_tree): assert l == 1 assert r == 1 + max_split[2] - def test_left_child_sets_upper_bound(self, depth3_tree): + def test_left_child_sets_upper_bound(self, depth3_tree: SplitRangeData) -> None: """For left subtree of an ancestor split on ref_var, r should be tightened to that split.""" var_tree, split_tree, max_split = depth3_tree @@ -442,7 +459,7 @@ def test_left_child_sets_upper_bound(self, depth3_tree): assert l == 1 assert r == 5 - def test_right_child_sets_lower_bound(self, depth3_tree): + def test_right_child_sets_lower_bound(self, depth3_tree: SplitRangeData) -> None: """For right subtree of an ancestor split on ref_var, l should be raised to that split+1.""" var_tree, split_tree, max_split = depth3_tree @@ -452,7 +469,7 @@ def test_right_child_sets_lower_bound(self, depth3_tree): assert l == 6 assert r == 1 + max_split[0] - def test_two_ancestors_combine_bounds(self, depth3_tree): + def test_two_ancestors_combine_bounds(self, depth3_tree: SplitRangeData) -> None: """Bounds from multiple ancestors on the same variable should combine (max lower, min upper).""" var_tree, split_tree, max_split = depth3_tree @@ -463,7 +480,9 @@ def test_two_ancestors_combine_bounds(self, depth3_tree): assert l == 6 assert r == 8 - def test_ref_var_constraints_from_parent_only(self, depth3_tree): + def test_ref_var_constraints_from_parent_only( + self, depth3_tree: SplitRangeData + ) -> None: """If only a deeper ancestor matches ref_var, constraints should come only from those matches.""" var_tree, split_tree, max_split = depth3_tree @@ -474,7 +493,9 @@ def test_ref_var_constraints_from_parent_only(self, depth3_tree): assert l == 1 assert r == 7 - def test_no_allowed_splits_when_bounds_cross(self, max_split): + def test_no_allowed_splits_when_bounds_cross( + self, max_split: UInt8[Array, ' p'] + ) -> None: """ If constraints make the interval empty, l can become >= r. @@ -499,7 +520,7 @@ def test_no_allowed_splits_when_bounds_cross(self, max_split): assert l == 9 assert r == 3 - def test_minimal_tree(self): + def test_minimal_tree(self) -> None: """Test the minimal tree.""" # We want the shortest possible `var_tree`/`split_tree` arrays that still # represent a valid tree for the function: @@ -529,14 +550,14 @@ def init_kwargs(self, keys: split) -> dict: numcut = 10 num_trees = 5 return dict( - X=randint(keys.pop(), (p, self.n), 0, numcut + 1, jnp.uint32), - y=normal(keys.pop(), (k, self.n)), - offset=normal(keys.pop(), (k,)), + X=random.randint(keys.pop(), (p, self.n), 0, numcut + 1, jnp.uint32), + y=random.normal(keys.pop(), (k, self.n)), + offset=random.normal(keys.pop(), (k,)), max_split=jnp.full(p, numcut + 1, jnp.uint32), num_trees=num_trees, p_nonterminal=jnp.full(d - 1, 0.9), leaf_prior_cov_inv=jnp.eye(k) * num_trees, - error_cov_df=2.0, + error_cov_df=2.0, # keep this a weak type error_cov_scale=2 * jnp.eye(k), ) @@ -549,7 +570,7 @@ def test_basic( shard_data: bool, subtests: SubTests, keys: split, - ): + ) -> None: """Create a multichain `State` with `init` and step it once.""" mesh = {} @@ -576,6 +597,7 @@ def test_basic( typechecking_init = jaxtyped(init, typechecker=beartype) state = typechecking_init(**init_kwargs, num_chains=num_chains, mesh=mesh) assert state.forest.num_chains() == num_chains + check_strong_types(state) check_sharding(state, state.config.mesh) with subtests.test('step'): @@ -584,12 +606,13 @@ def test_basic( # key reuse checks trigger with empty key array apparently new_state = typechecking_step(keys.pop(), state) assert new_state.forest.num_chains() == num_chains + check_strong_types(new_state) check_sharding(new_state, state.config.mesh) @pytest.mark.parametrize('profile', [False, True]) def test_multichain_equiv_stack( self, init_kwargs: dict, keys: split, profile: bool - ): + ) -> None: """Check that stacking multiple chains is equivalent to a multichain trace.""" num_chains = 4 num_iters = 10 @@ -632,13 +655,13 @@ def stack_leaf( return jnp.stack(sc_xs, axis=chain_axis) chain_axes = chain_vmap_axes(mc_state) - stacked_state = map_with_path( + stacked_state = tree.map_with_path( stack_leaf, chain_axes, mc_state, *sc_states, is_leaf=lambda x: x is None ) # check the mc state is equal to the stacked state - def check_equal(path: KeyPath, mc: Array, stacked: Array): - str_path = ''.join(map(str, path)) + def check_equal(path: KeyPath, mc: Array, stacked: Array) -> None: + str_path = keystr(path) exact = mc.platform() == 'cpu' or jnp.issubdtype(mc.dtype, jnp.integer) assert_close_matrices( mc, @@ -648,12 +671,12 @@ def check_equal(path: KeyPath, mc: Array, stacked: Array): reduce_rank=True, ) - map_with_path(check_equal, mc_state, stacked_state) + tree.map_with_path(check_equal, mc_state, stacked_state) def chain_vmap_axes(self, state: State) -> State: """Old manual version of `chain_vmap_axes(_: State)`.""" - def choose_vmap_index(path, _) -> Literal[0, None]: + def choose_vmap_index(path: KeyPath, _: Array) -> Literal[0, None]: no_vmap_attrs = ( '.X', '.y', @@ -674,18 +697,17 @@ def choose_vmap_index(path, _) -> Literal[0, None]: '.config.sparse_on_at', '.config.steps_done', ) - str_path = ''.join(map(str, path)) - if str_path in no_vmap_attrs: + if keystr(path) in no_vmap_attrs: return None else: return 0 - return map_with_path(choose_vmap_index, state) + return tree.map_with_path(choose_vmap_index, state) def data_vmap_axes(self, state: State) -> State: """Hardcoded version of `data_vmap_axes(_: State)`.""" - def choose_vmap_index(path: KeyPath, _) -> Literal[-1, None]: + def choose_vmap_index(path: KeyPath, _: Array) -> Literal[-1, None]: vmap_attrs = ( '.X', '.y', @@ -694,15 +716,14 @@ def choose_vmap_index(path: KeyPath, _) -> Literal[-1, None]: '.prec_scale', '.forest.leaf_indices', ) - str_path = ''.join(map(str, path)) - if str_path in vmap_attrs: + if keystr(path) in vmap_attrs: return -1 else: return None - return map_with_path(choose_vmap_index, state) + return tree.map_with_path(choose_vmap_index, state) - def test_vmap_axes(self, init_kwargs: dict): + def test_vmap_axes(self, init_kwargs: dict) -> None: """Check `data_vmap_axes` and `chain_vmap_axes` on a `State`.""" state = init(**init_kwargs) @@ -712,37 +733,39 @@ def test_vmap_axes(self, init_kwargs: dict): ref_chain_axes = self.chain_vmap_axes(state) ref_data_axes = self.data_vmap_axes(state) - def assert_equal(_path: KeyPath, axis: int | None, ref_axis: int | None): + def assert_equal( + _path: KeyPath, axis: int | None, ref_axis: int | None + ) -> None: assert axis == ref_axis - map_with_path(assert_equal, chain_axes, ref_chain_axes) - map_with_path(assert_equal, data_axes, ref_data_axes) + tree.map_with_path(assert_equal, chain_axes, ref_chain_axes) + tree.map_with_path(assert_equal, data_axes, ref_data_axes) - def test_normalize_spec(self): + def test_normalize_spec(self) -> None: """Test `normalize_spec`.""" devices = jax.devices('cpu')[:3] mesh = make_mesh( - (3, 1), + (len(devices), 1), ('ciao', 'bau'), axis_types=(AxisType.Auto, AxisType.Auto), devices=devices, ) assert normalize_spec(['ciao'], mesh, (1, 1, 1)) == PartitionSpec( - 'ciao', None, None + 'ciao' if len(devices) > 1 else None, None, None ) assert normalize_spec([None, 'bau'], mesh, (1, 1)) == PartitionSpec(None, None) assert normalize_spec(['ciao'], mesh, (0,)) == PartitionSpec(None) assert normalize_spec([None, 'ciao'], mesh, (0, 1)) == PartitionSpec(None, None) -def check_sharding(x: PyTree, mesh: Mesh | None): +def check_sharding(x: PyTree, mesh: Mesh | None) -> None: """Check that chains and data are sharded as expected.""" chain_axes = chain_vmap_axes(x) data_axes = data_vmap_axes(x) def check_leaf( _path: KeyPath, x: Array | None, chain_axis: int | None, data_axis: int | None - ): + ) -> None: if x is None: return elif mesh is None: @@ -759,7 +782,9 @@ def check_leaf( assert spec == expected_spec - map_with_path(check_leaf, x, chain_axes, data_axes, is_leaf=lambda x: x is None) + tree.map_with_path( + check_leaf, x, chain_axes, data_axes, is_leaf=lambda x: x is None + ) def get_normal_spec(x: Array) -> PartitionSpec: @@ -788,3 +813,12 @@ def normalize_spec( assert len(s) == ndim return PartitionSpec(*s) + + +def check_strong_types(x: PyTree[Array]) -> None: + """Check all arrays in `x` have strong types.""" + + def check_leaf(path: KeyPath, x: Array) -> None: + assert not x.weak_type, f'{keystr(path)} has weak type' + + tree.map_with_path(check_leaf, x) diff --git a/tests/test_meta.py b/tests/test_meta.py index ab8c3e03..0dfb3ab5 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -30,26 +30,29 @@ from jax import debug_nans, jit, random from jax import numpy as jnp from jax.errors import KeyReuseError +from jaxtyping import Array, Float, Key + +from bartz.jaxext import split @pytest.fixture -def keys1(keys): +def keys1(keys: split) -> split: """Pass-through the `keys` fixture.""" return keys @pytest.fixture -def keys2(keys): +def keys2(keys: split) -> split: """Pass-through the `keys` fixture.""" return keys -def test_random_keys_do_not_depend_on_fixture(keys1, keys2): +def test_random_keys_do_not_depend_on_fixture(keys1: split, keys2: split) -> None: """Check that the `keys` fixture is per-test-case, not per-fixture.""" assert keys1 is keys2 -def test_number_of_random_keys(keys): +def test_number_of_random_keys(keys: split) -> None: """Check the fixed number of available keys. This is here just as reference for the `test_random_keys_are_consumed` test @@ -59,21 +62,25 @@ def test_number_of_random_keys(keys): @pytest.fixture -def consume_one_key(keys): # noqa: D103 +def consume_one_key(keys: split) -> Key[Array, '']: # noqa: D103 return keys.pop() @pytest.fixture -def consume_another_key(keys): # noqa: D103 +def consume_another_key(keys: split) -> Key[Array, '']: # noqa: D103 return keys.pop() -def test_random_keys_are_consumed(consume_one_key, consume_another_key, keys): # noqa: ARG001 +def test_random_keys_are_consumed( + consume_one_key: Key[Array, ''], # noqa: ARG001 + consume_another_key: Key[Array, ''], # noqa: ARG001 + keys: split, +) -> None: """Check that the random keys in `keys` can't be used more than once.""" assert len(keys) == 126 -def test_debug_key_reuse(keys): +def test_debug_key_reuse(keys: split) -> None: """Check that the jax debug_key_reuse option works.""" key = keys.pop() random.uniform(key) @@ -81,11 +88,11 @@ def test_debug_key_reuse(keys): random.uniform(key) -def test_debug_key_reuse_within_jit(keys): +def test_debug_key_reuse_within_jit(keys: split) -> None: """Check that the jax debug_key_reuse option works within a jitted function.""" @jit - def func(key): + def func(key: Key[Array, '']) -> Float[Array, '']: return random.uniform(key) + random.uniform(key) with pytest.raises(KeyReuseError): @@ -95,7 +102,7 @@ def func(key): class TestJaxNoCopyBehavior: """Check whether jax makes actual copies of arrays in various conditions.""" - def test_unconditional_buffer_donation(self): + def test_unconditional_buffer_donation(self) -> None: """Test jax donates buffers even if they are small.""" # donation disabled under debug_nans, see jax/issues/#33949 with debug_nans(False): @@ -104,7 +111,7 @@ def test_unconditional_buffer_donation(self): xp = x.unsafe_buffer_pointer() @partial(jit, donate_argnums=(0,)) - def noop(x): + def noop(x: Array) -> Array: return x y = noop(x) @@ -114,7 +121,7 @@ def noop(x): with pytest.raises(RuntimeError, match=r'delete'): x[0] - def test_jnp_array_copy_no_jit(self): + def test_jnp_array_copy_no_jit(self) -> None: """Test jnp.array makes copies outside jitted functions.""" y = jnp.arange(100) yp = y.unsafe_buffer_pointer() @@ -124,7 +131,7 @@ def test_jnp_array_copy_no_jit(self): assert zp != yp - def test_jnp_array_no_copy_jit(self): + def test_jnp_array_no_copy_jit(self) -> None: """Check jnp.array does not make copies within jit.""" # donation disabled under debug_nans, see jax/issues/#33949 with debug_nans(False): @@ -132,7 +139,7 @@ def test_jnp_array_no_copy_jit(self): yp = y.unsafe_buffer_pointer() @partial(jit, donate_argnums=(0,)) - def array(x): + def array(x: Array) -> Array: return jnp.array(x) q = array(y) diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 154950fd..3d2cbfde 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -25,13 +25,17 @@ """Test multivariate BART components.""" from dataclasses import replace +from typing import NamedTuple import pytest from jax import numpy as jnp from jax import random, vmap +from jaxtyping import Array, Float, Float32, Int32, Key, UInt32 from numpy.testing import assert_allclose, assert_array_equal +from pytest import FixtureRequest # noqa: PT013 from scipy.stats import chi2, ks_1samp, ks_2samp +from bartz.jaxext import split from bartz.mcmcstep import State, init, step from bartz.mcmcstep._step import ( Counts, @@ -49,51 +53,58 @@ from tests.util import assert_close_matrices +class Data(NamedTuple): + """Toy dataset for testing.""" + + X: Int32[Array, 'p n'] + y: Float32[Array, ' n'] + max_split: UInt32[Array, ' p'] + + class TestWishart: """Test the basic properties of the wishart sampler output.""" # Parameterize with (k, df) pairs @pytest.fixture(params=[(1, 3), (3, 3), (3, 5), (3, 100), (100, 102)]) - def wishart_params(self, request): + def wishart_params(self, request: FixtureRequest) -> tuple[int, int]: """Provide (k, df) pairs for testing.""" k, df = request.param return k, df - def random_pd_matrix(self, key, k): + def random_pd_matrix(self, key: Key[Array, ''], k: int) -> Float[Array, '{k} {k}']: """Generate a random positive definite matrix.""" A = random.normal(key, (k, k)) return A @ A.T + jnp.eye(k) - def ill_conditioned_matrix(self, key, k, condition_number=1e6, exact_psd=True): + def ill_conditioned_matrix( + self, key: Key[Array, ''], k: int, condition_number: float = 1e6 + ) -> Float[Array, '{k} {k}']: """Generate a ill conditioned random positive semi-definite matrix.""" A = random.normal(key, (k, k)) U, _ = jnp.linalg.qr(A) - if exact_psd: - if k == 1: - eigs = jnp.zeros(1) - else: - smalls = jnp.geomspace(1.0, 1.0 / condition_number, num=k - 1) - eigs = jnp.concatenate([smalls, jnp.array([0.0])]) + if k == 1: + eigs = jnp.zeros(1) else: - eigs = jnp.geomspace(1.0, 1.0 / condition_number, num=k) + smalls = jnp.geomspace(1.0, 1.0 / condition_number, num=k - 1) + eigs = jnp.concatenate([smalls, jnp.array([0.0])]) return (U * eigs) @ U.T - def test_size(self, keys, wishart_params): + def test_size(self, keys: split, wishart_params: tuple[int, int]) -> None: """Check that the sample generated by wishart sampler is of shape k*k.""" k, df = wishart_params scale = self.random_pd_matrix(keys.pop(), k) sample = _sample_wishart_bartlett(keys.pop(), df, scale) assert sample.shape == (k, k) - def test_symmetric(self, keys, wishart_params): + def test_symmetric(self, keys: split, wishart_params: tuple[int, int]) -> None: """Check that the sample generated by wishart sampler is symmetric.""" k, df = wishart_params scale = self.random_pd_matrix(keys.pop(), k) sample = _sample_wishart_bartlett(keys.pop(), df, scale) assert_close_matrices(sample, sample.T, rtol=1e-6) - def test_pos_def(self, keys, wishart_params): + def test_pos_def(self, keys: split, wishart_params: tuple[int, int]) -> None: """Check that the sample generated by wishart sampler is positive definite.""" k, df = wishart_params scale = self.random_pd_matrix(keys.pop(), k) @@ -101,25 +112,27 @@ def test_pos_def(self, keys, wishart_params): eigs = jnp.linalg.eigvalsh(sample) assert jnp.all(eigs > 0) - def test_near_singular_scale(self, keys, wishart_params): + def test_near_singular_scale( + self, keys: split, wishart_params: tuple[int, int] + ) -> None: """Check that the wishart sampler still works with singular or near singular matrix.""" k, df = wishart_params ill_conditioned_scale = self.ill_conditioned_matrix(keys.pop(), k) sample = _sample_wishart_bartlett(keys.pop(), df, ill_conditioned_scale) assert jnp.all(jnp.isfinite(sample)) - def test_wishart_dist(self, keys, wishart_params): + def test_wishart_dist(self, keys: split, wishart_params: tuple[int, int]) -> None: """Check that the sample generated by wishart sampler follows a wishart distribution.""" k, df = wishart_params sigma = self.random_pd_matrix(keys.pop(), k) scale_inv = jnp.linalg.inv(sigma) a = random.normal(keys.pop(), (k,)) - denumerator = a.T @ sigma @ a + denominator = a.T @ sigma @ a sampler = vmap(_sample_wishart_bartlett, in_axes=(0, None, None)) W = sampler(keys.pop(1000), float(df), scale_inv) - t = jnp.einsum('ijk,j,k->i', W, a, a) / denumerator + t = jnp.einsum('ijk,j,k->i', W, a, a) / denominator test = ks_1samp(t, chi2(df).cdf) assert test.pvalue > 0.01 @@ -129,16 +142,16 @@ class TestPrecomputeTerms: """Test _precompute_likelihood_terms_mv and _precompute_leaf_terms_mv correctness and stability.""" @pytest.fixture(params=[1, 2, 5, 10]) - def k(self, request): + def k(self, request: FixtureRequest) -> int: """Provide different ks for testing.""" return request.param - def random_pd_matrix(self, key, k): + def random_pd_matrix(self, key: Key[Array, ''], k: int) -> Float[Array, '{k} {k}']: """Generate a random positive definite matrix.""" A = random.normal(key, (k, k)) return A @ A.T + jnp.eye(k) - def test_shapes_leaf(self, keys, k): + def test_shapes_leaf(self, keys: split, k: int) -> None: """Check that shapes of outputs are correct.""" num_trees, tree_size = 3, 4 prec_trees = jnp.ones((num_trees, tree_size)) @@ -151,7 +164,7 @@ def test_shapes_leaf(self, keys, k): assert result.mean_factor.shape == (num_trees, k, k, tree_size) assert result.centered_leaves.shape == (num_trees, k, tree_size) - def test_likelihood_equiv(self, keys): + def test_likelihood_equiv(self, keys: split) -> None: """Check that _compute_likelihood_ratio_uv and _compute_likelihood_ratio_mv agree when k = 1.""" inv_sigma2 = random.uniform(keys.pop(), (), minval=0.1, maxval=5.0) leaf_prior_cov_inv_uv = random.uniform(keys.pop(), (), minval=0.1, maxval=5.0) @@ -187,7 +200,7 @@ def test_likelihood_equiv(self, keys): ) assert_allclose(likelihood_mv, likelihood_uv, rtol=1e-6, atol=1e-6) - def test_leaf_terms_equiv(self, keys): + def test_leaf_terms_equiv(self, keys: split) -> None: """Check that _precompute_leaf_terms_uv and _precompute_leaf_terms_mv agree when k = 1.""" num_trees, tree_size = 2, 3 inv_sigma2 = random.uniform(keys.pop(), (), minval=0.1, maxval=5.0) @@ -221,25 +234,25 @@ def test_leaf_terms_equiv(self, keys): @pytest.fixture(params=[(10, 2), (20, 5), (3, 100), (50, 50)]) -def data_shape(request): +def data_shape(request: FixtureRequest) -> tuple[int, int]: """Provide (n, p) pairs for testing.""" return request.param @pytest.fixture -def data(data_shape): +def data(data_shape: tuple[int, int]) -> Data: """Generate a toy dataset.""" n, p = data_shape X = jnp.arange(n * p).reshape(p, n) y = jnp.linspace(-1, 1, n) max_split = jnp.full(p, 5, dtype=jnp.uint32) - return X, y, max_split + return Data(X, y, max_split) class TestMVBartIntegration: """Test equivalence between Univariate and Multivariate (k=1) modes.""" - def test_init_equivalence(self, data): + def test_init_equivalence(self, data: Data) -> None: """Test that init produces compatible structures for UV and MV(k=1).""" X, y, max_split = data y_mv = y[None, :] @@ -295,7 +308,7 @@ def test_init_equivalence(self, data): assert_array_equal(bart_uv.forest.p_propose_grow, bart_mv.forest.p_propose_grow) assert_array_equal(bart_uv.forest.affluence_tree, bart_mv.forest.affluence_tree) - def test_step_sigma_distribution_match(self, keys, data): + def test_step_sigma_distribution_match(self, keys: split, data: Data) -> None: """ Test that _step_error_cov_inv_uv and _step_error_cov_inv_mv (k = 1) sample from the same posterior. @@ -337,10 +350,10 @@ def test_step_sigma_distribution_match(self, keys, data): config=None, ) - def sample_uv(k): + def sample_uv(k: Key[Array, '']) -> Float32[Array, '']: return _step_error_cov_inv_uv(k, st_uv).error_cov_inv - def sample_mv(k): + def sample_mv(k: Key[Array, '']) -> Float32[Array, '']: return _step_error_cov_inv_mv(k, st_mv).error_cov_inv.reshape(()) n_samples = 10000 @@ -356,7 +369,7 @@ def sample_mv(k): class TestMVBartSteps: """Test the full MCMC step trajectory (init + multiple steps).""" - def test_step_trees_exact_match(self, keys, data): + def test_step_trees_exact_match(self, keys: split, data: Data) -> None: """Test that MV tree logic is Identical to UV logic.""" X, y, max_split = data y_mv = y[None, :] @@ -436,7 +449,7 @@ def test_step_trees_exact_match(self, keys, data): uv_state.forest.prune_acc_count, mv_state.forest.prune_acc_count ) - def test_mv_steps(self, keys, data): + def test_mv_steps(self, keys: split, data: Data) -> None: """Test that mv mode can run without crashing.""" X, y_uv, max_split = data k = 3 @@ -458,7 +471,7 @@ def test_mv_steps(self, keys, data): count_num_batches=None, ) - for key in random.split(keys.pop(), 10): + for key in keys.pop(10): mv_state = step(key, mv_state) assert jnp.all(jnp.isfinite(mv_state.resid)) diff --git a/tests/test_prepcovars.py b/tests/test_prepcovars.py index bc212ad9..511e7820 100644 --- a/tests/test_prepcovars.py +++ b/tests/test_prepcovars.py @@ -1,6 +1,6 @@ # bartz/tests/test_prepcovars.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -38,7 +38,7 @@ class TestQuantilizer: @pytest.mark.parametrize( 'fill_value', [jnp.finfo(jnp.float32).max, jnp.iinfo(jnp.int32).max] ) - def test_splits_fill(self, fill_value): + def test_splits_fill(self, fill_value: float | int) -> None: """Check how predictors with less unique values are right-padded.""" with debug_infs(not jnp.isinf(fill_value)): fill_value = jnp.array(fill_value) @@ -47,13 +47,13 @@ def test_splits_fill(self, fill_value): expected_splits = [[2, fill_value, fill_value], [2, 4, fill_value], [2, 4, 6]] assert_array_equal(splits, expected_splits) - def test_max_splits(self): + def test_max_splits(self) -> None: """Check that the number of splits per predictor is counted correctly.""" x = jnp.array([[1, 1, 1, 1], [4, 4, 1, 1], [2, 1, 3, 2], [1, 4, 2, 3]]) _, max_split = quantilized_splits_from_matrix(x, 100) assert_array_equal(max_split, jnp.arange(4)) - def test_integer_splits_overflow(self): + def test_integer_splits_overflow(self) -> None: """Check that the splits are computed correctly at the limit of overflow.""" x = jnp.array([[-(2**31), 2**31 - 2]]) splits, _ = quantilized_splits_from_matrix(x, 100) @@ -61,13 +61,13 @@ def test_integer_splits_overflow(self): assert_array_equal(splits, expected_splits) @pytest.mark.parametrize('dtype', [int, float]) - def test_splits_type(self, dtype): + def test_splits_type(self, dtype: type) -> None: """Check that the input type is preserved.""" x = jnp.arange(10, dtype=dtype)[None, :] splits, _ = quantilized_splits_from_matrix(x, 100) assert splits.dtype == x.dtype - def test_splits_length(self): + def test_splits_length(self) -> None: """Check that the correct number of splits is returned in corner cases.""" x = jnp.linspace(0, 1, 10)[None, :] @@ -83,33 +83,33 @@ def test_splits_length(self): no_splits, _ = quantilized_splits_from_matrix(x, 1) assert no_splits.shape == (1, 0) - def test_round_trip(self): + def test_round_trip(self) -> None: """Check that `bin_predictors` is the ~inverse of `quantilized_splits_from_matrix`.""" x = jnp.arange(10)[None, :] splits, _ = quantilized_splits_from_matrix(x, 100) b = bin_predictors(x, splits) assert_array_equal(x, b) - def test_one_value(self): + def test_one_value(self) -> None: """Check there's only 1 bin (0 splits) if there is 1 datapoint.""" x = jnp.arange(10)[:, None] _, max_split = quantilized_splits_from_matrix(x, 100) assert_array_equal(max_split, jnp.full(len(x), 0)) - def test_zero_values(self): + def test_zero_values(self) -> None: """Check what happens when no binning is possible.""" x = jnp.empty((1, 0)) with pytest.raises(ValueError, match='at least 1'): quantilized_splits_from_matrix(x, 100) - def test_zero_bins(self): + def test_zero_bins(self) -> None: """Check what happens when no binning is possible.""" x = jnp.arange(10)[None, :] with pytest.raises(ValueError, match='at least 1'): quantilized_splits_from_matrix(x, 0) -def test_binner_left_boundary(): +def test_binner_left_boundary() -> None: """Check that the first bin is right-closed.""" splits = jnp.array([[1, 2, 3]]) @@ -118,7 +118,7 @@ def test_binner_left_boundary(): assert_array_equal(b, [[0, 0]]) -def test_binner_right_boundary(): +def test_binner_right_boundary() -> None: """Check that the next-to-last bin is right-closed.""" splits = jnp.array([[1, 2, 3, 2**31 - 1]]) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9d0ae9bc..c9ddfa16 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -28,10 +28,12 @@ from functools import partial from pstats import Stats from time import perf_counter, sleep +from typing import NamedTuple import pytest from jax import debug_infs, debug_nans, jit, pure_callback, random from jax import numpy as jnp +from jaxtyping import Array, Bool, Float32, Int32, Integer from numpy.testing import assert_array_equal from bartz._profiler import ( @@ -40,8 +42,8 @@ jit_and_block_if_profiling, jit_if_not_profiling, profile_mode, - scan_if_not_profiling, set_profile_mode, + while_loop_if_not_profiling, ) from bartz.jaxext import get_default_device @@ -49,18 +51,18 @@ class TestFlag: """Test the functionality of the global profile mode flag.""" - def test_initial_state(self): + def test_initial_state(self) -> None: """Check profiling mode is off by default.""" assert not get_profile_mode() - def test_getter_setter(self): + def test_getter_setter(self) -> None: """Test setting and getting the profile mode.""" set_profile_mode(True) assert get_profile_mode() set_profile_mode(False) assert not get_profile_mode() - def test_context_manager(self): + def test_context_manager(self) -> None: """Test the profile mode context manager.""" with profile_mode(True): assert get_profile_mode() @@ -84,32 +86,40 @@ class TestScanIfNotProfiling: """Test `scan_if_not_profiling`.""" @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool): + def test_result(self, mode: bool) -> None: """Test that `scan_if_not_profiling` has the right output on a simple example.""" - def body(carry, _): - return carry + 1, None + def cond(carry: Integer[Array, '']) -> Bool[Array, '']: + return carry < 5 + + def body(carry: Integer[Array, '']) -> Integer[Array, '']: + return carry + 1 with profile_mode(mode): - carry, ys = scan_if_not_profiling(body, 0, None, 5) - assert ys is None + carry = while_loop_if_not_profiling(cond, body, 0) assert carry == 5 - def test_does_not_jit(self): + def test_does_not_jit(self) -> None: """Check that `scan_if_not_profiling` does not jit the function in profiling mode.""" - def body(carry, _): - return carry.block_until_ready(), None - # block_until_ready errors under jit + class Carry(NamedTuple): + i: Int32[Array, ''] + state: Int32[Array, ''] + + def cond(carry: Carry) -> Bool[Array, '']: + return carry.i < 5 + + def body(carry: Carry) -> Carry: + return Carry(carry.i + 1, carry.state.block_until_ready()) with profile_mode(True): - scan_if_not_profiling(body, jnp.int32(0), None, 5) + while_loop_if_not_profiling(cond, body, Carry(jnp.int32(0), jnp.int32(0))) with pytest.raises( AttributeError, match='DynamicJaxprTracer has no attribute block_until_ready', ): - scan_if_not_profiling(body, 0, None, 5) + while_loop_if_not_profiling(cond, body, Carry(jnp.int32(0), jnp.int32(0))) class TestCondIfNotProfiling: @@ -117,7 +127,7 @@ class TestCondIfNotProfiling: @pytest.mark.parametrize('mode', [True, False]) @pytest.mark.parametrize('pred', [True, False]) - def test_result(self, mode: bool, pred: bool): + def test_result(self, mode: bool, pred: bool) -> None: """Test that `cond_if_not_profiling` has the right output on a simple example.""" with profile_mode(mode): out = cond_if_not_profiling( @@ -125,7 +135,7 @@ def test_result(self, mode: bool, pred: bool): ) assert out == (4 if pred else 6) - def test_does_not_jit(self): + def test_does_not_jit(self) -> None: """Check that `cond_if_not_profiling` does not jit the function in profiling mode.""" with profile_mode(True): cond_if_not_profiling( @@ -151,10 +161,10 @@ class TestJitIfNotProfiling: """Test `jit_if_not_profiling`.""" @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool): + def test_result(self, mode: bool) -> None: """Test that `jit_if_not_profiling` has the right output in both modes.""" - def func(x): + def func(x: Integer[Array, '']) -> Integer[Array, '']: return x * 2 + 1 jitted_func = jit_if_not_profiling(func) @@ -163,10 +173,10 @@ def func(x): result = jitted_func(5) assert result == 11 - def test_does_not_jit(self): + def test_does_not_jit(self) -> None: """Check that `jit_if_not_profiling` does not jit the function in profiling mode.""" - def func(x): + def func(x: Int32[Array, '']) -> Int32[Array, '']: return x.block_until_ready() # block_until_ready errors under jit @@ -187,10 +197,10 @@ class TestJitAndBlockIfProfiling: """Test `jit_and_block_if_profiling`.""" @pytest.mark.parametrize('mode', [True, False]) - def test_result(self, mode: bool): + def test_result(self, mode: bool) -> None: """Test that `jit_and_block_if_profiling` has the right output in both modes.""" - def func(x): + def func(x: Integer[Array, '']) -> Integer[Array, '']: return x * 2 + 1 jitted_func = jit_and_block_if_profiling(func) @@ -199,10 +209,10 @@ def func(x): result = jitted_func(5) assert result == 11 - def test_jits_when_profiling(self): + def test_jits_when_profiling(self) -> None: """Check that `jit_and_block_if_profiling` jits when profiling is enabled.""" - def func(x): + def func(x: Int32[Array, '']) -> Int32[Array, '']: return x.block_until_ready() # block_until_ready errors under jit @@ -222,10 +232,10 @@ def func(x): with profile_mode(False): jitted_func(jnp.int32(0)) - def test_static_args(self): + def test_static_args(self) -> None: """Check that it works with static arguments.""" - def func(n: int): + def func(n: int) -> Integer[Array, ' {n}']: return jnp.arange(n) jitted_func = jit_and_block_if_profiling(func, static_argnums=(0,)) @@ -236,7 +246,7 @@ def func(n: int): @pytest.mark.flaky(max_runs=3) # flaky because it involves comparing time measurements done on the fly - def test_blocks_execution(self): + def test_blocks_execution(self) -> None: """Check that `jit_and_block_if_profiling` blocks execution when profiling.""" with debug_nans(False), debug_infs(False): platform = get_default_device().platform @@ -287,15 +297,16 @@ def test_blocks_execution(self): f'Expected async execution << {expected:#.2g}s, got {elapsed:#.2g}s' ) - def test_profile(self): + def test_profile(self) -> None: """Test `jit_and_block_if_profiling` under the Python profiler.""" runtime = 0.1 @jit_and_block_if_profiling - def awlkugh(): # weird name to make sure identifiers are legit + # weird name to make sure identifiers are legit + def awlkugh() -> Int32[Array, '']: x = jnp.int32(0) - def sleeper(x): + def sleeper(x: Int32[Array, '']) -> Int32[Array, '']: sleep(runtime) return x @@ -319,7 +330,7 @@ def sleeper(x): @partial(jit, static_argnums=(0,)) -def idle(n: int): +def idle(n: int) -> Float32[Array, ' {n} {n}']: """Waste time in jax computation.""" key = random.key(0) x = random.normal(key, (n, n)) diff --git a/tests/util.py b/tests/util.py index 7f8ab8c5..f93030bc 100644 --- a/tests/util.py +++ b/tests/util.py @@ -28,6 +28,7 @@ from dataclasses import replace from operator import ge, le from pathlib import Path +from typing import Any import numpy as np import tomli @@ -35,7 +36,7 @@ from jaxtyping import ArrayLike from scipy import linalg -from bartz.debug import check_tree, describe_error +from bartz.debug import check_trace, describe_error from bartz.jaxext import minimal_unsigned_dtype from bartz.mcmcloop import TreesTrace @@ -51,7 +52,7 @@ def manual_tree( """Facilitate the hardcoded definition of tree heaps.""" assert len(leaf) == len(var) + 1 == len(split) + 1 - def check_powers_of_2(seq: list[list]): + def check_powers_of_2(seq: list[list]) -> bool: """Check if the lengths of the lists in `seq` are powers of 2.""" return all(len(x) == 2**i for i, x in enumerate(seq)) @@ -75,7 +76,7 @@ def check_powers_of_2(seq: list[list]): split_tree=tree.split_tree.astype(split_type), ) - error = check_tree(tree, max_split) + error = check_trace(tree, max_split) descr = describe_error(error) bad = any(d not in ignore_errors for d in descr) assert not bad, descr @@ -94,7 +95,7 @@ def assert_close_matrices( ord: int | float | str | None = 2, # noqa: A002 err_msg: str = '', reduce_rank: bool = False, -): +) -> None: """ Check if two matrices are similar. @@ -183,7 +184,7 @@ def assert_close_matrices( assert op(adnorm, atol + rtol * dnorm), msg -def assert_different_matrices(*args, **kwargs): +def assert_different_matrices(*args: ArrayLike, **kwargs: Any) -> None: """Invoke `assert_close_matrices` with negate=True and default inf tolerance.""" default_kwargs: dict = dict(rtol=np.inf, atol=np.inf) default_kwargs.update(kwargs)