From 304e6bc189fe8c36eb19d85e8c8690fc98012283 Mon Sep 17 00:00:00 2001 From: Miaoqing Yu Date: Sat, 3 Jan 2026 22:17:15 +0000 Subject: [PATCH 1/5] add mvbart to interface --- src/bartz/_interface.py | 337 +++++++++++++++++++++++++++++++++++++++- src/bartz/grove.py | 33 +++- src/bartz/mcmcloop.py | 4 +- tests/test_mvbart.py | 24 ++- 4 files changed, 387 insertions(+), 11 deletions(-) diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 6883530b..0cc46a7c 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -1,6 +1,6 @@ # bartz/src/bartz/_interface.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -930,3 +930,338 @@ def _evaluate_chains( trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] ) -> Float32[Array, 'mc_cores ndpost/mc_cores m']: return mcmcloop.evaluate_trace(trace, x) + + +class mvBart(Module): + """Multivariate version BART (mvBART) for continuous regression with multiple outputs.""" + + _main_trace: mcmcloop.MainTrace + _burnin_trace: mcmcloop.BurninTrace + _mcmc_state: mcmcstep.State + _splits: Real[Array, 'p max_num_splits'] + _x_train_fmt: Any = field(static=True) + + ndpost: int = field(static=True) + offset: Float32[Array, ' k'] + sigest: Float32[Array, 'k k'] | None = None + yhat_test: Float32[Array, 'ndpost k m'] | None = None + + def __init__( + self, + x_train: Real[Array, 'p n'] | DataFrame, + y_train: Float32[Array, 'k n'] | DataFrame, + *, + x_test: Real[Array, 'p m'] | DataFrame | None = None, + sparse: bool = False, + theta: FloatLike | None = None, + a: FloatLike = 0.5, + b: FloatLike = 1.0, + rho: FloatLike | None = None, + xinfo: Float[Array, 'p n'] | None = None, + usequants: bool = False, + rm_const: bool | None = True, + sigest: Float32[Array, 'k k'] | None = None, + sigdf: FloatLike = 3.0, + sigquant: FloatLike = 0.9, + k: FloatLike = 2.0, + power: FloatLike = 2.0, + base: FloatLike = 0.95, + lamda: Float32[Array, 'k k'] | None = None, + tau_num: FloatLike | Float32[Array, ' k'] | None = None, + offset: Float32[Array, ' k'] | None = None, + w: None = None, + ntree: int = 200, + numcut: int = 100, + ndpost: int = 1000, + nskip: int = 100, + keepevery: int = 1, + printevery: int | None = None, + mc_cores: int = 2, + seed: int | Key[Array, ''] = 0, + maxdepth: int = 6, + init_kw: dict | None = None, + run_mcmc_kw: dict | None = None, + ): + if w is not None: + msg = 'Weights are not supported for multivariate BART.' + raise ValueError(msg) + + # check data and put it in the right format + x_train, x_train_fmt = Bart._process_predictor_input(x_train) # noqa: SLF001 + y_train = self._process_mv_response_input(y_train) + self._check_same_n(x_train, y_train) + + # process sparsity settings + theta, a, b, rho = Bart._process_sparsity_settings( # noqa: SLF001 + x_train, sparse, theta, a, b, rho + ) + + # process "standardization" settings + offset = self._process_offset_settings_mv(y_train, offset) + sigma_mu = self._process_leaf_sdev_settings_mv(y_train, k, ntree, tau_num) + error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + + # determine splits + splits, max_split = Bart._determine_splits(x_train, usequants, numcut, xinfo) # noqa: SLF001 + x_train_binned = Bart._bin_predictors(x_train, splits) # noqa: SLF001 + + # setup and run mcmc + initial_state = self._setup_mcmc_mv( + x_train_binned, + y_train, + offset, + w, + max_split, + sigma_mu, + error_cov_df, + error_cov_scale, + power, + base, + maxdepth, + ntree, + init_kw, + rm_const, + theta, + a, + b, + rho, + ) + final_state, burnin_trace, main_trace = Bart._run_mcmc( # noqa: SLF001 + initial_state, + mc_cores, + ndpost, + nskip, + keepevery, + printevery, + seed, + run_mcmc_kw, + sparse, + ) + + # set public attributes + self.offset = final_state.offset # from the state because of buffer donation + self.ndpost = main_trace.grow_prop_count.size + self.sigest = sigest + + # set private attributes + self._main_trace = main_trace + self._burnin_trace = burnin_trace + self._mcmc_state = final_state + self._splits = splits + self._x_train_fmt = x_train_fmt + + # predict at test points + if x_test is not None: + self.yhat_test = self.predict(x_test) + + # ----------------------- + # MV-specific helpers + # ----------------------- + @staticmethod + def _process_mv_response_input(y) -> Float32[Array, 'k n']: + y = jnp.asarray(y) + if y.ndim != 2: + msg = f'mvBART requires y_train to be 2D (k,n). Got {y.ndim=}.' + raise ValueError(msg) + if y.dtype == bool: + msg = 'mvBART is continuous-only: y_train must be floating (not bool).' + raise ValueError(msg) + return y.astype(jnp.float32) + + @staticmethod + def _check_same_n(x: Real[Array, 'p n'], y: Float32[Array, 'k n']) -> None: + if x.shape[1] != y.shape[1]: + msg = f'Mismatch: x_train has n={x.shape[1]}, y_train has n={y.shape[1]}' + raise ValueError(msg) + + @staticmethod + def _process_offset_settings_mv( + y_train: Float32[Array, 'k n'], offset: float | Float32[Array, ' k'] | None + ) -> Float32[Array, ' k']: + if offset is None: + return y_train.mean(axis=1) + off = jnp.asarray(offset, dtype=jnp.float32) + if off.ndim == 0: + return jnp.broadcast_to(off, (y_train.shape[0],)) + if off.shape != (y_train.shape[0],): + msg = f'Expected offset shape (k,), got {off.shape=}' + raise ValueError(msg) + return off + + @staticmethod + def _process_leaf_sdev_settings_mv( + y_train: Float32[Array, 'k n'], k: float, ntree: int, tau_num: float | None + ) -> Float32[Array, 'k k']: + k = y_train.shape[0] + + if tau_num is None: + if y_train.shape[1] < 2: + tau_num = jnp.ones(k) + else: + tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2 + + return tau_num / (k * math.sqrt(ntree)) + + @staticmethod + def _process_error_variance_settings_mv( + x_train: Real[Array, 'p n'], + y_train: Float32[Array, 'k n'], + sigest: Float32[Array, ' k'] | None, + sigdf: float, + sigquant: float, + lamda_vec: float | Float32[Array, ' k'] | None, + *, + t0: float | None = None, + s0: Float32[Array, 'k k'] | None = None, + ) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]: + p = x_train.shape[0] + k, n = y_train.shape + + # df of IW prior + if t0 is None: + t0 = float(sigdf + k - 1) + if t0 <= k - 1: + msg = f'Degrees of freedom `t0` must be > {k - 1}' + raise ValueError(msg) + + # scale of IW prior: + if s0 is not None: + if s0.shape != (k, k): + msg = ValueError( + f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}' + ) + raise ValueError(msg) + s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32)) + return jnp.asarray(t0, dtype=jnp.float32), s0 + + # if t0 and s0 are none, use a diagonal construction + if sigest is not None: + sigest = jnp.asarray(sigest, dtype=jnp.float32) + if sigest.shape != (k,): + msg = f'sigest must have shape ({k},), got {sigest.shape}' + raise ValueError(msg) + sigest2_vec = jnp.square(sigest) + + elif n < 2: + sigest2_vec = jnp.ones((k,), dtype=jnp.float32) + + elif n <= p: + sigest2_vec = jnp.var(y_train, axis=1) + + else: + # OLS with implicit intercept via centering + # Xc: (n,p), Yc: (n,k) + Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T + Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T + + coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) + R = Yc - Xc @ coef # (n,k) + + # match univariate: chisq = sum residual^2, dof = n - rank + chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) + dof = jnp.maximum(1, n - rank) + sigest2_vec = chisq_vec / dof + + alpha = sigdf / 2.0 + invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 + invchi2rid = invchi2 * sigdf + lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) + + s0 = jnp.diag(t0 * lamda_vec).astype(jnp.float32) + return jnp.asarray(t0, dtype=jnp.float32), s0 + + @staticmethod + def _setup_mcmc_mv( + x_train: Real[Array, 'p n'], + y_train: Float32[Array, 'k n'], + offset: Float32[Array, ' k'], + w: Float[Array, ' n'] | None, + max_split: UInt[Array, ' p'], + # lamda: Float32[Array, 'k k'] | None, + sigma_mu: float, + # sigdf: float, + error_cov_df: float, + error_cov_scale: Float32[Array, 'k k'], + power: float, + base: float, + maxdepth: int, + ntree: int, + init_kw: dict[str, Any] | None, + rm_const: bool | None, + theta: float | None, + a: float, + b: float, + rho: float | None, + ) -> mcmcstep.State: + # depth prior p_nonterminal: same construction as Bart + depth = jnp.arange(maxdepth - 1) + p_nonterminal = base / (1 + depth).astype(float) ** power + + kdim = y_train.shape[0] + leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye(kdim, dtype=jnp.float32) + + kw = dict( + X=x_train, + y=jnp.array(y_train), + offset=offset, + error_scale=w, + max_split=max_split, + num_trees=ntree, + p_nonterminal=p_nonterminal, + leaf_prior_cov_inv=leaf_prior_cov_inv, + error_cov_df=error_cov_df, + error_cov_scale=error_cov_scale, + theta=theta, + a=a, + b=b, + rho=rho, + kind='mv', + ) + + if rm_const is None: + kw.update(filter_splitless_vars=False) + elif rm_const: + kw.update(filter_splitless_vars=True) + else: + n_empty = jnp.count_nonzero(max_split == 0) + if n_empty: + msg = f'There are {n_empty}/{max_split.size} predictors without decision rules' + raise ValueError(msg) + kw.update(filter_splitless_vars=False) + + if init_kw: + kw.update(init_kw) + + return mcmcstep.init(**kw) + + # ----------------------- + # Predictions + # ----------------------- + + def yhat_train(self) -> Float32[Array, 'ndpost k n']: + x_train = self._mcmc_state.X + return self._predict(x_train) + + def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost k m']: + return self._evaluate_chains_flattened(self._main_trace, x) + + @classmethod + @partial(jax.jit, static_argnums=(0,)) + def _evaluate_chains_flattened( + cls, trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] + ) -> Float32[Array, 'ndpost k m']: + out = cls._evaluate_chains(trace, x) # (mc_cores, steps, k, m) + if out.ndim != 4: + msg = f'Expected MV output (mc_cores, steps, k, m). Got {out.shape=}' + raise ValueError(msg) + mc_cores, steps, kdim, m = out.shape + return out.reshape(mc_cores * steps, kdim, m) + + @staticmethod + @partial(jax.vmap, in_axes=(0, None)) + def _evaluate_chains( + trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] + ) -> Float32[Array, 'mc_cores t k m']: + return mcmcloop.evaluate_trace(trace, x) diff --git a/src/bartz/grove.py b/src/bartz/grove.py index 9f670864..9c944c49 100644 --- a/src/bartz/grove.py +++ b/src/bartz/grove.py @@ -1,6 +1,6 @@ # bartz/src/bartz/grove.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -192,17 +192,36 @@ def evaluate_forest( Returns ------- The (sum of) the values of the trees at the points in `X`. + + Raises + ------ + ValueError + If `trees.leaf_tree.ndim` is not 2 (univariate) or 3 (multivariate). """ indices = traverse_forest(X, trees.var_tree, trees.split_tree) - num_trees, _ = trees.leaf_tree.shape - tree_index = jnp.arange(num_trees, dtype=minimal_unsigned_dtype(num_trees - 1)) - leaves = trees.leaf_tree[tree_index[:, None], indices] - if sum_trees: - return jnp.sum(leaves, axis=0, dtype=jnp.float32) + num_trees = trees.leaf_tree[0] + + if trees.leaf_tree.ndim == 2: + tree_index = jnp.arange(num_trees, dtype=minimal_unsigned_dtype(num_trees - 1)) + leaves = trees.leaf_tree[tree_index[:, None], indices] + if sum_trees: + return jnp.sum(leaves, axis=0, dtype=jnp.float32) + else: + return leaves # this sum suggests to swap the vmaps, but I think it's better for X # copying to keep it that way + + elif trees.leaf_tree.ndim == 3: + leaves = jnp.take_along_axis( + trees.leaf_tree, indices[:, None, :], axis=-1 + ) # (num_trees, k, n) + if sum_trees: + return jnp.sum(leaves, axis=0, dtype=jnp.float32) + else: + return leaves else: - return leaves + msg = f'Expected trees.leaf_tree.ndim to be 2 or 3, got {trees.leaf_tree.ndim}' + raise ValueError(msg) def is_actual_leaf( diff --git a/src/bartz/mcmcloop.py b/src/bartz/mcmcloop.py index 07189f35..ad2ecad5 100644 --- a/src/bartz/mcmcloop.py +++ b/src/bartz/mcmcloop.py @@ -1,6 +1,6 @@ # bartz/src/bartz/mcmcloop.py # -# Copyright (c) 2024-2025, The Bartz Contributors +# Copyright (c) 2024-2026, The Bartz Contributors # # This file is part of bartz. # @@ -677,7 +677,7 @@ def evaluate_trace( def loop(_, item): offset, trees = item values = evaluate_trees(X, trees) - return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) + return None, offset[..., None] + jnp.sum(values, axis=0, dtype=jnp.float32) _, y = lax.scan(loop, None, (trace.offset, trees)) return y diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 04dd726c..d76eabf0 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -1,6 +1,6 @@ # bartz/tests/test_mvbart.py # -# Copyright (c) 2025, The Bartz Contributors +# Copyright (c) 2025-2026, The Bartz Contributors # # This file is part of bartz. # @@ -32,6 +32,7 @@ from numpy.testing import assert_allclose, assert_array_equal from scipy.stats import chi2, ks_1samp, ks_2samp +from bartz._interface import mvBart from bartz.mcmcstep import State, init, step from bartz.mcmcstep._step import ( Counts, @@ -479,3 +480,24 @@ def test_mv_steps(self, keys, data): assert mv_state.error_cov_inv.shape == (k, k) assert mv_state.resid.shape == (k, y.shape[1]) + + +class TestMVBartInterface: + """Tests for mvBart Interface.""" + + def test_initialization_and_shapes(self, keys): + """Test that mvBart predicts with correct shapes.""" + n, n_test, p, k_dim = 60, 40, 5, 3 + nskip, ndpost = 10, 50 + + X = random.normal(keys.pop(), (p, n)) + B = random.normal(keys.pop(), (p, k_dim)) + Y = B.T @ X + 0.1 * random.normal(keys.pop(), (k_dim, n)) + + model = mvBart( + x_train=X, y_train=Y, ntree=10, ndpost=ndpost, nskip=nskip, mc_cores=2 + ) + + X_test = random.normal(random.key(1), (p, n_test)) + y_pred = model._predict(X_test) + assert y_pred.shape == (ndpost, k_dim, n_test) From 6bb2747fe874a8366490ca6f35675cf96e8971fc Mon Sep 17 00:00:00 2001 From: Miaoqing Yu Date: Sun, 4 Jan 2026 23:54:13 +0000 Subject: [PATCH 2/5] move everything into bart class --- src/bartz/_interface.py | 595 +++++++++++++++------------------------- tests/test_mvbart.py | 6 +- 2 files changed, 218 insertions(+), 383 deletions(-) diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index f65948ae..53bf6f7b 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -216,6 +216,8 @@ class Bart(Module): to the maximum value of an unsigned integer type, like 255. Ignored if `xinfo` is specified. + is_mv + An indicator of whether y being multivariate or not. ndpost The number of MCMC samples to save, after burn-in. `ndpost` is the total number of samples across all chains. `ndpost` is rounded up to the @@ -271,12 +273,15 @@ class Bart(Module): ndpost: int = field(static=True) offset: Float32[Array, ''] sigest: Float32[Array, ''] | None = None - yhat_test: Float32[Array, 'ndpost m'] | None = None + yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost m k'] | None = None def __init__( self, x_train: Real[Array, 'p n'] | DataFrame, - y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, + y_train: Bool[Array, ' n'] + | Float32[Array, ' n'] + | Float32[Array, 'k n'] + | Series, *, x_test: Real[Array, 'p m'] | DataFrame | None = None, type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 @@ -288,18 +293,19 @@ def __init__( xinfo: Float[Array, 'p n'] | None = None, usequants: bool = False, rm_const: bool | None = True, - sigest: FloatLike | None = None, + sigest: FloatLike | Float32[Array, 'k k'] | None = None, sigdf: FloatLike = 3.0, sigquant: FloatLike = 0.9, k: FloatLike = 2.0, power: FloatLike = 2.0, base: FloatLike = 0.95, - lamda: FloatLike | None = None, - tau_num: FloatLike | None = None, - offset: FloatLike | None = None, + lamda: FloatLike | None = None, # to change? + tau_num: FloatLike | None = None, # to change? + offset: FloatLike | None = None, # to change? w: Float[Array, ' n'] | None = None, ntree: int | None = None, numcut: int = 100, + is_mv: bool = field(static=True), ndpost: int = 1000, nskip: int = 100, keepevery: int | None = None, @@ -313,13 +319,18 @@ def __init__( # check data and put it in the right format x_train, x_train_fmt = self._process_predictor_input(x_train) y_train = self._process_response_input(y_train) + is_mv = y_train.ndim == 2 + self._check_same_length(x_train, y_train) + self._validate_compatibility(is_mv, y_train, w, type) + if w is not None: w = self._process_response_input(w) self._check_same_length(x_train, w) # check data types are correct for continuous/binary regression - self._check_type_settings(y_train, type, w) + if not is_mv: + self._check_type_settings(y_train, type, w) # from here onwards, the type is determined by y_train.dtype == bool # set defaults that depend on type of regression @@ -336,10 +347,26 @@ def __init__( # process "standardization" settings offset = self._process_offset_settings(y_train, offset) sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) - lamda, sigest = self._process_error_variance_settings( - x_train, y_train, sigest, sigdf, sigquant, lamda + + error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = ( + self._configure_priors( + is_mv, x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda + ) ) + if is_mv: # Multivariate standardization + error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye( + y_train.shape[0], dtype=jnp.float32 + ) + else: # Univariate standardization + lamda, sigest = self._process_error_variance_settings( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = lax.reciprocal(jnp.square(sigma_mu)) + # determine splits splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) x_train = self._bin_predictors(x_train, splits) @@ -351,9 +378,12 @@ def __init__( offset, w, max_split, - lamda, - sigma_mu, - sigdf, + leaf_prior_cov_inv, + error_cov_df, + error_cov_scale, + # lamda, + # sigma_mu, + # sigdf, power, base, maxdepth, @@ -380,7 +410,7 @@ def __init__( # set public attributes self.offset = final_state.offset # from the state because of buffer donation self.ndpost = main_trace.grow_prop_count.size - self.sigest = sigest + self.sigest = sigest if not is_mv else None # set private attributes self._main_trace = main_trace @@ -526,13 +556,13 @@ def yhat_test_mean(self) -> Float32[Array, ' m'] | None: return self.yhat_test.mean(axis=0) @cached_property - def yhat_train(self) -> Float32[Array, 'ndpost n']: + def yhat_train(self) -> Float32[Array, 'ndpost n'] | Float32[Array, 'ndpost k n']: """The conditional posterior mean at `x_train` for each MCMC iteration.""" x_train = self._mcmc_state.X return self._predict(x_train) @cached_property - def yhat_train_mean(self) -> Float32[Array, ' n'] | None: + def yhat_train_mean(self) -> Float32[Array, ' n'] | Float32[Array, ' k n'] | None: """The marginal posterior mean at `x_train`. Not defined with binary regression because it's error-prone, typically @@ -582,12 +612,62 @@ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: return x, fmt @staticmethod - def _process_response_input(y) -> Shaped[Array, ' n']: + def _validate_compatibility(is_mv, y_train, w, type): # noqa: A002 + """Validate inputs based on regression type (Univariate/Multivariate).""" + if is_mv: + if w is not None: + msg = "Weights 'w' are not supported for multivariate regression." + raise ValueError(msg) + if type != 'wbart': + msg = "Multivariate regression implies type='wbart'." + raise ValueError(msg) + if y_train.dtype == bool: + msg = 'Multivariate regression not yet support binary responses.' + raise TypeError(msg) + + def _configure_priors( + self, is_mv, x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda + ): + """Configure error covariance/variance priors and leaf priors.""" + if is_mv: + error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye( + y_train.shape[0], dtype=jnp.float32 + ) + return error_cov_df, error_cov_scale, leaf_prior_cov_inv, None + else: + lamda_val, sigest_val = self._process_error_variance_settings( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = lax.reciprocal(jnp.square(sigma_mu)) + + if y_train.dtype == bool: + error_cov_df = None + error_cov_scale = None + else: + error_cov_df = sigdf + error_cov_scale = lamda_val * sigdf + + return error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest_val + + @staticmethod + def _process_response_input(y) -> Shaped[Array, ' n'] | Shaped[Array, ' k n']: if hasattr(y, 'to_numpy'): y = y.to_numpy() y = jnp.asarray(y) - assert y.ndim == 1 - return y + + if y.ndim == 1: + return y + elif y.ndim == 2: + if y.dtype == bool: + msg = 'mvBART is continuous-only: y_train must be floating (not bool).' + raise ValueError(msg) + return y.astype(jnp.float32) + else: + msg = f'y_train must be 1D (n,) or 2D (k,n). Got {y.ndim=}.' + raise ValueError(msg) @staticmethod def _check_same_length(x1, x2): @@ -631,6 +711,74 @@ def _process_error_variance_settings( invchi2rid = invchi2 * sigdf return sigest2 / invchi2rid, jnp.sqrt(sigest2) + @staticmethod + def _process_error_variance_settings_mv( + x_train: Real[Array, 'p n'], + y_train: Float32[Array, 'k n'], + sigest: Float32[Array, ' k'] | None, + sigdf: float, + sigquant: float, + lamda_vec: float | Float32[Array, ' k'] | None, + *, + t0: float | None = None, + s0: Float32[Array, 'k k'] | None = None, + ) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]: + p = x_train.shape[0] + k, n = y_train.shape + + # df of IW prior + if t0 is None: + t0 = float(sigdf + k - 1) + if t0 <= k - 1: + msg = f'Degrees of freedom `t0` must be > {k - 1}' + raise ValueError(msg) + + # scale of IW prior: + if s0 is not None: + if s0.shape != (k, k): + msg = ValueError( + f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}' + ) + raise ValueError(msg) + s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32)) + return jnp.asarray(t0, dtype=jnp.float32), s0 + + # if t0 and s0 are none, use a diagonal construction + if sigest is not None: + sigest = jnp.asarray(sigest, dtype=jnp.float32) + if sigest.shape != (k,): + msg = f'sigest must have shape ({k},), got {sigest.shape}' + raise ValueError(msg) + sigest2_vec = jnp.square(sigest) + + elif n < 2: + sigest2_vec = jnp.ones((k,), dtype=jnp.float32) + + elif n <= p: + sigest2_vec = jnp.var(y_train, axis=1) + + else: + # OLS with implicit intercept via centering + # Xc: (n,p), Yc: (n,k) + Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T + Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T + + coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) + R = Yc - Xc @ coef # (n,k) + + # match univariate: chisq = sum residual^2, dof = n - rank + chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) + dof = jnp.maximum(1, n - rank) + sigest2_vec = chisq_vec / dof + + alpha = sigdf / 2.0 + invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 + invchi2rid = invchi2 * sigdf + lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) + + s0 = jnp.diag(t0 * lamda_vec).astype(jnp.float32) + return jnp.asarray(t0, dtype=jnp.float32), s0 + @staticmethod def _check_type_settings(y_train, type, w): # noqa: A002 match type: @@ -680,26 +828,37 @@ def _process_sparsity_settings( @staticmethod def _process_offset_settings( - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], offset: float | Float32[Any, ''] | None, - ) -> Float32[Array, '']: + ) -> Float32[Array, '...']: if offset is not None: - return jnp.asarray(offset) - elif y_train.size < 1: - return jnp.array(0.0) - else: - mean = y_train.mean() + off = jnp.asarray(offset, dtype=jnp.float32) + + if y_train.ndim == 2: + k = y_train.shape[0] + if off.ndim == 0: + return jnp.broadcast_to(off, (k,)) + if off.shape != (k,): + msg = f'Expected offset shape ({k},), got {off.shape=}' + raise ValueError(msg) + else: + return off + if y_train.ndim == 2: + return y_train.mean(axis=1) + if y_train.size < 1: + return jnp.array(0.0) + mean = y_train.mean() if y_train.dtype == bool: bound = 1 / (1 + y_train.size) mean = jnp.clip(mean, bound, 1 - bound) return ndtri(mean) - else: - return mean + + return mean @staticmethod def _process_leaf_sdev_settings( - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], k: float, ntree: int, tau_num: FloatLike | None, @@ -707,11 +866,15 @@ def _process_leaf_sdev_settings( if tau_num is None: if y_train.dtype == bool: tau_num = 3.0 + elif y_train.ndim == 2: + if y_train.shape[1] < 2: + tau_num = jnp.ones(k) + else: + tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2 elif y_train.size < 2: tau_num = 1.0 else: tau_num = (y_train.max() - y_train.min()) / 2 - return tau_num / (k * math.sqrt(ntree)) @staticmethod @@ -740,13 +903,16 @@ def _bin_predictors( @staticmethod def _setup_mcmc( x_train: Real[Array, 'p n'], - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'p n'] | Bool[Array, ' n'], offset: Float32[Array, ''], w: Float[Array, ' n'] | None, max_split: UInt[Array, ' p'], - lamda: Float32[Array, ''] | None, - sigma_mu: FloatLike, - sigdf: FloatLike, + # lamda: Float32[Array, ''] | None, + # sigma_mu: FloatLike, + # sigdf: FloatLike, + leaf_prior_cov_inv, + error_cov_df, + error_cov_scale, power: FloatLike, base: FloatLike, maxdepth: int, @@ -761,14 +927,6 @@ def _setup_mcmc( depth = jnp.arange(maxdepth - 1) p_nonterminal = base / (1 + depth).astype(float) ** power - if y_train.dtype == bool: - error_cov_df = None - error_cov_scale = None - else: - # inverse gamma prior: alpha = df / 2, beta = scale / 2 - error_cov_df = sigdf - error_cov_scale = lamda * sigdf - kw = dict( X=x_train, # copy y_train because it's going to be donated in the mcmc loop @@ -778,7 +936,7 @@ def _setup_mcmc( max_split=max_split, num_trees=ntree, p_nonterminal=p_nonterminal, - leaf_prior_cov_inv=lax.reciprocal(jnp.square(sigma_mu)), + leaf_prior_cov_inv=leaf_prior_cov_inv, error_cov_df=error_cov_df, error_cov_scale=error_cov_scale, min_points_per_decision_node=10, @@ -911,356 +1069,33 @@ def choose_vmap_index(path, _) -> Literal[0, None]: return map_with_path(choose_vmap_index, state) - def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: + def _predict( + self, x: UInt[Array, 'p m'] + ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost m k']: return self._evaluate_chains_flattened(self._main_trace, x) @classmethod @partial(jax.jit, static_argnums=(0,)) def _evaluate_chains_flattened( cls, trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] - ) -> Float32[Array, 'ndpost m']: + ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost m k']: out = cls._evaluate_chains(trace, x) - mc_cores, ndpost_per_chain, m = out.shape - return out.reshape(mc_cores * ndpost_per_chain, m) - - @staticmethod - @partial(jax.vmap, in_axes=(0, None)) - def _evaluate_chains( - trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] - ) -> Float32[Array, 'mc_cores ndpost/mc_cores m']: - return mcmcloop.evaluate_trace(trace, x) - - -class mvBart(Module): - """Multivariate version BART (mvBART) for continuous regression with multiple outputs.""" - - _main_trace: mcmcloop.MainTrace - _burnin_trace: mcmcloop.BurninTrace - _mcmc_state: mcmcstep.State - _splits: Real[Array, 'p max_num_splits'] - _x_train_fmt: Any = field(static=True) - - ndpost: int = field(static=True) - offset: Float32[Array, ' k'] - sigest: Float32[Array, 'k k'] | None = None - yhat_test: Float32[Array, 'ndpost k m'] | None = None - - def __init__( - self, - x_train: Real[Array, 'p n'] | DataFrame, - y_train: Float32[Array, 'k n'] | DataFrame, - *, - x_test: Real[Array, 'p m'] | DataFrame | None = None, - sparse: bool = False, - theta: FloatLike | None = None, - a: FloatLike = 0.5, - b: FloatLike = 1.0, - rho: FloatLike | None = None, - xinfo: Float[Array, 'p n'] | None = None, - usequants: bool = False, - rm_const: bool | None = True, - sigest: Float32[Array, 'k k'] | None = None, - sigdf: FloatLike = 3.0, - sigquant: FloatLike = 0.9, - k: FloatLike = 2.0, - power: FloatLike = 2.0, - base: FloatLike = 0.95, - lamda: Float32[Array, 'k k'] | None = None, - tau_num: FloatLike | Float32[Array, ' k'] | None = None, - offset: Float32[Array, ' k'] | None = None, - w: None = None, - ntree: int = 200, - numcut: int = 100, - ndpost: int = 1000, - nskip: int = 100, - keepevery: int = 1, - printevery: int | None = None, - mc_cores: int = 2, - seed: int | Key[Array, ''] = 0, - maxdepth: int = 6, - init_kw: dict | None = None, - run_mcmc_kw: dict | None = None, - ): - if w is not None: - msg = 'Weights are not supported for multivariate BART.' - raise ValueError(msg) - - # check data and put it in the right format - x_train, x_train_fmt = Bart._process_predictor_input(x_train) # noqa: SLF001 - y_train = self._process_mv_response_input(y_train) - self._check_same_n(x_train, y_train) - - # process sparsity settings - theta, a, b, rho = Bart._process_sparsity_settings( # noqa: SLF001 - x_train, sparse, theta, a, b, rho - ) - - # process "standardization" settings - offset = self._process_offset_settings_mv(y_train, offset) - sigma_mu = self._process_leaf_sdev_settings_mv(y_train, k, ntree, tau_num) - error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( - x_train, y_train, sigest, sigdf, sigquant, lamda - ) - - # determine splits - splits, max_split = Bart._determine_splits(x_train, usequants, numcut, xinfo) # noqa: SLF001 - x_train_binned = Bart._bin_predictors(x_train, splits) # noqa: SLF001 - - # setup and run mcmc - initial_state = self._setup_mcmc_mv( - x_train_binned, - y_train, - offset, - w, - max_split, - sigma_mu, - error_cov_df, - error_cov_scale, - power, - base, - maxdepth, - ntree, - init_kw, - rm_const, - theta, - a, - b, - rho, - ) - final_state, burnin_trace, main_trace = Bart._run_mcmc( # noqa: SLF001 - initial_state, - mc_cores, - ndpost, - nskip, - keepevery, - printevery, - seed, - run_mcmc_kw, - sparse, - ) - - # set public attributes - self.offset = final_state.offset # from the state because of buffer donation - self.ndpost = main_trace.grow_prop_count.size - self.sigest = sigest - - # set private attributes - self._main_trace = main_trace - self._burnin_trace = burnin_trace - self._mcmc_state = final_state - self._splits = splits - self._x_train_fmt = x_train_fmt - - # predict at test points - if x_test is not None: - self.yhat_test = self.predict(x_test) - - # ----------------------- - # MV-specific helpers - # ----------------------- - @staticmethod - def _process_mv_response_input(y) -> Float32[Array, 'k n']: - y = jnp.asarray(y) - if y.ndim != 2: - msg = f'mvBART requires y_train to be 2D (k,n). Got {y.ndim=}.' - raise ValueError(msg) - if y.dtype == bool: - msg = 'mvBART is continuous-only: y_train must be floating (not bool).' - raise ValueError(msg) - return y.astype(jnp.float32) - - @staticmethod - def _check_same_n(x: Real[Array, 'p n'], y: Float32[Array, 'k n']) -> None: - if x.shape[1] != y.shape[1]: - msg = f'Mismatch: x_train has n={x.shape[1]}, y_train has n={y.shape[1]}' - raise ValueError(msg) - - @staticmethod - def _process_offset_settings_mv( - y_train: Float32[Array, 'k n'], offset: float | Float32[Array, ' k'] | None - ) -> Float32[Array, ' k']: - if offset is None: - return y_train.mean(axis=1) - off = jnp.asarray(offset, dtype=jnp.float32) - if off.ndim == 0: - return jnp.broadcast_to(off, (y_train.shape[0],)) - if off.shape != (y_train.shape[0],): - msg = f'Expected offset shape (k,), got {off.shape=}' - raise ValueError(msg) - return off - - @staticmethod - def _process_leaf_sdev_settings_mv( - y_train: Float32[Array, 'k n'], k: float, ntree: int, tau_num: float | None - ) -> Float32[Array, 'k k']: - k = y_train.shape[0] - - if tau_num is None: - if y_train.shape[1] < 2: - tau_num = jnp.ones(k) - else: - tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2 - - return tau_num / (k * math.sqrt(ntree)) - - @staticmethod - def _process_error_variance_settings_mv( - x_train: Real[Array, 'p n'], - y_train: Float32[Array, 'k n'], - sigest: Float32[Array, ' k'] | None, - sigdf: float, - sigquant: float, - lamda_vec: float | Float32[Array, ' k'] | None, - *, - t0: float | None = None, - s0: Float32[Array, 'k k'] | None = None, - ) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]: - p = x_train.shape[0] - k, n = y_train.shape - - # df of IW prior - if t0 is None: - t0 = float(sigdf + k - 1) - if t0 <= k - 1: - msg = f'Degrees of freedom `t0` must be > {k - 1}' - raise ValueError(msg) - - # scale of IW prior: - if s0 is not None: - if s0.shape != (k, k): - msg = ValueError( - f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}' - ) - raise ValueError(msg) - s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32)) - return jnp.asarray(t0, dtype=jnp.float32), s0 - - # if t0 and s0 are none, use a diagonal construction - if sigest is not None: - sigest = jnp.asarray(sigest, dtype=jnp.float32) - if sigest.shape != (k,): - msg = f'sigest must have shape ({k},), got {sigest.shape}' - raise ValueError(msg) - sigest2_vec = jnp.square(sigest) - - elif n < 2: - sigest2_vec = jnp.ones((k,), dtype=jnp.float32) - - elif n <= p: - sigest2_vec = jnp.var(y_train, axis=1) - + if out.ndim == 4: + mc_cores, ndpost_per_chain, m, k = out.shape + return out.reshape(mc_cores * ndpost_per_chain, m, k) + elif out.ndim == 3: + mc_cores, ndpost_per_chain, m = out.shape + return out.reshape(mc_cores * ndpost_per_chain, m) else: - # OLS with implicit intercept via centering - # Xc: (n,p), Yc: (n,k) - Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T - Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T - - coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) - R = Yc - Xc @ coef # (n,k) - - # match univariate: chisq = sum residual^2, dof = n - rank - chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) - dof = jnp.maximum(1, n - rank) - sigest2_vec = chisq_vec / dof - - alpha = sigdf / 2.0 - invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 - invchi2rid = invchi2 * sigdf - lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) - - s0 = jnp.diag(t0 * lamda_vec).astype(jnp.float32) - return jnp.asarray(t0, dtype=jnp.float32), s0 - - @staticmethod - def _setup_mcmc_mv( - x_train: Real[Array, 'p n'], - y_train: Float32[Array, 'k n'], - offset: Float32[Array, ' k'], - w: Float[Array, ' n'] | None, - max_split: UInt[Array, ' p'], - # lamda: Float32[Array, 'k k'] | None, - sigma_mu: float, - # sigdf: float, - error_cov_df: float, - error_cov_scale: Float32[Array, 'k k'], - power: float, - base: float, - maxdepth: int, - ntree: int, - init_kw: dict[str, Any] | None, - rm_const: bool | None, - theta: float | None, - a: float, - b: float, - rho: float | None, - ) -> mcmcstep.State: - # depth prior p_nonterminal: same construction as Bart - depth = jnp.arange(maxdepth - 1) - p_nonterminal = base / (1 + depth).astype(float) ** power - - kdim = y_train.shape[0] - leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye(kdim, dtype=jnp.float32) - - kw = dict( - X=x_train, - y=jnp.array(y_train), - offset=offset, - error_scale=w, - max_split=max_split, - num_trees=ntree, - p_nonterminal=p_nonterminal, - leaf_prior_cov_inv=leaf_prior_cov_inv, - error_cov_df=error_cov_df, - error_cov_scale=error_cov_scale, - theta=theta, - a=a, - b=b, - rho=rho, - kind='mv', - ) - - if rm_const is None: - kw.update(filter_splitless_vars=False) - elif rm_const: - kw.update(filter_splitless_vars=True) - else: - n_empty = jnp.count_nonzero(max_split == 0) - if n_empty: - msg = f'There are {n_empty}/{max_split.size} predictors without decision rules' - raise ValueError(msg) - kw.update(filter_splitless_vars=False) - - if init_kw: - kw.update(init_kw) - - return mcmcstep.init(**kw) - - # ----------------------- - # Predictions - # ----------------------- - - def yhat_train(self) -> Float32[Array, 'ndpost k n']: - x_train = self._mcmc_state.X - return self._predict(x_train) - - def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost k m']: - return self._evaluate_chains_flattened(self._main_trace, x) - - @classmethod - @partial(jax.jit, static_argnums=(0,)) - def _evaluate_chains_flattened( - cls, trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] - ) -> Float32[Array, 'ndpost k m']: - out = cls._evaluate_chains(trace, x) # (mc_cores, steps, k, m) - if out.ndim != 4: - msg = f'Expected MV output (mc_cores, steps, k, m). Got {out.shape=}' + msg = f'Expected output has dimension 3 or 4. Got {out.shape=}' raise ValueError(msg) - mc_cores, steps, kdim, m = out.shape - return out.reshape(mc_cores * steps, kdim, m) @staticmethod @partial(jax.vmap, in_axes=(0, None)) def _evaluate_chains( trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] - ) -> Float32[Array, 'mc_cores t k m']: + ) -> ( + Float32[Array, 'mc_cores ndpost/mc_cores m'] + | Float32[Array, 'mc_cores ndpost/mc_cores m k'] + ): return mcmcloop.evaluate_trace(trace, x) diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 681f25c8..c9dacdd1 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -32,7 +32,7 @@ from numpy.testing import assert_allclose, assert_array_equal from scipy.stats import chi2, ks_1samp, ks_2samp -from bartz._interface import mvBart +from bartz._interface import Bart from bartz.mcmcstep import State, init, step from bartz.mcmcstep._step import ( Counts, @@ -489,10 +489,10 @@ def test_initialization_and_shapes(self, keys): B = random.normal(keys.pop(), (p, k_dim)) Y = B.T @ X + 0.1 * random.normal(keys.pop(), (k_dim, n)) - model = mvBart( + model = Bart( x_train=X, y_train=Y, ntree=10, ndpost=ndpost, nskip=nskip, mc_cores=2 ) X_test = random.normal(random.key(1), (p, n_test)) y_pred = model._predict(X_test) - assert y_pred.shape == (ndpost, k_dim, n_test) + assert y_pred.shape == (ndpost, n_test, k_dim) From 61854f1b83f625e709c5df733b5c01a67501ffbb Mon Sep 17 00:00:00 2001 From: Miaoqing Yu Date: Tue, 6 Jan 2026 06:27:57 +0000 Subject: [PATCH 3/5] add mvbart convergence test, not finished yet --- src/bartz/_interface.py | 65 +++++++++++++++++---------------- tests/test_BART.py | 81 ++--------------------------------------- tests/test_mvbart.py | 61 ++++++++++++++++++++++++++++++- tests/util.py | 81 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 176 insertions(+), 112 deletions(-) diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 33430a38..6eb18763 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -250,7 +250,7 @@ class Bart(Module): The prior mean of the latent mean function. sigest : Float32[Array, ''] | None The estimated standard deviation of the error used to set `lamda`. - yhat_test : Float32[Array, 'ndpost m'] | None + yhat_test : Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None The conditional posterior mean at `x_test` for each MCMC iteration. References @@ -272,7 +272,7 @@ class Bart(Module): ndpost: int = field(static=True) offset: Float32[Array, ''] sigest: Float32[Array, ''] | None = None - yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost m k'] | None = None + yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None = None def __init__( self, @@ -722,39 +722,42 @@ def _process_error_variance_settings_mv( return jnp.asarray(t0, dtype=jnp.float32), s0 # if t0 and s0 are none, use a diagonal construction - if sigest is not None: - sigest = jnp.asarray(sigest, dtype=jnp.float32) - if sigest.shape != (k,): - msg = f'sigest must have shape ({k},), got {sigest.shape}' - raise ValueError(msg) - sigest2_vec = jnp.square(sigest) - - elif n < 2: - sigest2_vec = jnp.ones((k,), dtype=jnp.float32) - - elif n <= p: - sigest2_vec = jnp.var(y_train, axis=1) - + if lamda_vec is not None: + lamda_vec = jnp.atleast_1d(lamda_vec).astype(jnp.float32) else: - # OLS with implicit intercept via centering - # Xc: (n,p), Yc: (n,k) - Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T - Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T + if sigest is not None: + sigest = jnp.asarray(sigest, dtype=jnp.float32) + if sigest.shape != (k,): + msg = f'sigest must have shape ({k},), got {sigest.shape}' + raise ValueError(msg) + sigest2_vec = jnp.square(sigest) + elif n < 2: + sigest2_vec = jnp.ones((k,), dtype=jnp.float32) + elif n <= p: + sigest2_vec = jnp.var(y_train, axis=1) - coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) - R = Yc - Xc @ coef # (n,k) + else: + # OLS with implicit intercept via centering + # Xc: (n,p), Yc: (n,k) + Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T + Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T - # match univariate: chisq = sum residual^2, dof = n - rank - chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) - dof = jnp.maximum(1, n - rank) - sigest2_vec = chisq_vec / dof + coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) + R = Yc - Xc @ coef # (n,k) - alpha = sigdf / 2.0 - invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 - invchi2rid = invchi2 * sigdf - lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) + # match univariate: chisq = sum residual^2, dof = n - rank + chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) + dof = jnp.maximum(1, n - rank) + sigest2_vec = chisq_vec / dof - s0 = jnp.diag(t0 * lamda_vec).astype(jnp.float32) + alpha = sigdf / 2.0 + invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 + invchi2rid = invchi2 * sigdf + lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype( + jnp.float32 + ) # (k,) + + s0 = jnp.diag(sigdf * lamda_vec).astype(jnp.float32) return jnp.asarray(t0, dtype=jnp.float32), s0 @staticmethod @@ -884,7 +887,7 @@ def _bin_predictors( @staticmethod def _setup_mcmc( x_train: Real[Array, 'p n'], - y_train: Float32[Array, ' n'] | Float32[Array, 'p n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], offset: Float32[Array, ''], w: Float[Array, ' n'] | None, max_split: UInt[Array, ' p'], diff --git a/tests/test_BART.py b/tests/test_BART.py index 5151b700..8d0567a7 100644 --- a/tests/test_BART.py +++ b/tests/test_BART.py @@ -44,11 +44,10 @@ from jax import debug_nans, lax, random, vmap from jax import numpy as jnp from jax.lax import collapse -from jax.scipy.linalg import solve_triangular from jax.scipy.special import logit, ndtr from jax.tree import map_with_path from jax.tree_util import KeyPath -from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, UInt +from jaxtyping import Array, Bool, Float32, Int32, Key, Real, UInt from numpy.testing import assert_allclose, assert_array_equal from pytest_subtests import SubTests @@ -72,6 +71,8 @@ assert_close_matrices, assert_different_matrices, get_old_python_tuple, + multivariate_rhat, + rhat, ) @@ -1064,82 +1065,6 @@ def avg_max_tree_depth( return depth.mean(-1) -def multivariate_rhat(chains: Real[Any, 'chain sample dim']) -> Float[Array, '']: - """ - Compute the multivariate Gelman-Rubin R-hat. - - Parameters - ---------- - chains - Independent chains of samples of a vector. - - Returns - ------- - Multivariate R-hat statistic. - - Raises - ------ - ValueError - If there are not enough chains or samples. - """ - chains = jnp.asarray(chains) - m, n, p = chains.shape - - if m < 2: # pragma: no cover - msg = 'Need at least 2 chains' - raise ValueError(msg) - if n < 2: # pragma: no cover - msg = 'Need at least 2 samples per chain' - raise ValueError(msg) - - chain_means = jnp.mean(chains, axis=1) - - def compute_chain_cov(chain_samples, chain_mean): - centered = chain_samples - chain_mean - return jnp.dot(centered.T, centered) / (n - 1) - - within_chain_covs = vmap(compute_chain_cov)(chains, chain_means) - W = jnp.mean(within_chain_covs, axis=0) - - overall_mean = jnp.mean(chain_means, axis=0) - chain_mean_diffs = chain_means - overall_mean - B = (n / (m - 1)) * jnp.dot(chain_mean_diffs.T, chain_mean_diffs) - - V_hat = ((n - 1) / n) * W + ((m + 1) / (m * n)) * B - - # Add regularization to W for numerical stability - gershgorin = jnp.max(jnp.sum(jnp.abs(W), axis=1)) - regularization = jnp.finfo(W.dtype).eps * len(W) * gershgorin - W_reg = W + regularization * jnp.eye(p) - - # Compute max(eigvals(W^-1 V_hat)) - L = jnp.linalg.cholesky(W_reg) - # Solve L @ L.T @ x = V_hat @ x = λ @ W @ x - # This is equivalent to solving (L^-1 V_hat L^-T) @ y = λ @ y - L_1V = solve_triangular(L, V_hat, lower=True) - L_1VL_T = solve_triangular(L, L_1V.T, lower=True).T - eigenvals = jnp.linalg.eigvalsh(L_1VL_T) - - return jnp.max(eigenvals) - - -def rhat(chains: Real[Any, 'chain sample']) -> Float[Array, '']: - """ - Compute the univariate Gelman-Rubin R-hat. - - Parameters - ---------- - chains - Independent chains of samples of a scalar. - - Returns - ------- - Univariate R-hat statistic. - """ - chains = jnp.asarray(chains) - return multivariate_rhat(chains[:, :, None]) - - def test_rhat(keys): """Test the multivariate R-hat implementation.""" chains, divergent_chains = random.normal(keys.pop(), (2, 2, 1000, 10)) diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 9bc8c509..196b387a 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -47,7 +47,7 @@ _step_error_cov_inv_uv, step_trees, ) -from tests.util import assert_close_matrices +from tests.util import assert_close_matrices, rhat class TestWishart: @@ -497,4 +497,61 @@ def test_initialization_and_shapes(self, keys): y_pred = model.predict(X_test) assert y_pred.shape == (ndpost, k_dim, n_test) - # def test_mvbart_convergence(self, keys): + def test_mvbart_convergence(self, keys): + """Test that MV Bart chains converge using R-hat.""" + n_train = 200 + p, k_dim = 5, 2 + sigma_noise = 0.1 + + mc_cores = 4 + ndpost = 2000 + nsamples_per_chain = ndpost // mc_cores + nskip = 1000 + keepevery = 5 + ntree = 100 + + key_x, key_b, key_n1 = random.split(keys.pop(), 3) + X_train = random.normal(key_x, (p, n_train)) + B = random.uniform(key_b, (p, k_dim), minval=-1, maxval=1) + F_train = B.T @ X_train + Y_train = F_train + sigma_noise * random.normal(key_n1, (k_dim, n_train)) + + model = Bart( + x_train=X_train, + y_train=Y_train, + ntree=ntree, + ndpost=ndpost, + nskip=nskip, + keepevery=keepevery, + mc_cores=mc_cores, + seed=0, + ) + + # Check yhat Convergence + yhat_train = model.yhat_train.reshape( + mc_cores, nsamples_per_chain, k_dim, n_train + ) + summ = yhat_train.mean(axis=-1) # (mc_cores, nsamples_per_chain, k_dim) + + max_rhats_yhat = [rhat(summ[:, :, j]) for j in range(k_dim)] + rhat_mean = jnp.max(jnp.stack(max_rhats_yhat)) + print('Rhat on mean(yhat_train) per response:', rhat_mean) + + global_max_rhat = jnp.max(jnp.array(max_rhats_yhat)) + assert global_max_rhat < 1.1 + + # Check Covariance Matrix Convergence + prec_trace = model._main_trace.error_cov_inv + if prec_trace.ndim == 3: # (ndpost, k, k) -> reshape + prec_trace = prec_trace.reshape(mc_cores, nsamples_per_chain, k_dim, k_dim) + + prec_flat = prec_trace.reshape( + mc_cores, nsamples_per_chain, -1 + ) # Result shape: (chains, samples, k*k) + assert jnp.all(jnp.std(prec_flat, axis=1) > 1e-8), 'Sigma is not updating!' + + max_rhats_prec = [rhat(prec_flat[:, :, j]) for j in range(k_dim * k_dim)] + max_rhat_sigma = jnp.max(jnp.array(max_rhats_prec)) + print(f'Max R-hat for precision matrix: {max_rhat_sigma}') + + assert max_rhat_sigma < 1.1 diff --git a/tests/util.py b/tests/util.py index 3f4a4d07..0ef69168 100644 --- a/tests/util.py +++ b/tests/util.py @@ -28,11 +28,14 @@ from dataclasses import replace from operator import ge, le from pathlib import Path +from typing import Any import numpy as np import tomli from jax import numpy as jnp -from jaxtyping import ArrayLike +from jax import vmap +from jax.scipy.linalg import solve_triangular +from jaxtyping import Array, ArrayLike, Float, Real from scipy import linalg from bartz.debug import check_tree, describe_error @@ -198,3 +201,79 @@ def update_version(): """Update the version file.""" version = get_version() Path('src/bartz/_version.py').write_text(f'__version__ = {version!r}\n') + + +def multivariate_rhat(chains: Real[Any, 'chain sample dim']) -> Float[Array, '']: + """ + Compute the multivariate Gelman-Rubin R-hat. + + Parameters + ---------- + chains + Independent chains of samples of a vector. + + Returns + ------- + Multivariate R-hat statistic. + + Raises + ------ + ValueError + If there are not enough chains or samples. + """ + chains = jnp.asarray(chains) + m, n, p = chains.shape + + if m < 2: # pragma: no cover + msg = 'Need at least 2 chains' + raise ValueError(msg) + if n < 2: # pragma: no cover + msg = 'Need at least 2 samples per chain' + raise ValueError(msg) + + chain_means = jnp.mean(chains, axis=1) + + def compute_chain_cov(chain_samples, chain_mean): + centered = chain_samples - chain_mean + return jnp.dot(centered.T, centered) / (n - 1) + + within_chain_covs = vmap(compute_chain_cov)(chains, chain_means) + W = jnp.mean(within_chain_covs, axis=0) + + overall_mean = jnp.mean(chain_means, axis=0) + chain_mean_diffs = chain_means - overall_mean + B = (n / (m - 1)) * jnp.dot(chain_mean_diffs.T, chain_mean_diffs) + + V_hat = ((n - 1) / n) * W + ((m + 1) / (m * n)) * B + + # Add regularization to W for numerical stability + gershgorin = jnp.max(jnp.sum(jnp.abs(W), axis=1)) + regularization = jnp.finfo(W.dtype).eps * len(W) * gershgorin + W_reg = W + regularization * jnp.eye(p) + + # Compute max(eigvals(W^-1 V_hat)) + L = jnp.linalg.cholesky(W_reg) + # Solve L @ L.T @ x = V_hat @ x = λ @ W @ x + # This is equivalent to solving (L^-1 V_hat L^-T) @ y = λ @ y + L_1V = solve_triangular(L, V_hat, lower=True) + L_1VL_T = solve_triangular(L, L_1V.T, lower=True).T + eigenvals = jnp.linalg.eigvalsh(L_1VL_T) + + return jnp.max(eigenvals) + + +def rhat(chains: Real[Any, 'chain sample']) -> Float[Array, '']: + """ + Compute the univariate Gelman-Rubin R-hat. + + Parameters + ---------- + chains + Independent chains of samples of a scalar. + + Returns + ------- + Univariate R-hat statistic. + """ + chains = jnp.asarray(chains) + return multivariate_rhat(chains[:, :, None]) From d73beb32b22bce4eddd8efe9b9937d53e8438764 Mon Sep 17 00:00:00 2001 From: Miaoqing Yu Date: Mon, 12 Jan 2026 06:52:55 +0000 Subject: [PATCH 4/5] edit tests on mvbart interface --- src/bartz/_interface.py | 4 +-- tests/test_mvbart.py | 66 ++++++++++++++++++++++++++--------------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 56d6ac05..5c806ea7 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -449,7 +449,7 @@ def prob_train_mean(self) -> Float32[Array, ' n'] | None: return self.prob_train.mean(axis=0) @cached_property - def sigma( + def sigma( # need to change to adapt to matrix covariance matrix self, ) -> ( Float32[Array, ' nskip+ndpost'] @@ -473,7 +473,7 @@ def sigma( ) ) - @cached_property + @cached_property # need to change to adapt to matrix covariance matrix def sigma_(self) -> Float32[Array, 'ndpost'] | None: """The standard deviation of the error, only over the post-burnin samples and flattened.""" error_cov_inv = self._main_trace.error_cov_inv diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 196b387a..4ee45cb2 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -480,14 +480,37 @@ def test_mv_steps(self, keys, data): class TestMVBartInterface: """Tests for mvBart Interface.""" - def test_initialization_and_shapes(self, keys): + @pytest.fixture(params=[(10, 2, 2), (20, 5, 3), (3, 100, 4), (50, 50, 5)]) + def data_shape(self, request): + """Provide (n, p, k) triples for testing.""" + n, p, k = request.param + return n, p, k + + @pytest.fixture + def data(self, keys, data_shape): + """Generate a toy dataset. Mimic dgp from test_BART.py.""" + n, p, k = data_shape + sigma_noise = 0.1 + + key_x, key_eps = random.split(keys.pop(), 2) + X = random.uniform(key_x, (p, n), float, -2, 2) + + s = jnp.ones((k, p)) + norm_s = jnp.sqrt(jnp.sum(s * s, axis=1, keepdims=True)) # (k, 1) + + # F[d, i] = (s_d @ cos(pi * x_i)) / ||s_d|| + F = (s @ jnp.cos(jnp.pi * X)) / norm_s # (k, n) + + # iid N(0, sigma^2) noise across dims and obs + y = F + sigma_noise * random.normal(key_eps, (k, n)) + return X, y + + def test_initialization_and_shapes(self, data): """Test that mvBart predicts with correct shapes.""" - n, n_test, p, k_dim = 60, 40, 5, 3 + X, Y = data nskip, ndpost = 10, 50 - - X = random.normal(keys.pop(), (p, n)) - B = random.normal(keys.pop(), (p, k_dim)) - Y = B.T @ X + 0.1 * random.normal(keys.pop(), (k_dim, n)) + n_test = 40 + p, k_dim = X.shape[0], Y.shape[0] model = Bart( x_train=X, y_train=Y, ntree=10, ndpost=ndpost, nskip=nskip, mc_cores=2 @@ -497,25 +520,19 @@ def test_initialization_and_shapes(self, keys): y_pred = model.predict(X_test) assert y_pred.shape == (ndpost, k_dim, n_test) - def test_mvbart_convergence(self, keys): + def test_mvbart_convergence(self, data): """Test that MV Bart chains converge using R-hat.""" - n_train = 200 - p, k_dim = 5, 2 - sigma_noise = 0.1 + X_train, Y_train = data + _, n_train = X_train.shape + k_dim = Y_train.shape[0] mc_cores = 4 ndpost = 2000 nsamples_per_chain = ndpost // mc_cores - nskip = 1000 + nskip = 4000 keepevery = 5 ntree = 100 - key_x, key_b, key_n1 = random.split(keys.pop(), 3) - X_train = random.normal(key_x, (p, n_train)) - B = random.uniform(key_b, (p, k_dim), minval=-1, maxval=1) - F_train = B.T @ X_train - Y_train = F_train + sigma_noise * random.normal(key_n1, (k_dim, n_train)) - model = Bart( x_train=X_train, y_train=Y_train, @@ -531,9 +548,10 @@ def test_mvbart_convergence(self, keys): yhat_train = model.yhat_train.reshape( mc_cores, nsamples_per_chain, k_dim, n_train ) - summ = yhat_train.mean(axis=-1) # (mc_cores, nsamples_per_chain, k_dim) - - max_rhats_yhat = [rhat(summ[:, :, j]) for j in range(k_dim)] + yhat_train_mean = yhat_train.mean( + axis=-1 + ) # (mc_cores, nsamples_per_chain, k_dim) + max_rhats_yhat = [rhat(yhat_train_mean[:, :, j]) for j in range(k_dim)] rhat_mean = jnp.max(jnp.stack(max_rhats_yhat)) print('Rhat on mean(yhat_train) per response:', rhat_mean) @@ -542,16 +560,16 @@ def test_mvbart_convergence(self, keys): # Check Covariance Matrix Convergence prec_trace = model._main_trace.error_cov_inv - if prec_trace.ndim == 3: # (ndpost, k, k) -> reshape + if prec_trace.ndim == 3: prec_trace = prec_trace.reshape(mc_cores, nsamples_per_chain, k_dim, k_dim) - prec_flat = prec_trace.reshape( - mc_cores, nsamples_per_chain, -1 - ) # Result shape: (chains, samples, k*k) + prec_flat = prec_trace.reshape(mc_cores, nsamples_per_chain, -1) assert jnp.all(jnp.std(prec_flat, axis=1) > 1e-8), 'Sigma is not updating!' max_rhats_prec = [rhat(prec_flat[:, :, j]) for j in range(k_dim * k_dim)] max_rhat_sigma = jnp.max(jnp.array(max_rhats_prec)) + print(f'R-hat for precision matrix: {jnp.array(max_rhats_prec)}') print(f'Max R-hat for precision matrix: {max_rhat_sigma}') + assert all(max_rhats_prec) < 1.1 assert max_rhat_sigma < 1.1 From b896b3219fbd14e98437940f155a4424a53c0ac3 Mon Sep 17 00:00:00 2001 From: Miaoqing Yu Date: Mon, 12 Jan 2026 07:18:58 +0000 Subject: [PATCH 5/5] remove redundant code --- src/bartz/_interface.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 5c806ea7..9d78d93e 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -352,19 +352,6 @@ def __init__( ) ) - if y_train.ndim == 2: # Multivariate standardization - error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( - x_train, y_train, sigest, sigdf, sigquant, lamda - ) - leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye( - y_train.shape[0], dtype=jnp.float32 - ) - else: # Univariate standardization - lamda, sigest = self._process_error_variance_settings( - x_train, y_train, sigest, sigdf, sigquant, lamda - ) - leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) - # determine splits splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) x_train = self._bin_predictors(x_train, splits) @@ -691,7 +678,7 @@ def _process_error_variance_settings_mv( sigest: Float32[Array, ' k'] | None, sigdf: float, sigquant: float, - lamda_vec: float | Float32[Array, ' k'] | None, + lamda: float | Float32[Array, ' k'] | None, *, t0: float | None = None, s0: Float32[Array, 'k k'] | None = None, @@ -717,8 +704,8 @@ def _process_error_variance_settings_mv( return jnp.asarray(t0, dtype=jnp.float32), s0 # if t0 and s0 are none, use a diagonal construction - if lamda_vec is not None: - lamda_vec = jnp.atleast_1d(lamda_vec).astype(jnp.float32) + if lamda is not None: + lamda = jnp.atleast_1d(lamda).astype(jnp.float32) else: if sigest is not None: sigest = jnp.asarray(sigest, dtype=jnp.float32) @@ -748,11 +735,9 @@ def _process_error_variance_settings_mv( alpha = sigdf / 2.0 invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 invchi2rid = invchi2 * sigdf - lamda_vec = jnp.atleast_1d(sigest2_vec / invchi2rid).astype( - jnp.float32 - ) # (k,) + lamda = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) - s0 = jnp.diag(sigdf * lamda_vec).astype(jnp.float32) + s0 = jnp.diag(sigdf * lamda).astype(jnp.float32) return jnp.asarray(t0, dtype=jnp.float32), s0 @staticmethod