diff --git a/src/bartz/_interface.py b/src/bartz/_interface.py index 7e767e4..9d78d93 100644 --- a/src/bartz/_interface.py +++ b/src/bartz/_interface.py @@ -251,7 +251,7 @@ class Bart(Module): The prior mean of the latent mean function. sigest : Float32[Array, ''] | None The estimated standard deviation of the error used to set `lamda`. - yhat_test : Float32[Array, 'ndpost m'] | None + yhat_test : Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None The conditional posterior mean at `x_test` for each MCMC iteration. References @@ -273,12 +273,15 @@ class Bart(Module): ndpost: int = field(static=True) offset: Float32[Array, ''] sigest: Float32[Array, ''] | None = None - yhat_test: Float32[Array, 'ndpost m'] | None = None + yhat_test: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] | None = None def __init__( self, x_train: Real[Array, 'p n'] | DataFrame, - y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, + y_train: Bool[Array, ' n'] + | Float32[Array, ' n'] + | Float32[Array, 'k n'] + | Series, *, x_test: Real[Array, 'p m'] | DataFrame | None = None, type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 @@ -290,15 +293,15 @@ def __init__( xinfo: Float[Array, 'p n'] | None = None, usequants: bool = False, rm_const: bool | None = True, - sigest: FloatLike | None = None, + sigest: FloatLike | Float32[Array, 'k k'] | None = None, sigdf: FloatLike = 3.0, sigquant: FloatLike = 0.9, k: FloatLike = 2.0, power: FloatLike = 2.0, base: FloatLike = 0.95, - lamda: FloatLike | None = None, - tau_num: FloatLike | None = None, - offset: FloatLike | None = None, + lamda: FloatLike | None = None, # to change? + tau_num: FloatLike | None = None, # to change? + offset: FloatLike | None = None, # to change? w: Float[Array, ' n'] | Series | None = None, ntree: int | None = None, numcut: int = 100, @@ -315,13 +318,17 @@ def __init__( # check data and put it in the right format x_train, x_train_fmt = self._process_predictor_input(x_train) y_train = self._process_response_input(y_train) + self._check_same_length(x_train, y_train) + self._validate_compatibility(y_train, w, type) + if w is not None: w = self._process_response_input(w) self._check_same_length(x_train, w) # check data types are correct for continuous/binary regression - self._check_type_settings(y_train, type, w) + if y_train.ndim == 1: + self._check_type_settings(y_train, type, w) # from here onwards, the type is determined by y_train.dtype == bool # set defaults that depend on type of regression @@ -338,8 +345,11 @@ def __init__( # process "standardization" settings offset = self._process_offset_settings(y_train, offset) sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) - lamda, sigest = self._process_error_variance_settings( - x_train, y_train, sigest, sigdf, sigquant, lamda + + error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = ( + self._configure_priors( + x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda + ) ) # determine splits @@ -353,9 +363,12 @@ def __init__( offset, w, max_split, - lamda, - sigma_mu, - sigdf, + leaf_prior_cov_inv, + error_cov_df, + error_cov_scale, + # lamda, + # sigma_mu, + # sigdf, power, base, maxdepth, @@ -377,7 +390,7 @@ def __init__( # set public attributes self.offset = final_state.offset # from the state because of buffer donation self.ndpost = main_trace.grow_prop_count.size - self.sigest = sigest + self.sigest = sigest if y_train.ndim == 1 else None # set private attributes self._main_trace = main_trace @@ -423,7 +436,7 @@ def prob_train_mean(self) -> Float32[Array, ' n'] | None: return self.prob_train.mean(axis=0) @cached_property - def sigma( + def sigma( # need to change to adapt to matrix covariance matrix self, ) -> ( Float32[Array, ' nskip+ndpost'] @@ -447,7 +460,7 @@ def sigma( ) ) - @cached_property + @cached_property # need to change to adapt to matrix covariance matrix def sigma_(self) -> Float32[Array, 'ndpost'] | None: """The standard deviation of the error, only over the post-burnin samples and flattened.""" error_cov_inv = self._main_trace.error_cov_inv @@ -508,13 +521,13 @@ def yhat_test_mean(self) -> Float32[Array, ' m'] | None: return self.yhat_test.mean(axis=0) @cached_property - def yhat_train(self) -> Float32[Array, 'ndpost n']: + def yhat_train(self) -> Float32[Array, 'ndpost n'] | Float32[Array, 'ndpost k n']: """The conditional posterior mean at `x_train` for each MCMC iteration.""" x_train = self._mcmc_state.X return self._predict(x_train) @cached_property - def yhat_train_mean(self) -> Float32[Array, ' n'] | None: + def yhat_train_mean(self) -> Float32[Array, ' n'] | Float32[Array, ' k n'] | None: """The marginal posterior mean at `x_train`. Not defined with binary regression because it's error-prone, typically @@ -564,12 +577,62 @@ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: return x, fmt @staticmethod - def _process_response_input(y) -> Shaped[Array, ' n']: + def _validate_compatibility(y_train, w, type): # noqa: A002 + """Validate inputs based on regression type (Univariate/Multivariate).""" + if y_train.ndim == 2: + if w is not None: + msg = "Weights 'w' are not supported for multivariate regression." + raise ValueError(msg) + if type != 'wbart': + msg = "Multivariate regression implies type='wbart'." + raise ValueError(msg) + if y_train.dtype == bool: + msg = 'Multivariate regression not yet support binary responses.' + raise TypeError(msg) + + def _configure_priors( + self, x_train, y_train, sigma_mu, sigest, sigdf, sigquant, lamda + ): + """Configure error covariance/variance priors and leaf priors.""" + if y_train.ndim == 2: + error_cov_df, error_cov_scale = self._process_error_variance_settings_mv( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = (1.0 / (sigma_mu**2)) * jnp.eye( + y_train.shape[0], dtype=jnp.float32 + ) + return error_cov_df, error_cov_scale, leaf_prior_cov_inv, None + else: + lamda_val, sigest_val = self._process_error_variance_settings( + x_train, y_train, sigest, sigdf, sigquant, lamda + ) + leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) + + if y_train.dtype == bool: + error_cov_df = None + error_cov_scale = None + else: + error_cov_df = sigdf + error_cov_scale = lamda_val * sigdf + + return error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest_val + + @staticmethod + def _process_response_input(y) -> Shaped[Array, ' n'] | Shaped[Array, ' k n']: if hasattr(y, 'to_numpy'): y = y.to_numpy() y = jnp.asarray(y) - assert y.ndim == 1 - return y + + if y.ndim == 1: + return y + elif y.ndim == 2: + if y.dtype == bool: + msg = 'mvBART is continuous-only: y_train must be floating (not bool).' + raise ValueError(msg) + return y.astype(jnp.float32) + else: + msg = f'y_train must be 1D (n,) or 2D (k,n). Got {y.ndim=}.' + raise ValueError(msg) @staticmethod def _check_same_length(x1, x2): @@ -608,6 +671,75 @@ def _process_error_variance_settings( invchi2rid = invchi2 * sigdf return sigest2 / invchi2rid, jnp.sqrt(sigest2) + @staticmethod + def _process_error_variance_settings_mv( + x_train: Real[Array, 'p n'], + y_train: Float32[Array, 'k n'], + sigest: Float32[Array, ' k'] | None, + sigdf: float, + sigquant: float, + lamda: float | Float32[Array, ' k'] | None, + *, + t0: float | None = None, + s0: Float32[Array, 'k k'] | None = None, + ) -> tuple[Float32[Array, 'k k'] | None, Float32[Array, 'k k'] | None]: + p = x_train.shape[0] + k, n = y_train.shape + + # df of IW prior + if t0 is None: + t0 = float(sigdf + k - 1) + if t0 <= k - 1: + msg = f'Degrees of freedom `t0` must be > {k - 1}' + raise ValueError(msg) + + # scale of IW prior: + if s0 is not None: + if s0.shape != (k, k): + msg = ValueError( + f'Scale matrix `s0` must have shape ({k}, {k}), got {s0.shape}' + ) + raise ValueError(msg) + s0 = jnp.diag(jnp.asarray(s0, dtype=jnp.float32)) + return jnp.asarray(t0, dtype=jnp.float32), s0 + + # if t0 and s0 are none, use a diagonal construction + if lamda is not None: + lamda = jnp.atleast_1d(lamda).astype(jnp.float32) + else: + if sigest is not None: + sigest = jnp.asarray(sigest, dtype=jnp.float32) + if sigest.shape != (k,): + msg = f'sigest must have shape ({k},), got {sigest.shape}' + raise ValueError(msg) + sigest2_vec = jnp.square(sigest) + elif n < 2: + sigest2_vec = jnp.ones((k,), dtype=jnp.float32) + elif n <= p: + sigest2_vec = jnp.var(y_train, axis=1) + + else: + # OLS with implicit intercept via centering + # Xc: (n,p), Yc: (n,k) + Xc = x_train.T - x_train.mean(axis=1, keepdims=True).T + Yc = y_train.T - y_train.mean(axis=1, keepdims=True).T + + coef, _, rank, _ = jnp.linalg.lstsq(Xc, Yc, rcond=None) # coef: (p,k) + R = Yc - Xc @ coef # (n,k) + + # match univariate: chisq = sum residual^2, dof = n - rank + chisq_vec = jnp.sum(jnp.square(R), axis=0) # (k,) + dof = jnp.maximum(1, n - rank) + sigest2_vec = chisq_vec / dof + + alpha = sigdf / 2.0 + invchi2 = invgamma.ppf(sigquant, alpha) / 2.0 + invchi2rid = invchi2 * sigdf + lamda = jnp.atleast_1d(sigest2_vec / invchi2rid).astype(jnp.float32) # (k,) + + s0 = jnp.diag(sigdf * lamda).astype(jnp.float32) + return jnp.asarray(t0, dtype=jnp.float32), s0 + @staticmethod @jit def _linear_regression( @@ -672,27 +804,38 @@ def _process_sparsity_settings( @staticmethod def _process_offset_settings( - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], offset: float | Float32[Any, ''] | None, ) -> Float32[Array, '']: """Return offset.""" if offset is not None: - return jnp.asarray(offset) - elif y_train.size < 1: - return jnp.array(0.0) - else: - mean = y_train.mean() + off = jnp.asarray(offset, dtype=jnp.float32) + + if y_train.ndim == 2: + k = y_train.shape[0] + if off.ndim == 0: + return jnp.broadcast_to(off, (k,)) + if off.shape != (k,): + msg = f'Expected offset shape ({k},), got {off.shape=}' + raise ValueError(msg) + else: + return off + if y_train.ndim == 2: + return y_train.mean(axis=1) + if y_train.size < 1: + return jnp.array(0.0) + mean = y_train.mean() if y_train.dtype == bool: bound = 1 / (1 + y_train.size) mean = jnp.clip(mean, bound, 1 - bound) return ndtri(mean) - else: - return mean + + return mean @staticmethod def _process_leaf_sdev_settings( - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], k: float, ntree: int, tau_num: FloatLike | None, @@ -701,11 +844,15 @@ def _process_leaf_sdev_settings( if tau_num is None: if y_train.dtype == bool: tau_num = 3.0 + elif y_train.ndim == 2: + if y_train.shape[1] < 2: + tau_num = jnp.ones(k) + else: + tau_num = (y_train.max(axis=1) - y_train.min(axis=1)) / 2 elif y_train.size < 2: tau_num = 1.0 else: tau_num = (y_train.max() - y_train.min()) / 2 - return tau_num / (k * math.sqrt(ntree)) @staticmethod @@ -734,13 +881,16 @@ def _bin_predictors( @staticmethod def _setup_mcmc( x_train: Real[Array, 'p n'], - y_train: Float32[Array, ' n'] | Bool[Array, ' n'], + y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], offset: Float32[Array, ''], w: Float[Array, ' n'] | None, max_split: UInt[Array, ' p'], - lamda: Float32[Array, ''] | None, - sigma_mu: FloatLike, - sigdf: FloatLike, + # lamda: Float32[Array, ''] | None, + # sigma_mu: FloatLike, + # sigdf: FloatLike, + leaf_prior_cov_inv, + error_cov_df, + error_cov_scale, power: FloatLike, base: FloatLike, maxdepth: int, @@ -758,15 +908,6 @@ def _setup_mcmc( depth = jnp.arange(maxdepth - 1) p_nonterminal = base / (1 + depth).astype(float) ** power - if y_train.dtype == bool: - error_cov_df = None - error_cov_scale = None - else: - assert lamda is not None - # inverse gamma prior: alpha = df / 2, beta = scale / 2 - error_cov_df = sigdf - error_cov_scale = lamda * sigdf - kw: dict = dict( X=x_train, # copy y_train because it's going to be donated in the mcmc loop @@ -776,7 +917,7 @@ def _setup_mcmc( max_split=max_split, num_trees=ntree, p_nonterminal=p_nonterminal, - leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)), + leaf_prior_cov_inv=leaf_prior_cov_inv, error_cov_df=error_cov_df, error_cov_scale=error_cov_scale, min_points_per_decision_node=10, @@ -844,4 +985,4 @@ def _run_mcmc( def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: """Evaluate trees on already quantized `x`.""" out = evaluate_trace(x, self._main_trace) - return collapse(out, 0, -1) + return collapse(out, 0, 2) diff --git a/tests/test_BART.py b/tests/test_BART.py index 3bbec48..2794b4c 100644 --- a/tests/test_BART.py +++ b/tests/test_BART.py @@ -43,12 +43,11 @@ import pytest from jax import debug_nans, lax, make_mesh, random, vmap from jax import numpy as jnp -from jax.scipy.linalg import solve_triangular from jax.scipy.special import logit, ndtr from jax.sharding import AxisType, SingleDeviceSharding from jax.tree import map_with_path from jax.tree_util import KeyPath -from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, UInt +from jaxtyping import Array, Bool, Float32, Int32, Key, Real, UInt from numpy.testing import assert_allclose, assert_array_equal from pytest_subtests import SubTests @@ -73,6 +72,8 @@ assert_close_matrices, assert_different_matrices, get_old_python_tuple, + multivariate_rhat, + rhat, ) @@ -1103,82 +1104,6 @@ def avg_max_tree_depth( return depth.mean(-1) -def multivariate_rhat(chains: Real[Any, 'chain sample dim']) -> Float[Array, '']: - """ - Compute the multivariate Gelman-Rubin R-hat. - - Parameters - ---------- - chains - Independent chains of samples of a vector. - - Returns - ------- - Multivariate R-hat statistic. - - Raises - ------ - ValueError - If there are not enough chains or samples. - """ - chains = jnp.asarray(chains) - m, n, p = chains.shape - - if m < 2: # pragma: no cover - msg = 'Need at least 2 chains' - raise ValueError(msg) - if n < 2: # pragma: no cover - msg = 'Need at least 2 samples per chain' - raise ValueError(msg) - - chain_means = jnp.mean(chains, axis=1) - - def compute_chain_cov(chain_samples, chain_mean): - centered = chain_samples - chain_mean - return jnp.dot(centered.T, centered) / (n - 1) - - within_chain_covs = vmap(compute_chain_cov)(chains, chain_means) - W = jnp.mean(within_chain_covs, axis=0) - - overall_mean = jnp.mean(chain_means, axis=0) - chain_mean_diffs = chain_means - overall_mean - B = (n / (m - 1)) * jnp.dot(chain_mean_diffs.T, chain_mean_diffs) - - V_hat = ((n - 1) / n) * W + ((m + 1) / (m * n)) * B - - # Add regularization to W for numerical stability - gershgorin = jnp.max(jnp.sum(jnp.abs(W), axis=1)) - regularization = jnp.finfo(W.dtype).eps * len(W) * gershgorin - W_reg = W + regularization * jnp.eye(p) - - # Compute max(eigvals(W^-1 V_hat)) - L = jnp.linalg.cholesky(W_reg) - # Solve L @ L.T @ x = V_hat @ x = λ @ W @ x - # This is equivalent to solving (L^-1 V_hat L^-T) @ y = λ @ y - L_1V = solve_triangular(L, V_hat, lower=True) - L_1VL_T = solve_triangular(L, L_1V.T, lower=True).T - eigenvals = jnp.linalg.eigvalsh(L_1VL_T) - - return jnp.max(eigenvals) - - -def rhat(chains: Real[Any, 'chain sample']) -> Float[Array, '']: - """ - Compute the univariate Gelman-Rubin R-hat. - - Parameters - ---------- - chains - Independent chains of samples of a scalar. - - Returns - ------- - Univariate R-hat statistic. - """ - chains = jnp.asarray(chains) - return multivariate_rhat(chains[:, :, None]) - - def test_rhat(keys): """Test the multivariate R-hat implementation.""" chains, divergent_chains = random.normal(keys.pop(), (2, 2, 1000, 10)) diff --git a/tests/test_mvbart.py b/tests/test_mvbart.py index 819d32e..4ee45cb 100644 --- a/tests/test_mvbart.py +++ b/tests/test_mvbart.py @@ -32,6 +32,7 @@ from numpy.testing import assert_allclose, assert_array_equal from scipy.stats import chi2, ks_1samp, ks_2samp +from bartz._interface import Bart from bartz.mcmcstep import State, init, step from bartz.mcmcstep._step import ( Counts, @@ -46,7 +47,7 @@ _step_error_cov_inv_uv, step_trees, ) -from tests.util import assert_close_matrices +from tests.util import assert_close_matrices, rhat class TestWishart: @@ -474,3 +475,101 @@ def test_mv_steps(self, keys, data): assert mv_state.error_cov_inv.shape == (k, k) assert mv_state.resid.shape == (k, y.shape[1]) + + +class TestMVBartInterface: + """Tests for mvBart Interface.""" + + @pytest.fixture(params=[(10, 2, 2), (20, 5, 3), (3, 100, 4), (50, 50, 5)]) + def data_shape(self, request): + """Provide (n, p, k) triples for testing.""" + n, p, k = request.param + return n, p, k + + @pytest.fixture + def data(self, keys, data_shape): + """Generate a toy dataset. Mimic dgp from test_BART.py.""" + n, p, k = data_shape + sigma_noise = 0.1 + + key_x, key_eps = random.split(keys.pop(), 2) + X = random.uniform(key_x, (p, n), float, -2, 2) + + s = jnp.ones((k, p)) + norm_s = jnp.sqrt(jnp.sum(s * s, axis=1, keepdims=True)) # (k, 1) + + # F[d, i] = (s_d @ cos(pi * x_i)) / ||s_d|| + F = (s @ jnp.cos(jnp.pi * X)) / norm_s # (k, n) + + # iid N(0, sigma^2) noise across dims and obs + y = F + sigma_noise * random.normal(key_eps, (k, n)) + return X, y + + def test_initialization_and_shapes(self, data): + """Test that mvBart predicts with correct shapes.""" + X, Y = data + nskip, ndpost = 10, 50 + n_test = 40 + p, k_dim = X.shape[0], Y.shape[0] + + model = Bart( + x_train=X, y_train=Y, ntree=10, ndpost=ndpost, nskip=nskip, mc_cores=2 + ) + + X_test = random.normal(random.key(1), (p, n_test)) + y_pred = model.predict(X_test) + assert y_pred.shape == (ndpost, k_dim, n_test) + + def test_mvbart_convergence(self, data): + """Test that MV Bart chains converge using R-hat.""" + X_train, Y_train = data + _, n_train = X_train.shape + k_dim = Y_train.shape[0] + + mc_cores = 4 + ndpost = 2000 + nsamples_per_chain = ndpost // mc_cores + nskip = 4000 + keepevery = 5 + ntree = 100 + + model = Bart( + x_train=X_train, + y_train=Y_train, + ntree=ntree, + ndpost=ndpost, + nskip=nskip, + keepevery=keepevery, + mc_cores=mc_cores, + seed=0, + ) + + # Check yhat Convergence + yhat_train = model.yhat_train.reshape( + mc_cores, nsamples_per_chain, k_dim, n_train + ) + yhat_train_mean = yhat_train.mean( + axis=-1 + ) # (mc_cores, nsamples_per_chain, k_dim) + max_rhats_yhat = [rhat(yhat_train_mean[:, :, j]) for j in range(k_dim)] + rhat_mean = jnp.max(jnp.stack(max_rhats_yhat)) + print('Rhat on mean(yhat_train) per response:', rhat_mean) + + global_max_rhat = jnp.max(jnp.array(max_rhats_yhat)) + assert global_max_rhat < 1.1 + + # Check Covariance Matrix Convergence + prec_trace = model._main_trace.error_cov_inv + if prec_trace.ndim == 3: + prec_trace = prec_trace.reshape(mc_cores, nsamples_per_chain, k_dim, k_dim) + + prec_flat = prec_trace.reshape(mc_cores, nsamples_per_chain, -1) + assert jnp.all(jnp.std(prec_flat, axis=1) > 1e-8), 'Sigma is not updating!' + + max_rhats_prec = [rhat(prec_flat[:, :, j]) for j in range(k_dim * k_dim)] + max_rhat_sigma = jnp.max(jnp.array(max_rhats_prec)) + print(f'R-hat for precision matrix: {jnp.array(max_rhats_prec)}') + print(f'Max R-hat for precision matrix: {max_rhat_sigma}') + + assert all(max_rhats_prec) < 1.1 + assert max_rhat_sigma < 1.1 diff --git a/tests/util.py b/tests/util.py index 7f8ab8c..32690c1 100644 --- a/tests/util.py +++ b/tests/util.py @@ -28,11 +28,14 @@ from dataclasses import replace from operator import ge, le from pathlib import Path +from typing import Any import numpy as np import tomli from jax import numpy as jnp -from jaxtyping import ArrayLike +from jax import vmap +from jax.scipy.linalg import solve_triangular +from jaxtyping import Array, ArrayLike, Float, Real from scipy import linalg from bartz.debug import check_tree, describe_error @@ -201,3 +204,79 @@ def get_old_python_tuple() -> tuple[int, int]: ver_str = get_old_python_str() major, minor = ver_str.split('.') return int(major), int(minor) + + +def multivariate_rhat(chains: Real[Any, 'chain sample dim']) -> Float[Array, '']: + """ + Compute the multivariate Gelman-Rubin R-hat. + + Parameters + ---------- + chains + Independent chains of samples of a vector. + + Returns + ------- + Multivariate R-hat statistic. + + Raises + ------ + ValueError + If there are not enough chains or samples. + """ + chains = jnp.asarray(chains) + m, n, p = chains.shape + + if m < 2: # pragma: no cover + msg = 'Need at least 2 chains' + raise ValueError(msg) + if n < 2: # pragma: no cover + msg = 'Need at least 2 samples per chain' + raise ValueError(msg) + + chain_means = jnp.mean(chains, axis=1) + + def compute_chain_cov(chain_samples, chain_mean): + centered = chain_samples - chain_mean + return jnp.dot(centered.T, centered) / (n - 1) + + within_chain_covs = vmap(compute_chain_cov)(chains, chain_means) + W = jnp.mean(within_chain_covs, axis=0) + + overall_mean = jnp.mean(chain_means, axis=0) + chain_mean_diffs = chain_means - overall_mean + B = (n / (m - 1)) * jnp.dot(chain_mean_diffs.T, chain_mean_diffs) + + V_hat = ((n - 1) / n) * W + ((m + 1) / (m * n)) * B + + # Add regularization to W for numerical stability + gershgorin = jnp.max(jnp.sum(jnp.abs(W), axis=1)) + regularization = jnp.finfo(W.dtype).eps * len(W) * gershgorin + W_reg = W + regularization * jnp.eye(p) + + # Compute max(eigvals(W^-1 V_hat)) + L = jnp.linalg.cholesky(W_reg) + # Solve L @ L.T @ x = V_hat @ x = λ @ W @ x + # This is equivalent to solving (L^-1 V_hat L^-T) @ y = λ @ y + L_1V = solve_triangular(L, V_hat, lower=True) + L_1VL_T = solve_triangular(L, L_1V.T, lower=True).T + eigenvals = jnp.linalg.eigvalsh(L_1VL_T) + + return jnp.max(eigenvals) + + +def rhat(chains: Real[Any, 'chain sample']) -> Float[Array, '']: + """ + Compute the univariate Gelman-Rubin R-hat. + + Parameters + ---------- + chains + Independent chains of samples of a scalar. + + Returns + ------- + Univariate R-hat statistic. + """ + chains = jnp.asarray(chains) + return multivariate_rhat(chains[:, :, None])