diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md new file mode 100644 index 0000000..e176051 --- /dev/null +++ b/cuthbert/factorial/README.md @@ -0,0 +1,133 @@ +# Factorial State-Space Models + +A factorial state-space model is a state-space model where the dynamics distribution +factors into a product of independent distributions across factors + +$$ +p(x_t \mid x_{t-1}) = \prod_{f=1}^F p(x_t^f \mid x_{t-1}^f), +$$ +for factorial index $f \in \{1, \ldots, F\}$. We additionally assume that observations +act locally on some subset of factors $S_t \subseteq \{1, \ldots, F\}$. + +$$ +p(y_t \mid x_t) = p(y_t \mid x_t^{S_t}). +$$ + +This motivates a factored approximation of filtering and smoothing distributions, e.g. + +$$ +p(x_t \mid y_{1:t}) = \prod_{f=1}^F p(x_t^f \mid y_{1:t}). +$$ + +A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035). + +The factorial approximation allows us to exploit significant benefits in terms of +memory, compute and parallelization. + +Note that although the dynamics are factorized, `cuthbert` does not differentiate +between `predict` and `update` (instead favouring a unified filter operation +via `filter_prepare` and `filter_combine`). Thus the dynamics and model inputs +should be specified to act on the joint local state (i.e. block diagonal +where appropriate). + + +## Factorial filtering with `cuthbert` + +Filtering in a factorial state-space model is similar to standard filtering, but with +an additional step before the filtering operation to extract the relevant +factors as well as an additional step after the filtering operation to insert the +updated factors back into the factorial state. + + +```python +from jax import tree +import cuthbert + +# Define model_inputs +model_inputs = ... + +# Define function to extract the factorial indices from model inputs +# Here we assume model_inputs is a NamedTuple with a field `factorial_inds` +get_factorial_indices = lambda mi: mi.factorial_inds + +# Build factorializer for the inference method +factorializer = cuthbert.factorial.gaussian.build_factorializer(get_factorial_indices) + +# Load inference method, with parameter extraction functions defined for factorial inference +kalman_filter = cuthbert.gaussian.kalman.build_filter( + get_init_params=get_init_params, # Init specified to generate factorial state + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state + get_observation_params=get_observation_params, # Observation specified to act on joint local state +) + +# Online inference +factorial_state = kalman_filter.init_prepare(tree.map(lambda x: x[0], model_inputs)) + +for t in range(1, T): + model_inputs_t = tree.map(lambda x: x[t], model_inputs) + factorial_inds = get_factorial_indices(model_inputs_t) + local_state = factorializer.extract_and_join(factorial_state, factorial_inds) + prepare_state = kalman_filter.filter_prepare(model_inputs_t) + filtered_local_state = kalman_filter.filter_combine(local_state, prepare_state) + factorial_state = factorializer.marginalize_and_insert( + filtered_local_state, factorial_state, factorial_inds + ) +``` + +You can also use `cuthbert.factorial.filter` for convenient offline filtering. +Note that associative/parallel filtering is not supported for factorial filtering. + +```python +init_factorial_state, local_filter_states = cuthbert.factorial.filter( + kalman_filter, factorializer, model_inputs, output_factorial=False +) +``` + +## Factorial smoothing with `cuthbert` + +Smoothing in factorial state-space models can be performed embarrassingly parallel +across factors since the dynamics and factorial approximation are independent +across factors (the observations are fully absorbed in the filtering and +are not accessed during smoothing). + +The model inputs and filter states require some preprocessing to convert from being +single sequence with each state containing all factors into a sequence or multiple +sequences with each state corresponding to a single factor. This can be +fiddly but is left to the user for maximum freedom. Oftentimes, it is easiest to +specify different parameter functions for smoothing than filtering. + +After this preprocessing, smoothing can be performed as usual: + +```python +# Define model_inputs for a single factor +model_inputs_single_factor = ... + +# Similarly, we need to extract the filter states for the single factor we're smoothing. +filter_states_single_factor = ... + +# Load smoother, with parameter extraction functions defined for a single factor +kalman_smoother = cuthbert.gaussian.kalman.build_smoother( + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on a single factor +) + +smoother_state = kalman_smoother.convert_filter_to_smoother_state( + tree.map(lambda x: x[-1], filter_states_single_factor), + model_inputs=tree.map(lambda x: x[-1], model_inputs_single_factor), +) + +for t in range(T - 1, -1, -1): + model_inputs_single_factor_t = tree.map(lambda x: x[t], model_inputs_single_factor) + filter_state_single_factor_t = tree.map(lambda x: x[t], filter_states_single_factor) + prepare_state = kalman_smoother.smoother_prepare( + filter_state_single_factor_t, model_inputs_single_factor_t + ) + smoother_state = kalman_smoother.smoother_combine(prepare_state, smoother_state) +``` + +Or directly using the `cuthbert.smoother`: + +```python +smoother_states = cuthbert.smoother( + kalman_smoother, filter_states_single_factor, model_inputs_single_factor +) +``` \ No newline at end of file diff --git a/cuthbert/factorial/__init__.py b/cuthbert/factorial/__init__.py new file mode 100644 index 0000000..87a6a70 --- /dev/null +++ b/cuthbert/factorial/__init__.py @@ -0,0 +1,8 @@ +from cuthbert.factorial import gaussian +from cuthbert.factorial.filtering import filter +from cuthbert.factorial.types import ( + ExtractAndJoin, + Factorializer, + GetFactorialIndices, + MarginalizeAndInsert, +) diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py new file mode 100644 index 0000000..d909f38 --- /dev/null +++ b/cuthbert/factorial/filtering.py @@ -0,0 +1,110 @@ +"""cuthbert factorial filtering interface.""" + +from jax import numpy as jnp +from jax import random, tree +from jax.lax import scan + +from cuthbert.factorial.types import Factorializer +from cuthbert.inference import Filter +from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray + + +def filter( + filter_obj: Filter, + factorializer: Factorializer, + model_inputs: ArrayTreeLike, + output_factorial: bool = False, + key: KeyArray | None = None, +) -> ( + ArrayTree | tuple[ArrayTree, ArrayTree] +): # TODO: Can overload this function so the type checker knows that the output is a ArrayTree if output_factorial is True and a tuple[ArrayTree, ArrayTree] if output_factorial is False + """Applies offline factorial filtering for given model inputs. + + `model_inputs` should have leading temporal dimension of length T + 1, + where T is the number of time steps excluding the initial state. + + Parallel associative filtering is not supported for factorial filtering. + + Note that if output_factorial is True, this function will output a factorial state + with first temporal dimension of length T + 1 and second factorial dimension of + length F. Many of the factors will be unchanged across timesteps where they aren't + relevant. + + Args: + filter_obj: The filter inference object. + factorializer: The factorializer object for the inference method. + model_inputs: The model inputs (with leading temporal dimension of length T + 1). + output_factorial: If True, return a single state with first temporal dimension + of length T + 1 and second factorial dimension of length F. + If False, return a tuple of states. The first being the initial state + with first dimension of length F and temporal dimension. + The second being the local states for each time step, i.e. first + dimension of length T and no factorial dimension. + key: The key for the random number generator. + + Returns: + The filtered states (NamedTuple with leading temporal dimension of length T + 1). + """ + T = tree.leaves(model_inputs)[0].shape[0] - 1 + + if key is None: + # This will throw error if used as a key, which is desired behavior + # (albeit not a useful error, we could improve this) + prepare_keys = jnp.empty(T + 1) + else: + prepare_keys = random.split(key, T + 1) + + init_model_input = tree.map(lambda x: x[0], model_inputs) + init_factorial_state = filter_obj.init_prepare( + init_model_input, key=prepare_keys[0] + ) + + prep_model_inputs = tree.map(lambda x: x[1:], model_inputs) + + def body_local(prev_factorial_state, prep_inp_and_k): + prep_inp, k = prep_inp_and_k + factorial_inds = factorializer.get_factorial_indices(prep_inp) + local_state = factorializer.extract_and_join( + prev_factorial_state, factorial_inds + ) + prep_state = filter_obj.filter_prepare(prep_inp, key=k) + filtered_joint_state = filter_obj.filter_combine(local_state, prep_state) + factorial_state = factorializer.marginalize_and_insert( + filtered_joint_state, prev_factorial_state, factorial_inds + ) + + def extract(arr): + if arr.ndim >= 2: + return arr[factorial_inds] + else: + return arr + + factorial_state_fac_inds_only = tree.map(extract, factorial_state) + return factorial_state, factorial_state_fac_inds_only + + if output_factorial: + + def body_factorial(prev_factorial_state, prep_inp_and_k): + factorial_state, _ = body_local(prev_factorial_state, prep_inp_and_k) + return factorial_state, factorial_state + + _, factorial_states = scan( + body_factorial, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + factorial_states = tree.map( + lambda x, y: jnp.concatenate([x[None], y]), + init_factorial_state, + factorial_states, + ) + + return factorial_states + + else: + _, local_states = scan( + body_local, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + return init_factorial_state, local_states diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py new file mode 100644 index 0000000..f20c9be --- /dev/null +++ b/cuthbert/factorial/gaussian.py @@ -0,0 +1,179 @@ +"""Factorial utilities for Kalman states.""" + +from typing import TypeVar + +from jax import numpy as jnp +from jax import tree +from jax.scipy.linalg import block_diag + +from cuthbert.factorial.types import Factorializer, GetFactorialIndices +from cuthbert.gaussian.kalman import KalmanFilterState +from cuthbert.gaussian.types import LinearizedKalmanFilterState +from cuthbertlib.linalg import block_marginal_sqrt_cov +from cuthbertlib.types import Array, ArrayLike + +KalmanState = TypeVar("KalmanState", KalmanFilterState, LinearizedKalmanFilterState) + + +def build_factorializer( + get_factorial_indices: GetFactorialIndices, +) -> Factorializer: + """Build a factorializer for Kalman states. + + Args: + get_factorial_indices: Function to extract the factorial indices + from model inputs. + + Returns: + Factorializer object for Kalman states with functions to extract and join + the relevant factors and marginalize and insert the updated factors. + """ + return Factorializer( + get_factorial_indices=get_factorial_indices, + extract_and_join=extract_and_join, + marginalize_and_insert=marginalize_and_insert, + ) + + +def extract_and_join( + factorial_state: KalmanState, factorial_inds: ArrayLike +) -> KalmanState: + """Convert a factorial Kalman state into a joint local Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Two dimensional arrays will be treated as means with shape (F, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a single array. + Three dimensional arrays will be treated as chol_covs with shape (F, d, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a block diagonal array. + + Here F is the number of factors and d is the dimension of the state. + + Args: + factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + factorial_inds: Indices of the factors to extract. Integer array. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + factorial_inds = jnp.asarray(factorial_inds) + new_elem = tree.map( + lambda x: _extract_and_join_arr(x, factorial_inds), factorial_state.elem + ) + new_state = factorial_state._replace(elem=new_elem) + + if isinstance(factorial_state, LinearizedKalmanFilterState): + new_mean_prev = _extract_and_join_arr(factorial_state.mean_prev, factorial_inds) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _extract_and_join_arr(arr: Array, factorial_inds: Array) -> Array: + if arr.ndim == 0 or arr.ndim == 1: + return arr + elif arr.ndim == 2: + return _extract_and_join_means(arr, factorial_inds) + elif arr.ndim == 3: + return _extract_and_join_chol_covs(arr, factorial_inds) + else: + raise ValueError(f"Array must be 3D or lower, got {arr.ndim}D") + + +def _extract_and_join_means(means: Array, factorial_inds: Array) -> Array: + return means[factorial_inds].reshape(-1) + + +def _extract_and_join_chol_covs(chol_covs: Array, factorial_inds: Array) -> Array: + selected_chol_covs = chol_covs[factorial_inds] + return block_diag(*selected_chol_covs) + + +def marginalize_and_insert( + local_state: KalmanState, + factorial_state: KalmanState, + factorial_inds: ArrayLike, +) -> KalmanState: + """Marginalize and insert a joint local Kalman state into a factorial Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Two dimensional arrays will be treated as means with shape (F, d). + In this case the dimension d will be inferred and then the array split into + len(factorial_inds) arrays of shape (d,) then inserted into the factorial array + at the factorial_inds indices. + Three dimensional arrays will be treated as chol_covs with shape (F, d, d). + In this case the dimension d will be inferred and then the array split into + len(factorial_inds) arrays of shape (d, d) (noting that the marginal_sqrt_cov + function is called to preserve the lower triangular structure) then inserted + into the factorial array at the factorial_inds indices. + + Here F is the number of factors and d is the dimension of the state. + + Args: + local_state: Joint local Kalman state to marginalize and insert. + With means and chol_covs with shape (d * len(factorial_inds),) + and (d * len(factorial_inds), d * len(factorial_inds)) respectively. + factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively.\ + factorial_inds: Indices of the factors to insert. Integer array. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + factorial_inds = jnp.asarray(factorial_inds) + new_elem = tree.map( + lambda loc, fac: _marginalize_and_insert_arr(loc, fac, factorial_inds), + local_state.elem, + factorial_state.elem, + ) + new_state = local_state._replace(elem=new_elem) + if isinstance(local_state, LinearizedKalmanFilterState) and isinstance( + factorial_state, LinearizedKalmanFilterState + ): + new_mean_prev = _marginalize_and_insert_arr( + local_state.mean_prev, factorial_state.mean_prev, factorial_inds + ) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _marginalize_and_insert_arr( + local_arr: Array, factorial_arr: Array, factorial_inds: ArrayLike +) -> Array: + factorial_inds = jnp.asarray(factorial_inds) + if factorial_arr.ndim == 0 or factorial_arr.ndim == 1: + return local_arr + elif factorial_arr.ndim == 2: + return _marginalize_and_insert_mean(local_arr, factorial_arr, factorial_inds) + elif factorial_arr.ndim == 3: + return _marginalize_and_insert_chol_cov( + local_arr, factorial_arr, factorial_inds + ) + else: + raise ValueError(f"Array must be 3D or lower, got {local_arr.ndim}D") + + +def _marginalize_and_insert_mean( + local_mean: Array, + factorial_means: Array, + factorial_inds: Array, +) -> Array: + local_mean_with_factorial_dimension = local_mean.reshape(len(factorial_inds), -1) + return factorial_means.at[factorial_inds].set(local_mean_with_factorial_dimension) + + +def _marginalize_and_insert_chol_cov( + local_chol_cov: Array, + factorial_chol_covs: Array, + factorial_inds: Array, +) -> Array: + d = factorial_chol_covs.shape[-1] + marginal_chol_covs = block_marginal_sqrt_cov(local_chol_cov, d) + return factorial_chol_covs.at[factorial_inds].set(marginal_chol_covs) diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py new file mode 100644 index 0000000..dfeeeed --- /dev/null +++ b/cuthbert/factorial/types.py @@ -0,0 +1,96 @@ +"""Provides types for factorial state-space models.""" + +from typing import NamedTuple, Protocol + +from cuthbertlib.types import ArrayLike, ArrayTree, ArrayTreeLike + + +class GetFactorialIndices(Protocol): + """Protocol for getting the factorial indices.""" + + def __call__(self, model_inputs: ArrayTreeLike) -> ArrayLike: + """Extract the factorial indices from model inputs. + + Args: + model_inputs: Model inputs. + + Returns: + Indices of the factors to extract. Integer array. + """ + ... + + +class ExtractAndJoin(Protocol): + """Protocol for extracting and joining the relevant factors.""" + + def __call__( + self, + factorial_state: ArrayTreeLike, + factorial_inds: ArrayLike, + ) -> ArrayTree: + """Extract factors from factorial state and combine into a joint local state. + + E.g. factorial_state might encode factorial `means` with shape (F, d) and + `chol_covs` with shape (F, d, d). Then `model_inputs` tells us factors `i` and + `j` are relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and + `chol_covs[j]`. Then combine them into `joint_mean` with shape (2 * d,) + and block diagonal `joint_chol_cov` with shape (2 * d, 2 * d). + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + factorial_inds: Indices of the factors to extract. Integer array. + + Returns: + Joint local state with no factorial index dimension. + """ + ... + + +class MarginalizeAndInsert(Protocol): + """Protocol for marginalizing and inserting the updated factors.""" + + def __call__( + self, + local_state: ArrayTree, + factorial_state: ArrayTree, + factorial_inds: ArrayLike, + ) -> ArrayTree: + """Marginalize joint state into factored state and insert into factorial state. + + E.g. `local_state` might have shape (2 * d,) and `joint_chol_cov` + with shape (2 * d, 2 * d). Then we marginalize out the joint local state into + two factorial `means` with shape (2, d) and `chol_covs` with shape (2, d, d). + If `model_inputs` tells us we're working with factors `i` and `j`, then we + insert `means[0]` and `means[1]` into `state[i]` and `state[j]` respectively. + Similarly, we insert `chol_covs[0]` and `chol_covs[1]`. In both cases, we + overwrite the existing factors in the factorial state for `i` and `j`, + leaving the other factors unchanged. + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + local_state: Joint local state with no factorial index dimension. + factorial_inds: Indices of the factors to insert. Integer array. + + Returns: + Factorial state with factorial index as the first dimension. + The updated factors are inserted into the factorial state. + The remaining factors are left unchanged. + """ + ... + + +class Factorializer(NamedTuple): + """Factorializer object. + + Attributes: + get_factorial_indices: Function to get the factorial indices. + Model inputs dependent. + extract_and_join: Function to extract and join the relevant factors. + Inference method dependent (e.g. Gaussian/SMC etc) + marginalize_and_insert: Function to marginalize and insert the updated factors. + Inference method dependent (e.g. Gaussian/SMC etc). + """ + + get_factorial_indices: GetFactorialIndices + extract_and_join: ExtractAndJoin + marginalize_and_insert: MarginalizeAndInsert diff --git a/tests/cuthbert/factorial/__init__.py b/tests/cuthbert/factorial/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cuthbert/factorial/gaussian_utils.py b/tests/cuthbert/factorial/gaussian_utils.py new file mode 100644 index 0000000..e357911 --- /dev/null +++ b/tests/cuthbert/factorial/gaussian_utils.py @@ -0,0 +1,44 @@ +import jax.numpy as jnp +from jax import random, vmap + +from cuthbertlib.kalman import generate + + +def generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps +): + # T = num_time_steps, F = num_factors + + key = random.key(seed) + init_key, factorial_indices_key = random.split(key, 2) + + # m0 with shape (F, x_dim) + # chol_P0 with shape (F, x_dim, x_dim) + init_keys_factorial = random.split(init_key, num_factors) + m0s, chol_P0s = vmap(generate.generate_init_model, in_axes=(0, None))( + init_keys_factorial, x_dim + ) + + # Fs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # cs with shape (T, num_factors_local * x_dim) + # chol_Qs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # Hs with shape (T, d_y, num_factors_local * x_dim) + # ds with shape (T, y_dim) + # chol_Rs with shape (T, num_factors_local * y_dim, num_factors_local * y_dim) + # ys with shape (T, d_y) + _, _, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate.generate_lgssm( + seed + 1, num_factors_local * x_dim, y_dim, num_time_steps + ) + + # factorial_indices with shape (T, num_factors_local) + # Each entry is a random integer in {0, ..., num_factors - 1} + # But each row must have unique entries + def rand_unique_indices(key): + indices = random.choice( + key, jnp.arange(num_factors), (num_factors_local,), replace=False + ) + return indices + + factorial_indices_keys = random.split(factorial_indices_key, num_time_steps) + factorial_indices = vmap(rand_unique_indices)(factorial_indices_keys) + return m0s, chol_P0s, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys, factorial_indices diff --git a/tests/cuthbert/factorial/test_kalman.py b/tests/cuthbert/factorial/test_kalman.py new file mode 100644 index 0000000..7de0042 --- /dev/null +++ b/tests/cuthbert/factorial/test_kalman.py @@ -0,0 +1,258 @@ +import itertools +from typing import cast, Callable + +import chex +import jax +import jax.numpy as jnp +import pytest +from jax import Array +from jax.scipy.linalg import block_diag + +from cuthbert import factorial +from cuthbert.gaussian import kalman +from cuthbert.inference import Filter, Smoother +from cuthbertlib.types import ArrayTree +from cuthbertlib.linalg import block_marginal_sqrt_cov +from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model +from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update + + +@pytest.fixture(scope="module", autouse=True) +def config(): + jax.config.update("jax_enable_x64", True) + yield + jax.config.update("jax_enable_x64", False) + + +def load_kalman_pairwise_factorial_inference( + m0: Array, # (F, d) + chol_P0: Array, # (F, d, d) + Fs: Array, # (T, 2 * d, 2 * d) + cs: Array, # (T, 2 * d) + chol_Qs: Array, # (T, 2 * d, 2 * d) + Hs: Array, # (T, d_y, 2 * d) + ds: Array, # (T, d_y) + chol_Rs: Array, # (T, d_y, d_y) + ys: Array, # (T, d_y) + factorial_indices: Array, # (T, 2) + smoother_factorial_index: int, +) -> tuple[Filter, factorial.Factorializer, Array, Smoother, Array]: + """Builds factorial Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" + + def get_init_params(model_inputs: int) -> tuple[Array, Array]: + return m0, chol_P0 + + def get_dynamics_params(model_inputs: int) -> tuple[Array, Array, Array]: + return Fs[model_inputs - 1], cs[model_inputs - 1], chol_Qs[model_inputs - 1] + + def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Array]: + return ( + Hs[model_inputs - 1], + ds[model_inputs - 1], + chol_Rs[model_inputs - 1], + ys[model_inputs - 1], + ) + + filter = kalman.build_filter( + get_init_params, get_dynamics_params, get_observation_params + ) + + factorializer = factorial.gaussian.build_factorializer( + get_factorial_indices=lambda model_inputs: factorial_indices[model_inputs - 1] + ) + filter_model_inputs = jnp.arange(len(ys) + 1) + + # Some processing to get smoothing for a single factor + num_factors = len(m0) + d_x = m0.shape[1] + Fs_per_factor = [[] for _ in range(num_factors)] + cs_per_factor = [[] for _ in range(num_factors)] + chol_Qs_per_factor = [[] for _ in range(num_factors)] + + for i in range(1, len(ys) + 1): + h, a = factorial_indices[i - 1] + + F_h = Fs[i - 1][:d_x, :d_x] + F_a = Fs[i - 1][-d_x:, -d_x:] + c_h = cs[i - 1][:d_x] + c_a = cs[i - 1][-d_x:] + chol_Q_h, chol_Q_a = block_marginal_sqrt_cov(chol_Qs[i - 1], d_x) + Fs_per_factor[h].append(F_h) + cs_per_factor[h].append(c_h) + chol_Qs_per_factor[h].append(chol_Q_h) + + Fs_per_factor[a].append(F_a) + cs_per_factor[a].append(c_a) + chol_Qs_per_factor[a].append(chol_Q_a) + + def get_dynamics_params_single_factor( + model_inputs: int, + ) -> tuple[Array, Array, Array]: + return ( + Fs_per_factor[smoother_factorial_index][model_inputs - 1], + cs_per_factor[smoother_factorial_index][model_inputs - 1], + chol_Qs_per_factor[smoother_factorial_index][model_inputs - 1], + ) + + smoother = kalman.build_smoother( + get_dynamics_params_single_factor, + store_gain=True, + store_chol_cov_given_next=True, + ) + smoother_model_inputs = jnp.arange(len(Fs_per_factor[smoother_factorial_index]) + 1) + + return filter, factorializer, filter_model_inputs, smoother, smoother_model_inputs + + +seeds = [1, 43] +x_dims = [1, 3] +y_dims = [1, 2] +num_factors = [10, 20] +num_factors_local = [2] # number of factors to interact at each time step +num_time_steps = [1, 25] + +common_params = list( + itertools.product( + seeds, x_dims, y_dims, num_factors, num_factors_local, num_time_steps + ) +) + + +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps", common_params +) +def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, factorializer, model_inputs, _, _ = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=0 + ) + ) + + # True means, covs and log norm constants + fac_means = model_params[0] + fac_chol_covs = model_params[1] + fac_covs = fac_chol_covs @ fac_chol_covs.transpose(0, 2, 1) + ell = jnp.array(0.0) + + local_means = [] + local_covs = [] + ells = [] + fac_means_t_all = [fac_means] + fac_covs_t_all = [fac_covs] + for i in model_inputs[1:]: + F, c, chol_Q = ( + model_params[2][i - 1], + model_params[3][i - 1], + model_params[4][i - 1], + ) + H, d, chol_R, y = ( + model_params[5][i - 1], + model_params[6][i - 1], + model_params[7][i - 1], + model_params[8][i - 1], + ) + fac_inds = model_params[9][i - 1] + + joint_mean = fac_means[fac_inds].reshape(-1) + joint_cov = block_diag(*fac_covs[fac_inds]) + Q = chol_Q @ chol_Q.T + R = chol_R @ chol_R.T + pred_mean, pred_cov = std_predict(joint_mean, joint_cov, F, c, Q) + upd_mean, upd_cov, upd_ell = std_update(pred_mean, pred_cov, H, d, R, y) + marginal_means = upd_mean.reshape(len(fac_inds), -1) + marginal_covs = jnp.array( + [ + upd_cov[i * x_dim : (i + 1) * x_dim, i * x_dim : (i + 1) * x_dim] + for i in range(len(fac_inds)) + ] + ) + ell += upd_ell + local_means.append(marginal_means) + local_covs.append(marginal_covs) + ells.append(ell) + fac_means = fac_means.at[fac_inds].set(marginal_means) + fac_covs = fac_covs.at[fac_inds].set(marginal_covs) + fac_means_t_all.append(fac_means) + fac_covs_t_all.append(fac_covs) + + local_means = jnp.stack(local_means) + local_covs = jnp.stack(local_covs) + ells = jnp.stack(ells) + fac_means_t_all = jnp.stack(fac_means_t_all) + fac_covs_t_all = jnp.stack(fac_covs_t_all) + + # Check output_factorial = False + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=False + ) + local_filter_covs = ( + local_filter_states.chol_cov + @ local_filter_states.chol_cov.transpose(0, 1, 3, 2) + ) + chex.assert_trees_all_close( + (init_state.mean, init_state.chol_cov), (model_params[0], model_params[1]) + ) + chex.assert_trees_all_close( + (local_means, local_covs, ells), + ( + local_filter_states.mean, + local_filter_covs, + local_filter_states.log_normalizing_constant, + ), + ) + + # Check output_factorial = True + factorial_filtering_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=True + ) + + factorial_filtering_states = cast(ArrayTree, factorial_filtering_states) + factorial_filtering_covs = ( + factorial_filtering_states.chol_cov + @ factorial_filtering_states.chol_cov.transpose(0, 1, 3, 2) + ) + chex.assert_trees_all_close( + (fac_means_t_all, fac_covs_t_all), + (factorial_filtering_states.mean, factorial_filtering_covs), + ) + chex.assert_trees_all_close( + ells, factorial_filtering_states.log_normalizing_constant[1:] + ) + + +smoother_indices = [0, 1, 5] + +common_smoother_params = list(itertools.product(common_params, smoother_indices)) + + +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps,smoother_factorial_index", + common_smoother_params, +) +def test_smoother( + seed, + x_dim, + y_dim, + num_factors, + num_factors_local, + num_time_steps, + smoother_factorial_index, +): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, factorializer, filter_model_inputs, smoother, smoother_model_inputs = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=smoother_factorial_index + ) + ) + + # Check output_factorial = False + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, filter_model_inputs, output_factorial=False + ) + + # Convert to local smoother states