-
Notifications
You must be signed in to change notification settings - Fork 4
add mvbart to _interface.py #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
miaoqingyu2
wants to merge
10
commits into
bartz-org:main
Choose a base branch
from
miaoqingyu2:feature/mv-implementation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
304e6bc
add mvbart to interface
miaoqingyu2 85177c8
Merge remote-tracking branch 'origin/main' into feature/mv-implementa…
miaoqingyu2 5d68de7
Merge remote changes, accepting theirs for grove and mcmcloop
miaoqingyu2 6bb2747
move everything into bart class
miaoqingyu2 80606a0
merge remote changes, remove is_mv
miaoqingyu2 61854f1
add mvbart convergence test, not finished yet
miaoqingyu2 cceebd5
Merge remote-tracking branch 'origin/main' into feature/mv-implementa…
miaoqingyu2 17024a0
Merge remote-tracking branch 'origin/main' into feature/mv-implementa…
miaoqingyu2 d73beb3
edit tests on mvbart interface
miaoqingyu2 b896b32
remove redundant code
miaoqingyu2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -251,7 +251,7 @@ class Bart(Module): | |
| The prior mean of the latent mean function. | ||
| sigest : Float32[Array, ''] | None | ||
| The estimated standard deviation of the error used to set `lamda`. | ||
| yhat_test : Float32[Array, 'ndpost m'] | None | ||
| yhat_test : Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None | ||
| The conditional posterior mean at `x_test` for each MCMC iteration. | ||
|
|
||
| References | ||
|
|
@@ -273,12 +273,15 @@ class Bart(Module): | |
| ndpost: int = field(static=True) | ||
| offset: Float32[Array, ''] | ||
| sigest: Float32[Array, ''] | None = None | ||
| yhat_test: Float32[Array, 'ndpost m'] | None = None | ||
| yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None = None | ||
|
|
||
| def __init__( | ||
| self, | ||
| x_train: Real[Array, 'p n'] | DataFrame, | ||
| y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, | ||
| y_train: Bool[Array, ' n'] | ||
| | Float32[Array, ' n'] | ||
| | Float32[Array, 'k n'] | ||
| | Series, | ||
| *, | ||
| x_test: Real[Array, 'p m'] | DataFrame | None = None, | ||
| type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 | ||
|
|
@@ -290,15 +293,15 @@ def __init__( | |
| xinfo: Float[Array, 'p n'] | None = None, | ||
| usequants: bool = False, | ||
| rm_const: bool | None = True, | ||
| sigest: FloatLike | None = None, | ||
| sigest: FloatLike | Float32[Array, 'k k'] | None = None, | ||
| sigdf: FloatLike = 3.0, | ||
| sigquant: FloatLike = 0.9, | ||
| k: FloatLike = 2.0, | ||
| power: FloatLike = 2.0, | ||
| base: FloatLike = 0.95, | ||
| lamda: FloatLike | None = None, | ||
| tau_num: FloatLike | None = None, | ||
| offset: FloatLike | None = None, | ||
| lamda: FloatLike | None = None, # to change? | ||
| tau_num: FloatLike | None = None, # to change? | ||
| offset: FloatLike | None = None, # to change? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For offset: also support shape (k,) |
||
| w: Float[Array, ' n'] | Series | None = None, | ||
| ntree: int | None = None, | ||
| numcut: int = 100, | ||
|
|
@@ -315,13 +318,17 @@ def __init__( | |
| # check data and put it in the right format | ||
| x_train, x_train_fmt = self._process_predictor_input(x_train) | ||
| y_train = self._process_response_input(y_train) | ||
|
|
||
| self._check_same_length(x_train, y_train) | ||
| self._validate_compatibility(y_train, w, type) | ||
|
|
||
| if w is not None: | ||
| w = self._process_response_input(w) | ||
| self._check_same_length(x_train, w) | ||
|
|
||
| # check data types are correct for continuous/binary regression | ||
| self._check_type_settings(y_train, type, w) | ||
| if y_train.ndim == 1: | ||
| self._check_type_settings(y_train, type, w) | ||
| # from here onwards, the type is determined by y_train.dtype == bool | ||
|
|
||
| # set defaults that depend on type of regression | ||
|
|
@@ -338,8 +345,11 @@ def __init__( | |
| # process "standardization" settings | ||
| offset = self._process_offset_settings(y_train, offset) | ||
| sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) | ||
| lamda, sigest = self._process_error_variance_settings( | ||
| x_train, y_train, sigest, sigdf, sigquant, lamda | ||
|
|
||
| error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = ( | ||
| self._configure_priors( | ||
| x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda | ||
| ) | ||
| ) | ||
|
|
||
| # determine splits | ||
|
|
@@ -353,9 +363,12 @@ def __init__( | |
| offset, | ||
| w, | ||
| max_split, | ||
| lamda, | ||
| sigma_mu, | ||
| sigdf, | ||
| leaf_prior_cov_inv, | ||
| error_cov_df, | ||
| error_cov_scale, | ||
| # lamda, | ||
| # sigma_mu, | ||
| # sigdf, | ||
| power, | ||
| base, | ||
| maxdepth, | ||
|
|
@@ -377,7 +390,7 @@ def __init__( | |
| # set public attributes | ||
| self.offset = final_state.offset # from the state because of buffer donation | ||
| self.ndpost = main_trace.grow_prop_count.size | ||
| self.sigest = sigest | ||
| self.sigest = sigest if y_train.ndim == 1 else None | ||
|
|
||
| # set private attributes | ||
| self._main_trace = main_trace | ||
|
|
@@ -423,7 +436,7 @@ def prob_train_mean(self) -> Float32[Array, ' n'] | None: | |
| return self.prob_train.mean(axis=0) | ||
|
|
||
| @cached_property | ||
| def sigma( | ||
| def sigma( # need to change to adapt to matrix covariance matrix | ||
| self, | ||
| ) -> ( | ||
| Float32[Array, ' nskip+ndpost'] | ||
|
|
@@ -447,7 +460,7 @@ def sigma( | |
| ) | ||
| ) | ||
|
|
||
| @cached_property | ||
| @cached_property # need to change to adapt to matrix covariance matrix | ||
| def sigma_(self) -> Float32[Array, 'ndpost'] | None: | ||
| """The standard deviation of the error, only over the post-burnin samples and flattened.""" | ||
| error_cov_inv = self._main_trace.error_cov_inv | ||
|
|
@@ -508,13 +521,13 @@ def yhat_test_mean(self) -> Float32[Array, ' m'] | None: | |
| return self.yhat_test.mean(axis=0) | ||
|
|
||
| @cached_property | ||
| def yhat_train(self) -> Float32[Array, 'ndpost n']: | ||
| def yhat_train(self) -> Float32[Array, 'ndpost n'] | Float32[Array, 'ndpost k n']: | ||
| """The conditional posterior mean at `x_train` for each MCMC iteration.""" | ||
| x_train = self._mcmc_state.X | ||
| return self._predict(x_train) | ||
|
|
||
| @cached_property | ||
| def yhat_train_mean(self) -> Float32[Array, ' n'] | None: | ||
| def yhat_train_mean(self) -> Float32[Array, ' n'] | Float32[Array, ' k n'] | None: | ||
| """The marginal posterior mean at `x_train`. | ||
|
|
||
| Not defined with binary regression because it's error-prone, typically | ||
|
|
@@ -564,12 +577,62 @@ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: | |
| return x, fmt | ||
|
|
||
| @staticmethod | ||
| def _process_response_input(y) -> Shaped[Array, ' n']: | ||
| def _validate_compatibility(y_train, w, type): # noqa: A002 | ||
| """Validate inputs based on regression type (Univariate/Multivariate).""" | ||
| if y_train.ndim == 2: | ||
| if w is not None: | ||
| msg = "Weights 'w' are not supported for multivariate regression." | ||
| raise ValueError(msg) | ||
| if type != 'wbart': | ||
| msg = "Multivariate regression implies type='wbart'." | ||
| raise ValueError(msg) | ||
| if y_train.dtype == bool: | ||
| msg = 'Multivariate regression not yet support binary responses.' | ||
| raise TypeError(msg) | ||
|
|
||
| def _configure_priors( | ||
| self, x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda | ||
| ): | ||
| """Configure error covariance/variance priors and leaf priors.""" | ||
| if y_train.ndim == 2: | ||
| error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( | ||
| x_train, y_train, sigest, sigdf, sigquant, lamda | ||
| ) | ||
| leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye( | ||
| y_train.shape[0], dtype=jnp.float32 | ||
| ) | ||
| return error_cov_df, error_cov_scale, leaf_prior_cov_inv, None | ||
| else: | ||
| lamda_val, sigest_val = self._process_error_variance_settings( | ||
| x_train, y_train, sigest, sigdf, sigquant, lamda | ||
| ) | ||
| leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) | ||
|
|
||
| if y_train.dtype == bool: | ||
| error_cov_df = None | ||
| error_cov_scale = None | ||
| else: | ||
| error_cov_df = sigdf | ||
| error_cov_scale = lamda_val * sigdf | ||
|
|
||
| return error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest_val | ||
|
|
||
| @staticmethod | ||
| def _process_response_input(y) -> Shaped[Array, ' n'] | Shaped[Array, ' k n']: | ||
| if hasattr(y, 'to_numpy'): | ||
| y = y.to_numpy() | ||
| y = jnp.asarray(y) | ||
| assert y.ndim == 1 | ||
| return y | ||
|
|
||
| if y.ndim == 1: | ||
| return y | ||
| elif y.ndim == 2: | ||
| if y.dtype == bool: | ||
| msg = 'mvBART is continuous-only: y_train must be floating (not bool).' | ||
| raise ValueError(msg) | ||
| return y.astype(jnp.float32) | ||
| else: | ||
| msg = f'y_train must be 1D (n,) or 2D (k,n). Got {y.ndim=}.' | ||
| raise ValueError(msg) | ||
|
|
||
| @staticmethod | ||
| def _check_same_length(x1, x2): | ||
|
|
@@ -608,6 +671,75 @@ def _process_error_variance_settings( | |
| invchi2rid = invchi2 * sigdf | ||
| return sigest2 / invchi2rid, jnp.sqrt(sigest2) | ||
|
|
||
| @staticmethod | ||
| def _process_error_variance_settings_mv( | ||
| x_train: Real[Array, 'p n'], | ||
| y_train: Float32[Array, 'k n'], | ||
| sigest: Float32[Array, ' k'] | None, | ||
| sigdf: float, | ||
| sigquant: float, | ||
| lamda: float | Float32[Array, ' k'] | None, | ||
| *, | ||
| t0: float | None = None, | ||
| s0: Float32[Array, 'k k'] | None = None, | ||
| ) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]: | ||
| p = x_train.shape[0] | ||
| k, n = y_train.shape | ||
|
|
||
| # df of IW prior | ||
| if t0 is None: | ||
| t0 = float(sigdf + k - 1) | ||
| if t0 <= k - 1: | ||
| msg = f'Degrees of freedom `t0` must be > {k - 1}' | ||
| raise ValueError(msg) | ||
|
|
||
| # scale of IW prior: | ||
| if s0 is not None: | ||
| if s0.shape != (k, k): | ||
| msg = ValueError( | ||
| f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}' | ||
| ) | ||
| raise ValueError(msg) | ||
| s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32)) | ||
| return jnp.asarray(t0, dtype=jnp.float32), s0 | ||
|
|
||
| # if t0 and s0 are none, use a diagonal construction | ||
| if lamda is not None: | ||
| lamda = jnp.atleast_1d(lamda).astype(jnp.float32) | ||
| else: | ||
| if sigest is not None: | ||
| sigest = jnp.asarray(sigest, dtype=jnp.float32) | ||
| if sigest.shape != (k,): | ||
| msg = f'sigest must have shape ({k},), got {sigest.shape}' | ||
| raise ValueError(msg) | ||
| sigest2_vec = jnp.square(sigest) | ||
| elif n < 2: | ||
| sigest2_vec = jnp.ones((k,), dtype=jnp.float32) | ||
| elif n <= p: | ||
| sigest2_vec = jnp.var(y_train, axis=1) | ||
|
|
||
| else: | ||
| # OLS with implicit intercept via centering | ||
| # Xc: (n,p), Yc: (n,k) | ||
| Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T | ||
| Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T | ||
|
|
||
| coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) | ||
| R = Yc - Xc @ coef # (n,k) | ||
|
|
||
| # match univariate: chisq = sum residual^2, dof = n - rank | ||
| chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) | ||
| dof = jnp.maximum(1, n - rank) | ||
| sigest2_vec = chisq_vec / dof | ||
|
|
||
| alpha = sigdf / 2.0 | ||
| invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 | ||
| invchi2rid = invchi2 * sigdf | ||
| lamda = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) | ||
|
|
||
| s0 = jnp.diag(sigdf * lamda).astype(jnp.float32) | ||
| return jnp.asarray(t0, dtype=jnp.float32), s0 | ||
|
|
||
| @staticmethod | ||
| @jit | ||
| def _linear_regression( | ||
|
|
@@ -672,27 +804,38 @@ def _process_sparsity_settings( | |
|
|
||
| @staticmethod | ||
| def _process_offset_settings( | ||
| y_train: Float32[Array, ' n'] | Bool[Array, ' n'], | ||
| y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], | ||
| offset: float | Float32[Any, ''] | None, | ||
| ) -> Float32[Array, '']: | ||
| """Return offset.""" | ||
| if offset is not None: | ||
| return jnp.asarray(offset) | ||
| elif y_train.size < 1: | ||
| return jnp.array(0.0) | ||
| else: | ||
| mean = y_train.mean() | ||
| off = jnp.asarray(offset, dtype=jnp.float32) | ||
|
|
||
| if y_train.ndim == 2: | ||
| k = y_train.shape[0] | ||
| if off.ndim == 0: | ||
| return jnp.broadcast_to(off, (k,)) | ||
| if off.shape != (k,): | ||
| msg = f'Expected offset shape ({k},), got {off.shape=}' | ||
| raise ValueError(msg) | ||
| else: | ||
| return off | ||
|
|
||
| if y_train.ndim == 2: | ||
| return y_train.mean(axis=1) | ||
| if y_train.size < 1: | ||
| return jnp.array(0.0) | ||
| mean = y_train.mean() | ||
| if y_train.dtype == bool: | ||
| bound = 1 / (1 + y_train.size) | ||
| mean = jnp.clip(mean, bound, 1 - bound) | ||
| return ndtri(mean) | ||
| else: | ||
| return mean | ||
|
|
||
| return mean | ||
|
|
||
| @staticmethod | ||
| def _process_leaf_sdev_settings( | ||
| y_train: Float32[Array, ' n'] | Bool[Array, ' n'], | ||
| y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], | ||
| k: float, | ||
| ntree: int, | ||
| tau_num: FloatLike | None, | ||
|
|
@@ -701,11 +844,15 @@ def _process_leaf_sdev_settings( | |
| if tau_num is None: | ||
| if y_train.dtype == bool: | ||
| tau_num = 3.0 | ||
| elif y_train.ndim == 2: | ||
| if y_train.shape[1] < 2: | ||
| tau_num = jnp.ones(k) | ||
| else: | ||
| tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2 | ||
| elif y_train.size < 2: | ||
| tau_num = 1.0 | ||
| else: | ||
| tau_num = (y_train.max() - y_train.min()) / 2 | ||
|
|
||
| return tau_num / (k * math.sqrt(ntree)) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -734,13 +881,16 @@ def _bin_predictors( | |
| @staticmethod | ||
| def _setup_mcmc( | ||
| x_train: Real[Array, 'p n'], | ||
| y_train: Float32[Array, ' n'] | Bool[Array, ' n'], | ||
| y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], | ||
| offset: Float32[Array, ''], | ||
| w: Float[Array, ' n'] | None, | ||
| max_split: UInt[Array, ' p'], | ||
| lamda: Float32[Array, ''] | None, | ||
| sigma_mu: FloatLike, | ||
| sigdf: FloatLike, | ||
| # lamda: Float32[Array, ''] | None, | ||
| # sigma_mu: FloatLike, | ||
| # sigdf: FloatLike, | ||
| leaf_prior_cov_inv, | ||
| error_cov_df, | ||
| error_cov_scale, | ||
| power: FloatLike, | ||
| base: FloatLike, | ||
| maxdepth: int, | ||
|
|
@@ -758,15 +908,6 @@ def _setup_mcmc( | |
| depth = jnp.arange(maxdepth - 1) | ||
| p_nonterminal = base / (1 + depth).astype(float) ** power | ||
|
|
||
| if y_train.dtype == bool: | ||
| error_cov_df = None | ||
| error_cov_scale = None | ||
| else: | ||
| assert lamda is not None | ||
| # inverse gamma prior: alpha = df / 2, beta = scale / 2 | ||
| error_cov_df = sigdf | ||
| error_cov_scale = lamda * sigdf | ||
|
|
||
| kw: dict = dict( | ||
| X=x_train, | ||
| # copy y_train because it's going to be donated in the mcmc loop | ||
|
|
@@ -776,7 +917,7 @@ def _setup_mcmc( | |
| max_split=max_split, | ||
| num_trees=ntree, | ||
| p_nonterminal=p_nonterminal, | ||
| leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)), | ||
| leaf_prior_cov_inv=leaf_prior_cov_inv, | ||
| error_cov_df=error_cov_df, | ||
| error_cov_scale=error_cov_scale, | ||
| min_points_per_decision_node=10, | ||
|
|
@@ -844,4 +985,4 @@ def _run_mcmc( | |
| def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: | ||
| """Evaluate trees on already quantized `x`.""" | ||
| out = evaluate_trace(x, self._main_trace) | ||
| return collapse(out, 0, -1) | ||
| return collapse(out, 0, 2) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lambda and tau_num: try to work out a good multivariate generalization of this thing.
I guess tau_num can be a scalar, with leaf_prior_cov_inv just being 1/sigma_mu^2 on the diagonal, that part is fine as is.
For error_cov_scale you do multivariate linear regression and get the estimate of the error cov matrix, but then what to do with the chisquared quantile thing used for univariate? I guess there should be something with the maximum eigenvalue of the error covariance matrix, i.e., set the quantile on the weighted outcome average with the largest residual variance.
I haven't read in detail your code so maybe you've already done the above, but I need to go to bed now.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For tau_num: Currently, I allow it to be a vector (length$k$ ) calculated based on the range of each response dimension individually. This allows the leaf node variance to scale differently for each output if needed.
For error_cov_scale: I currently mimic the univariate logic for each dimension independently. I run the standard quantile calibration on each row of$y$ to get specific $\lambda$ values and place them on the diagonal of the prior matrix. It's just because I wanted to ensure that if $k=1$ , the multivariate initialization path yields the exact same priors as the existing univariate path.
However, I was also thinking about using the info from eigenvalues and make the scale matrix reveal some info about correlation of y's. Any preference or comments?