Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
fc624ca
chore: push only new version tag
Gattocrucco Feb 13, 2026
87e88a2
chore: remove redundant gpu install on ci
Gattocrucco Feb 13, 2026
a7a09e7
chore: rename make target all => help
Gattocrucco Feb 13, 2026
7140660
chore: correction in release workflow help
Gattocrucco Feb 13, 2026
7c0eb22
chore: simpler gpu memory allocation in unit tests
Gattocrucco Feb 13, 2026
90c30fc
chore: make lint target + activate ANN rules
Gattocrucco Feb 13, 2026
8c0011d
chore: more relaxed ANN
Gattocrucco Feb 13, 2026
872bc9a
refactor: add all missing None return type hints
Gattocrucco Feb 13, 2026
b0050cb
refactor: missing typing in speed.py and util.py
Gattocrucco Feb 13, 2026
e30949a
refactor: add all missing type annotations
Gattocrucco Feb 15, 2026
77b0ee1
chore: list benchmarks in release workflow
Gattocrucco Feb 15, 2026
028dc89
chore: rule against committing to main
Gattocrucco Feb 15, 2026
baf9928
refactor: use jax.tree_util.keystr
Gattocrucco Feb 15, 2026
5f759df
style: typo in variable name
Gattocrucco Feb 15, 2026
8c9f8f6
test: two additional jax error settings
Gattocrucco Feb 15, 2026
449c6d3
refactor: merge test_same/similar_result
Gattocrucco Feb 15, 2026
c5c8ee6
test: enable jax cache only on latest jax
Gattocrucco Feb 15, 2026
61abdf2
docs: fix spacing issue with attributes
Gattocrucco Feb 15, 2026
65566a1
docs: fix typo in docstring
Gattocrucco Feb 15, 2026
3af4080
chore: drop redundant uv sync in make docs-latest
Gattocrucco Feb 15, 2026
2ef78ee
feat: make check_trees=True traceable
Gattocrucco Feb 15, 2026
bea11f4
refactor: simplify init_kw and run_mcmc_kw
Gattocrucco Feb 15, 2026
0a7da2c
chore: enforce some import conventions
Gattocrucco Feb 15, 2026
739e7bc
refactor: scan => while_loop in run_mcmc
Gattocrucco Feb 15, 2026
5a0fa97
refactor: jaxext.jit_active() to DRY
Gattocrucco Feb 15, 2026
b918832
perf: avoid large copies when squashing chains
Gattocrucco Feb 15, 2026
053d7ad
feat: default hypers in make_p_nonterminal
Gattocrucco Feb 15, 2026
1d76d77
perf: create new state arrays already sharded
Gattocrucco Feb 16, 2026
2863429
test: check that trees are replicated correctly
Gattocrucco Feb 17, 2026
7b1fd23
fix: shard_map import with old jax version
Gattocrucco Feb 17, 2026
7f869cf
fix: further problems with old dependencies
Gattocrucco Feb 18, 2026
fd0312f
feat: bartz.__version_info__
Gattocrucco Feb 18, 2026
0801dc9
chore: pre-commit hook to update uv.lock
Gattocrucco Feb 18, 2026
c8b8991
chore: add a smoke test before uploading release
Gattocrucco Feb 18, 2026
6775869
chore: make clean removes docs build
Gattocrucco Feb 18, 2026
5167b07
test: check the output of init() is strongly typed
Gattocrucco Feb 18, 2026
16a980a
chore: increase self-coverage of test-suite
Gattocrucco Feb 18, 2026
76fad83
refactor: distinguish varprob attr from param
Gattocrucco Feb 18, 2026
244015e
test: check if double compilation detection works
Gattocrucco Feb 18, 2026
2235061
fix: error in TestVarprobAttr
Gattocrucco Feb 18, 2026
eae0fff
test: check likelihood ratio=1 if n=0/1
Gattocrucco Feb 18, 2026
23484af
test: check n=2 => likelihood ratio != 0
Gattocrucco Feb 18, 2026
7d923f7
test: check arrays are not deleted by gc
Gattocrucco Feb 18, 2026
c0e7e79
test: check replicated arrays in test_sharding
Gattocrucco Feb 18, 2026
9b662b8
test: cmdline opt to set num cpu devices
Gattocrucco Feb 18, 2026
22b6336
chore: tests on ci w/ 1 cpu device
Gattocrucco Feb 18, 2026
27d1e67
fix: missing quotes in workflow
Gattocrucco Feb 18, 2026
259a562
fix: stray comma in json in workflow
Gattocrucco Feb 18, 2026
358350d
perf: run benchmarks on ci with single cpu device
Gattocrucco Feb 18, 2026
639d51c
refactor: simplify test_array_no_gc
Gattocrucco Feb 18, 2026
cf23878
test: check same result w/ & w/o sharding
Gattocrucco Feb 18, 2026
2f91c43
feat!: remove i_skip
Gattocrucco Feb 18, 2026
fc442b1
perf: inner_loop_length traceable in compiled loop
Gattocrucco Feb 19, 2026
b009d25
refactor: run_mcmc returns a named tuple
Gattocrucco Feb 19, 2026
aebc09c
chore: set pydoclint to forbid __init__ docstrings
Gattocrucco Feb 19, 2026
92814c0
perf: benchmark multivariate step
Gattocrucco Feb 19, 2026
131e9c5
perf: benchmark gbart.predict
Gattocrucco Feb 19, 2026
aa6d5f9
docs: explain more development workflow
Gattocrucco Feb 19, 2026
8be39c0
feat: Bart.num_trees property
Gattocrucco Feb 19, 2026
339b29b
refactor: rename Bart(ntree) to Bart(num_trees)
Gattocrucco Feb 19, 2026
314ccea
feat!: default Bart(num_trees=200) unconditionally
Gattocrucco Feb 19, 2026
39bf67f
feat!: Bart(nskip=1000) by default
Gattocrucco Feb 19, 2026
1d5f1c7
feat!: Bart(keepevery=1) default unconditionally
Gattocrucco Feb 19, 2026
57d4751
feat!: Bart(numcut=255) new default
Gattocrucco Feb 19, 2026
751ae5e
feat!: new default Bart(num_chains=4)
Gattocrucco Feb 19, 2026
b8052cd
feat!: Bart does not set min_points_per_leaf
Gattocrucco Feb 19, 2026
094002a
fix: make mc_gbart(mc_cores=1) run with no chains
Gattocrucco Feb 19, 2026
bf892b9
chore: ban jax.random.PRNGKey
Gattocrucco Feb 19, 2026
7e047ba
refactor: simplify a few uses of jaxext.split
Gattocrucco Feb 19, 2026
37a1f4c
refactor: move some stuff debug => grove
Gattocrucco Feb 19, 2026
3c24e71
refactor: make CheckFunc a Protocol
Gattocrucco Feb 19, 2026
fbcaad2
refactor: split and clean up bartz.debug
Gattocrucco Feb 19, 2026
4ccee91
feat: gen_data(k=None) with squeezed outcome axis
Gattocrucco Feb 20, 2026
601ef90
perf: jit gen_data
Gattocrucco Feb 20, 2026
45b8aea
docs: html docs of bartz.testing
Gattocrucco Feb 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ env:
jobs:
pre-commit:
runs-on: ubuntu-latest
env:
SKIP: no-commit-to-branch
steps:
- uses: actions/checkout@v6
with:
Expand Down Expand Up @@ -84,11 +86,36 @@ jobs:
- id: set-matrix
run: |
BASE_MATRIX='[
{"os":"ubuntu-latest","runs-on":"ubuntu-latest","target-suffix":""},
{"os":"ubuntu-22.04","runs-on":"ubuntu-22.04","target-suffix":"-old"}
{
"id": "latest",
"runs-on": "ubuntu-latest",
"target-suffix": "",
"tests-args": ""
},
{
"id": "old",
"runs-on": "ubuntu-22.04",
"target-suffix": "-old",
"tests-args": ""
},
{
"id": "single-cpu",
"runs-on": "ubuntu-latest",
"target-suffix": "",
"tests-args": "ARGS=--num-cpu-devices=1"
}
]'
BENCHMARKS_BASE_MATRIX='[
{"os":"ubuntu-latest","runs-on":"ubuntu-latest","target-suffix":""}
{
"id": "latest",
"runs-on": "ubuntu-latest",
"vars": ""
},
{
"id": "single-cpu",
"runs-on": "ubuntu-latest",
"vars": "BARTZ_BENCHMARKS_SINGLE_CPU_DEVICE=1 "
}
]'

