Skip to content
233 changes: 187 additions & 46 deletions src/bartz/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class Bart(Module):
The prior mean of the latent mean function.
sigest : Float32[Array, ''] | None
The estimated standard deviation of the error used to set `lamda`.
yhat_test : Float32[Array, 'ndpost m'] | None
yhat_test : Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None
The conditional posterior mean at `x_test` for each MCMC iteration.

References
Expand All @@ -273,12 +273,15 @@ class Bart(Module):
ndpost: int = field(static=True)
offset: Float32[Array, '']
sigest: Float32[Array, ''] | None = None
yhat_test: Float32[Array, 'ndpost m'] | None = None
yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None = None

def __init__(
self,
x_train: Real[Array, 'p n'] | DataFrame,
y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
y_train: Bool[Array, ' n']
| Float32[Array, ' n']
| Float32[Array, 'k n']
| Series,
*,
x_test: Real[Array, 'p m'] | DataFrame | None = None,
type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
Expand All @@ -290,15 +293,15 @@ def __init__(
xinfo: Float[Array, 'p n'] | None = None,
usequants: bool = False,
rm_const: bool | None = True,
sigest: FloatLike | None = None,
sigest: FloatLike | Float32[Array, 'k k'] | None = None,
sigdf: FloatLike = 3.0,
sigquant: FloatLike = 0.9,
k: FloatLike = 2.0,
power: FloatLike = 2.0,
base: FloatLike = 0.95,
lamda: FloatLike | None = None,
tau_num: FloatLike | None = None,
offset: FloatLike | None = None,
lamda: FloatLike | None = None, # to change?
tau_num: FloatLike | None = None, # to change?
Comment on lines +302 to +303
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lambda and tau_num: try to work out a good multivariate generalization of this thing.

I guess tau_num can be a scalar, with leaf_prior_cov_inv just being 1/sigma_mu^2 on the diagonal, that part is fine as is.

For error_cov_scale you do multivariate linear regression and get the estimate of the error cov matrix, but then what to do with the chisquared quantile thing used for univariate? I guess there should be something with the maximum eigenvalue of the error covariance matrix, i.e., set the quantile on the weighted outcome average with the largest residual variance.

I haven't read in detail your code so maybe you've already done the above, but I need to go to bed now.

Copy link
Contributor Author

@miaoqingyu2 miaoqingyu2 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lambda and tau_num: try to work out a good multivariate generalization of this thing.

I guess tau_num can be a scalar, with leaf_prior_cov_inv just being 1/sigma_mu^2 on the diagonal, that part is fine as is.

For error_cov_scale you do multivariate linear regression and get the estimate of the error cov matrix, but then what to do with the chisquared quantile thing used for univariate? I guess there should be something with the maximum eigenvalue of the error covariance matrix, i.e., set the quantile on the weighted outcome average with the largest residual variance.

I haven't read in detail your code so maybe you've already done the above, but I need to go to bed now.

For tau_num: Currently, I allow it to be a vector (length $k$) calculated based on the range of each response dimension individually. This allows the leaf node variance to scale differently for each output if needed.

For error_cov_scale: I currently mimic the univariate logic for each dimension independently. I run the standard quantile calibration on each row of $y$ to get specific $\lambda$ values and place them on the diagonal of the prior matrix. It's just because I wanted to ensure that if $k=1$, the multivariate initialization path yields the exact same priors as the existing univariate path.

However, I was also thinking about using the info from eigenvalues and make the scale matrix reveal some info about correlation of y's. Any preference or comments?

offset: FloatLike | None = None, # to change?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For offset: also support shape (k,)

w: Float[Array, ' n'] | Series | None = None,
ntree: int | None = None,
numcut: int = 100,
Expand All @@ -315,13 +318,17 @@ def __init__(
# check data and put it in the right format
x_train, x_train_fmt = self._process_predictor_input(x_train)
y_train = self._process_response_input(y_train)

self._check_same_length(x_train, y_train)
self._validate_compatibility(y_train, w, type)

if w is not None:
w = self._process_response_input(w)
self._check_same_length(x_train, w)

# check data types are correct for continuous/binary regression
self._check_type_settings(y_train, type, w)
if y_train.ndim == 1:
self._check_type_settings(y_train, type, w)
# from here onwards, the type is determined by y_train.dtype == bool

# set defaults that depend on type of regression
Expand All @@ -338,8 +345,11 @@ def __init__(
# process "standardization" settings
offset = self._process_offset_settings(y_train, offset)
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
lamda, sigest = self._process_error_variance_settings(
x_train, y_train, sigest, sigdf, sigquant, lamda

error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = (
self._configure_priors(
x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda
)
)

# determine splits
Expand All @@ -353,9 +363,12 @@ def __init__(
offset,
w,
max_split,
lamda,
sigma_mu,
sigdf,
leaf_prior_cov_inv,
error_cov_df,
error_cov_scale,
# lamda,
# sigma_mu,
# sigdf,
power,
base,
maxdepth,
Expand All @@ -377,7 +390,7 @@ def __init__(
# set public attributes
self.offset = final_state.offset # from the state because of buffer donation
self.ndpost = main_trace.grow_prop_count.size
self.sigest = sigest
self.sigest = sigest if y_train.ndim == 1 else None

# set private attributes
self._main_trace = main_trace
Expand Down Expand Up @@ -423,7 +436,7 @@ def prob_train_mean(self) -> Float32[Array, ' n'] | None:
return self.prob_train.mean(axis=0)

@cached_property
def sigma(
def sigma( # need to change to adapt to matrix covariance matrix
self,
) -> (
Float32[Array, ' nskip+ndpost']
Expand All @@ -447,7 +460,7 @@ def sigma(
)
)

@cached_property
@cached_property # need to change to adapt to matrix covariance matrix
def sigma_(self) -> Float32[Array, 'ndpost'] | None:
"""The standard deviation of the error, only over the post-burnin samples and flattened."""
error_cov_inv = self._main_trace.error_cov_inv
Expand Down Expand Up @@ -508,13 +521,13 @@ def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
return self.yhat_test.mean(axis=0)

@cached_property
def yhat_train(self) -> Float32[Array, 'ndpost n']:
def yhat_train(self) -> Float32[Array, 'ndpost n'] | Float32[Array, 'ndpost k n']:
"""The conditional posterior mean at `x_train` for each MCMC iteration."""
x_train = self._mcmc_state.X
return self._predict(x_train)

@cached_property
def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
def yhat_train_mean(self) -> Float32[Array, ' n'] | Float32[Array, ' k n'] | None:
"""The marginal posterior mean at `x_train`.

Not defined with binary regression because it's error-prone, typically
Expand Down Expand Up @@ -564,12 +577,62 @@ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
return x, fmt

@staticmethod
def _process_response_input(y) -> Shaped[Array, ' n']:
def _validate_compatibility(y_train, w, type): # noqa: A002
"""Validate inputs based on regression type (Univariate/Multivariate)."""
if y_train.ndim == 2:
if w is not None:
msg = "Weights 'w' are not supported for multivariate regression."
raise ValueError(msg)
if type != 'wbart':
msg = "Multivariate regression implies type='wbart'."
raise ValueError(msg)
if y_train.dtype == bool:
msg = 'Multivariate regression not yet support binary responses.'
raise TypeError(msg)

def _configure_priors(
self, x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda
):
"""Configure error covariance/variance priors and leaf priors."""
if y_train.ndim == 2:
error_cov_df, error_cov_scale = self._process_error_variance_settings_mv(
x_train, y_train, sigest, sigdf, sigquant, lamda
)
leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye(
y_train.shape[0], dtype=jnp.float32
)
return error_cov_df, error_cov_scale, leaf_prior_cov_inv, None
else:
lamda_val, sigest_val = self._process_error_variance_settings(
x_train, y_train, sigest, sigdf, sigquant, lamda
)
leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu))

if y_train.dtype == bool:
error_cov_df = None
error_cov_scale = None
else:
error_cov_df = sigdf
error_cov_scale = lamda_val * sigdf

return error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest_val

@staticmethod
def _process_response_input(y) -> Shaped[Array, ' n'] | Shaped[Array, ' k n']:
if hasattr(y, 'to_numpy'):
y = y.to_numpy()
y = jnp.asarray(y)
assert y.ndim == 1
return y

if y.ndim == 1:
return y
elif y.ndim == 2:
if y.dtype == bool:
msg = 'mvBART is continuous-only: y_train must be floating (not bool).'
raise ValueError(msg)
return y.astype(jnp.float32)
else:
msg = f'y_train must be 1D (n,) or 2D (k,n). Got {y.ndim=}.'
raise ValueError(msg)

@staticmethod
def _check_same_length(x1, x2):
Expand Down Expand Up @@ -608,6 +671,75 @@ def _process_error_variance_settings(
invchi2rid = invchi2 * sigdf
return sigest2 / invchi2rid, jnp.sqrt(sigest2)

@staticmethod
def _process_error_variance_settings_mv(
x_train: Real[Array, 'p n'],
y_train: Float32[Array, 'k n'],
sigest: Float32[Array, ' k'] | None,
sigdf: float,
sigquant: float,
lamda: float | Float32[Array, ' k'] | None,
*,
t0: float | None = None,
s0: Float32[Array, 'k k'] | None = None,
) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]:
p = x_train.shape[0]
k, n = y_train.shape

# df of IW prior
if t0 is None:
t0 = float(sigdf + k - 1)
if t0 <= k - 1:
msg = f'Degrees of freedom `t0` must be > {k - 1}'
raise ValueError(msg)

# scale of IW prior:
if s0 is not None:
if s0.shape != (k, k):
msg = ValueError(
f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}'
)
raise ValueError(msg)
s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32))
return jnp.asarray(t0, dtype=jnp.float32), s0

# if t0 and s0 are none, use a diagonal construction
if lamda is not None:
lamda = jnp.atleast_1d(lamda).astype(jnp.float32)
else:
if sigest is not None:
sigest = jnp.asarray(sigest, dtype=jnp.float32)
if sigest.shape != (k,):
msg = f'sigest must have shape ({k},), got {sigest.shape}'
raise ValueError(msg)
sigest2_vec = jnp.square(sigest)
elif n < 2:
sigest2_vec = jnp.ones((k,), dtype=jnp.float32)
elif n <= p:
sigest2_vec = jnp.var(y_train, axis=1)

else:
# OLS with implicit intercept via centering
# Xc: (n,p), Yc: (n,k)
Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T
Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T

coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k)
R = Yc - Xc @ coef # (n,k)

# match univariate: chisq = sum residual^2, dof = n - rank
chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,)
dof = jnp.maximum(1, n - rank)
sigest2_vec = chisq_vec / dof

alpha = sigdf / 2.0
invchi2 = invgamma.ppf(sigquant, alpha) / 2.0
invchi2rid = invchi2 * sigdf
lamda = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,)

s0 = jnp.diag(sigdf * lamda).astype(jnp.float32)
return jnp.asarray(t0, dtype=jnp.float32), s0

@staticmethod
@jit
def _linear_regression(
Expand Down Expand Up @@ -672,27 +804,38 @@ def _process_sparsity_settings(

@staticmethod
def _process_offset_settings(
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
offset: float | Float32[Any, ''] | None,
) -> Float32[Array, '']:
"""Return offset."""
if offset is not None:
return jnp.asarray(offset)
elif y_train.size < 1:
return jnp.array(0.0)
else:
mean = y_train.mean()
off = jnp.asarray(offset, dtype=jnp.float32)

if y_train.ndim == 2:
k = y_train.shape[0]
if off.ndim == 0:
return jnp.broadcast_to(off, (k,))
if off.shape != (k,):
msg = f'Expected offset shape ({k},), got {off.shape=}'
raise ValueError(msg)
else:
return off

if y_train.ndim == 2:
return y_train.mean(axis=1)
if y_train.size < 1:
return jnp.array(0.0)
mean = y_train.mean()
if y_train.dtype == bool:
bound = 1 / (1 + y_train.size)
mean = jnp.clip(mean, bound, 1 - bound)
return ndtri(mean)
else:
return mean

return mean

@staticmethod
def _process_leaf_sdev_settings(
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
k: float,
ntree: int,
tau_num: FloatLike | None,
Expand All @@ -701,11 +844,15 @@ def _process_leaf_sdev_settings(
if tau_num is None:
if y_train.dtype == bool:
tau_num = 3.0
elif y_train.ndim == 2:
if y_train.shape[1] < 2:
tau_num = jnp.ones(k)
else:
tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2
elif y_train.size < 2:
tau_num = 1.0
else:
tau_num = (y_train.max() - y_train.min()) / 2

return tau_num / (k * math.sqrt(ntree))

@staticmethod
Expand Down Expand Up @@ -734,13 +881,16 @@ def _bin_predictors(
@staticmethod
def _setup_mcmc(
x_train: Real[Array, 'p n'],
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
offset: Float32[Array, ''],
w: Float[Array, ' n'] | None,
max_split: UInt[Array, ' p'],
lamda: Float32[Array, ''] | None,
sigma_mu: FloatLike,
sigdf: FloatLike,
# lamda: Float32[Array, ''] | None,
# sigma_mu: FloatLike,
# sigdf: FloatLike,
leaf_prior_cov_inv,
error_cov_df,
error_cov_scale,
power: FloatLike,
base: FloatLike,
maxdepth: int,
Expand All @@ -758,15 +908,6 @@ def _setup_mcmc(
depth = jnp.arange(maxdepth - 1)
p_nonterminal = base / (1 + depth).astype(float) ** power

if y_train.dtype == bool:
error_cov_df = None
error_cov_scale = None
else:
assert lamda is not None
# inverse gamma prior: alpha = df / 2, beta = scale / 2
error_cov_df = sigdf
error_cov_scale = lamda * sigdf

kw: dict = dict(
X=x_train,
# copy y_train because it's going to be donated in the mcmc loop
Expand All @@ -776,7 +917,7 @@ def _setup_mcmc(
max_split=max_split,
num_trees=ntree,
p_nonterminal=p_nonterminal,
leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
leaf_prior_cov_inv=leaf_prior_cov_inv,
error_cov_df=error_cov_df,
error_cov_scale=error_cov_scale,
min_points_per_decision_node=10,
Expand Down Expand Up @@ -844,4 +985,4 @@ def _run_mcmc(
def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
"""Evaluate trees on already quantized `x`."""
out = evaluate_trace(x, self._main_trace)
return collapse(out, 0, -1)
return collapse(out, 0, 2)
Loading
Loading