MATRIX=$(echo "$BASE_MATRIX" | jq -c ".")
Expand All @@ -98,7 +125,7 @@ jobs:
echo "benchmarks-matrix={\"include\":$BENCHMARKS_MATRIX}" >> $GITHUB_OUTPUT

tests:
name: tests-${{ matrix.os }}${{ matrix.target-suffix }}
name: tests-${{ matrix.id }}
needs: setup-matrix
runs-on: ${{ matrix.runs-on }}
continue-on-error: ${{ matrix.target-suffix == '-gpu' }}
Expand All @@ -115,7 +142,7 @@ jobs:
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.UV_VERSION }}
cache-suffix: -${{ matrix.os }}${{ matrix.target-suffix }}
cache-suffix: -${{ matrix.id }}

- name: Install R
uses: r-lib/actions/setup-r@v2
Expand All @@ -134,21 +161,15 @@ jobs:
ldd "$(R RHOME)/library/methods/libs/methods.so" || true
uv run --dev python -c "import rpy2.situation as s; print('\n'.join(map(str, s.iter_info())))"

- name: Install jax GPU support
if: matrix.target-suffix == '-gpu'
run: |
CUDA_VERSION=`nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3`
uv pip install "jax[cuda$CUDA_VERSION]"

- name: Run unit tests
timeout-minutes: 30
run: make COVERAGE_SUFFIX=-${{ matrix.os }}${{ matrix.target-suffix }} tests${{ matrix.target-suffix }}
run: make COVERAGE_SUFFIX=-${{ matrix.id }} tests${{ matrix.target-suffix }} ${{ matrix.tests-args }}

- name: Save coverage information
uses: actions/upload-artifact@v6
with:
name: coverage.tests-${{ matrix.os }}${{ matrix.target-suffix }}
path: .coverage.tests-${{ matrix.os }}${{ matrix.target-suffix }}
name: coverage.tests-${{ matrix.id }}
path: .coverage.tests-${{ matrix.id }}
include-hidden-files: true
if-no-files-found: error

Expand Down Expand Up @@ -279,7 +300,7 @@ jobs:
uses: actions/deploy-pages@v4

benchmarks:
name: benchmarks-${{ matrix.os }}${{ matrix.target-suffix }}
name: benchmarks-${{ matrix.id }}
needs: setup-matrix
runs-on: ${{ matrix.runs-on }}
continue-on-error: ${{ matrix.target-suffix == '-gpu' }}
Expand All @@ -296,19 +317,14 @@ jobs:
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.UV_VERSION }}
cache-suffix: -${{ matrix.os }}${{ matrix.target-suffix }}
cache-suffix: -${{ matrix.id }}

- name: Get machine information for asv
run: RPY2_CFFI_MODE=ABI uv run --dev python -m asv machine --yes

- name: Install jax GPU support
if: matrix.target-suffix == '-gpu'
run: |
CUDA_VERSION=`nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3`
uv pip install "jax[cuda$CUDA_VERSION]"

- name: Run asv in quick mode
run: make asv-quick ARGS='--attribute timeout=120' # long timeout because CI runners are slow
# long timeout because CI runners are slow
run: ${{ matrix.vars }}make asv-quick ARGS='--attribute timeout=120'

dev-commands:
runs-on: ubuntu-latest
Expand All @@ -328,6 +344,9 @@ jobs:
- name: run IPython with old toolchain
run: RPY2_CFFI_MODE=ABI make ipython-old ARGS=config/ipython/profile_default/startup/startup.ipy

- name: Update package version
run: make copy-version

required-dummy:
runs-on: ubuntu-latest
needs:
Expand Down
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ repos:
exclude: tests/(rbartpackages/.+\.py|util\.py)$
- id: trailing-whitespace
exclude: '^renv/activate\.R$' # because renv edits it automatically, with trailing whitespace
- id: no-commit-to-branch
- repo: https://github.com/sbrunner/hooks
rev: 1.6.1
hooks:
Expand All @@ -55,4 +56,8 @@ repos:
- repo: https://github.com/jorisroovers/gitlint
rev: v0.19.1
hooks:
- id: gitlint # this is already defined on stage commit-msg
- id: gitlint # this is already defined to run in stage commit-msg
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.10.4
hooks:
- id: uv-lock
29 changes: 20 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ UV_RUN = uv run --dev $(EXTRAS)
OLD_PYTHON = $(shell grep 'requires-python' pyproject.toml | sed 's/.*>=\([0-9.]*\).*/\1/')
UV_RUN_OLD = $(UV_RUN) --python=$(OLD_PYTHON) --resolution=lowest-direct --exclude-newer=2025-05-15 --isolated

.PHONY: all
all:
.PHONY: help
help:
@echo "Available targets:"
@echo "- setup: create R and Python environments for development"
@echo "- tests: run unit tests on cpu, saving coverage information"
Expand All @@ -61,8 +61,10 @@ all:
@echo "- asv-quick: run quick benchmarks on current code, no saving"
@echo "- ipython: start an ipython shell with stuff pre-imported"
@echo "- ipython-old: start an ipython shell with oldest supported python and dependencies"
@echo "- lint: run the linter used in pre-commit"
@echo
@echo "Release workflow:"
@echo "- do a PR that re-runs benchmarks"
@echo "- create a new branch"
@echo "- $$ uv version --bump major|minor|patch"
@echo "- describe release in docs/changelog.md"
Expand All @@ -71,7 +73,7 @@ all:
@echo "- $$ make release (iterate to fix problems)"
@echo "- if CI does not pass, debug and go back to make release"
@echo "- merge PR"
@echo "- if CI does not pass, debug and go back to make release"
@echo "- if CI does not pass, debug and go back to open PR"
@echo "- $$ make upload"
@echo "- publish github release (updates zenodo automatically)"
@echo "- if the online docs are not up-to-date, merge another PR to trigger a new merge CI"
Expand All @@ -82,14 +84,17 @@ setup:
Rscript -e "renv::restore()"
$(UV_RUN) pre-commit install --install-hooks

.PHONY: lint
lint:
$(UV_RUN) pre-commit run --all-files ruff-check

################# TESTS #################

TESTS_VARS = COVERAGE_FILE=.coverage.tests$(COVERAGE_SUFFIX)
TESTS_COMMAND = python -m pytest --cov --cov-context=test --dist=worksteal --durations=1000
TESTS_CPU_VARS = $(TESTS_VARS) JAX_PLATFORMS=cpu
TESTS_CPU_COMMAND = $(TESTS_COMMAND) --platform=cpu --numprocesses=2
TESTS_GPU_VARS = $(TESTS_VARS) XLA_PYTHON_CLIENT_MEM_FRACTION=.20
TESTS_GPU_VARS = $(TESTS_VARS) XLA_PYTHON_CLIENT_PREALLOCATE=false
TESTS_GPU_COMMAND = $(TESTS_COMMAND) --platform=gpu --numprocesses=3

.PHONY: tests
Expand Down Expand Up @@ -129,7 +134,6 @@ docs-latest:
WORKTREE_DIR=$$(mktemp -d) && \
trap "git worktree remove --force '$$WORKTREE_DIR' 2>/dev/null || rm -rf '$$WORKTREE_DIR'" EXIT && \
git worktree add --detach "$$WORKTREE_DIR" "$$LATEST_TAG" && \
uv sync --all-groups --directory "$$WORKTREE_DIR" && \
$(MAKE) -C "$$WORKTREE_DIR" docs && \
test ! -d _site/docs || rm -r _site/docs && \
mv "$$WORKTREE_DIR/_site/docs-dev" _site/docs
Expand Down Expand Up @@ -172,6 +176,7 @@ clean:
rm -fr .venv
rm -fr dist
rm -fr config/jax_cache
rm -fr docs/_build

.PHONY: release
release: clean update-deps copy-version check-committed tests tests-old docs
Expand All @@ -180,11 +185,17 @@ release: clean update-deps copy-version check-committed tests tests-old docs
.PHONY: version-tag
version-tag: copy-version check-committed
git fetch --tags
git tag v$(shell uv run python -c 'import bartz; print(bartz.__version__)')
git push --tags
$(eval VERSION_TAG := v$(shell uv run python -c 'import bartz; print(bartz.__version__)'))
git tag $(VERSION_TAG)
git push origin $(VERSION_TAG)

.PHONY: smoke-test
smoke-test:
uv run --isolated --no-project --with dist/*.whl python -c 'import bartz'
uv run --isolated --no-project --with dist/*.tar.gz python -c 'import bartz'

.PHONY: upload
upload: version-tag
upload: smoke-test version-tag
@echo "Enter PyPI token:"
@read -s UV_PUBLISH_TOKEN && \
export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \
Expand All @@ -194,7 +205,7 @@ upload: version-tag
uv tool run --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)'

.PHONY: upload-test
upload-test: check-committed
upload-test: smoke-test check-committed
@echo "Enter TestPyPI token:"
@read -s UV_PUBLISH_TOKEN && \
export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@

"""Benchmarking code run by asv."""

from os import getenv

from jax import config

from benchmarks._vendor_latest_bartz import ensure_vendored

ensure_vendored()

config.update('jax_num_cpu_devices', 16)
if getenv('BARTZ_BENCHMARKS_SINGLE_CPU_DEVICE') is None:
config.update('jax_num_cpu_devices', 16)
2 changes: 1 addition & 1 deletion benchmarks/rmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def run_sim_impl(
# adapt for older versions
sig = signature(mc_gbart)

def drop_if_missing(arg: str):
def drop_if_missing(arg: str) -> None:
if arg not in sig.parameters:
kw.pop(arg)

Expand Down
Loading