From 540b3ba3ab51900504f162ba51375141614f4af2 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Sun, 1 Feb 2026 17:37:45 +0000 Subject: [PATCH 01/29] Draft factorial filtering API --- cuthbert/factorial/README.md | 62 ++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 cuthbert/factorial/README.md diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md new file mode 100644 index 00000000..885332f4 --- /dev/null +++ b/cuthbert/factorial/README.md @@ -0,0 +1,62 @@ +# Factorial State-Space Models + +A factorial state-space model is a state-space model where the dynamics distribution +factors into a product of independent distributions across factors + +$$ +p(x_t \mid x_{t-1}) = \prod_{f=1}^F p(x_t^f \mid x_{t-1}^f), +$$ +for factorial index $f \in \{1, \ldots, F\}$. We additionally assume that observations +act locally on some subset of factors $S_t \subseteq \{1, \ldots, F\}$. + +$$ +p(y_t \mid x_t) = p(y_t \mid x_t^{S_t}). +$$ + +This motivates a factored approximation of filtering and smoothing distributions, e.g. + +$$ +p(x_t \mid y_{0:t}) = \prod_{f=1}^F p(x_t^f \mid y_{0:t}). +$$ + +A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035). + + +## Factorial filtering with `cuthbert` + + + +```python +from jax import tree + +# Define model_inputs +model_inputs = ... + +# Define factorial function to extract relevant factors and combine into a joint local state +def extract_and_join(state, model_inputs): + .... + +# Define factorial function to marginalize joint local state into a factored state +# and insert into factorial state +def factorial_marginalize_and_insert(state, local_state, model_inputs): + .... + +# Load inference method, with parameter extraction functions defined for factorial inference +kalman_filter = cuthbert.gaussian.kalman.build_filter( + get_init_params=get_init_params, # Init specified to generate factorial state + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state + get_observation_params=get_observation_params, # Observation specified to act on joint local state +) + +# Online inference +factorial_state = kalman_filter.init_prepare(tree.map(lambda x: x[0], model_inputs)) + +for t in range(1, T): + model_inputs_t = tree.map(lambda x: x[t], model_inputs) + local_state = extract_and_join(factorial_state, model_inputs_t) + prepare_state = kalman_filter.filter_prepare(model_inputs_t) + filtered_local_state = kalman_filter.filter_combine(local_state, prepare_state) + factorial_state = factorial_marginalize_and_insert(factorial_state, filtered_local_state, model_inputs_t) +``` + + From 248fa12871109a923f434ff857daa53b8d5e1aec Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 11:49:05 +0000 Subject: [PATCH 02/29] Add smoothing API --- cuthbert/factorial/README.md | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index 885332f4..9ddb528a 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -21,9 +21,16 @@ $$ A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035). +The factorial approximation allows us to exploit significant benefits in terms of +memory, compute and parallelization. + ## Factorial filtering with `cuthbert` +Filtering in a factorial state-space model is similar to standard filtering, but with +additional an additional step before the filtering operation to extract the relevant +factors as well as an additional step after the filtering operation to insert the +updated factors back into the factorial state. ```python @@ -60,3 +67,45 @@ for t in range(1, T): ``` +## Factorial smoothing with `cuthbert` + +Smoothing in factorial state-space models can be performed embarassingly parallel +along the factors since the dynamics and factorial approximation are independent +across factors (the observations are fully absorbed in the filtering and +are not accessed during smoothing). + +The model inputs and filter states require some preprocessing to convert from being +single sequence with each state containing all factors into a sequence or multiple +sequences with each state corresponding to a single factor. This can be quite +fiddly but is left to the user for maximum freedom. + +TODO: Document some use cases in the examples. + +After this preprocessing, smoothing can be performed as usual: + +```python +# Define model_inputs for a single factor +model_inputs_single_factor = ... + +# Similarly, we need to extract the filter states for the single factor we're smoothing. +filter_states_single_factor = ... + +# Load smoother, with parameter extraction functions defined for factorial inference +kalman_smoother = cuthbert.gaussian.kalman.build_smoother( + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state +) + +smoother_state = kalman_smoother.convert_filter_to_smoother_state( + tree.map(lambda x: x[-1], filter_states_single_factor), + model_inputs=tree.map(lambda x: x[-1], model_inputs_single_factor), +) + +for t in range(T - 1, -1, -1): + model_inputs_single_factor_t = tree.map(lambda x: x[t], model_inputs_single_factor) + filter_state_single_factor_t = tree.map(lambda x: x[t], filter_states_single_factor) + prepare_state = kalman_smoother.smoother_prepare( + filter_state_single_factor_t, model_inputs_single_factor_t + ) + smoother_state = kalman_smoother.smoother_combine(prepare_state, smoother_state) +``` + From 1ff85c51f82b59b4bfefd110ff9b7aa0d2c7760c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 12:07:56 +0000 Subject: [PATCH 03/29] Fix typos --- cuthbert/factorial/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index 9ddb528a..10b939fe 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -28,7 +28,7 @@ memory, compute and parallelization. ## Factorial filtering with `cuthbert` Filtering in a factorial state-space model is similar to standard filtering, but with -additional an additional step before the filtering operation to extract the relevant +an additional step before the filtering operation to extract the relevant factors as well as an additional step after the filtering operation to insert the updated factors back into the factorial state. @@ -69,7 +69,7 @@ for t in range(1, T): ## Factorial smoothing with `cuthbert` -Smoothing in factorial state-space models can be performed embarassingly parallel +Smoothing in factorial state-space models can be performed embarrassingly parallel along the factors since the dynamics and factorial approximation are independent across factors (the observations are fully absorbed in the filtering and are not accessed during smoothing). From 51fdc0855ccd34074f65ebb436902ff1707f35cc Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 12:26:00 +0000 Subject: [PATCH 04/29] Add types.py --- cuthbert/factorial/README.md | 27 ++++++++++++--- cuthbert/factorial/types.py | 64 ++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 cuthbert/factorial/types.py diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index 10b939fe..cd8fe590 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -45,14 +45,14 @@ def extract_and_join(state, model_inputs): # Define factorial function to marginalize joint local state into a factored state # and insert into factorial state -def factorial_marginalize_and_insert(state, local_state, model_inputs): +def marginalize_and_insert(state, local_state, model_inputs): .... # Load inference method, with parameter extraction functions defined for factorial inference kalman_filter = cuthbert.gaussian.kalman.build_filter( - get_init_params=get_init_params, # Init specified to generate factorial state - get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state - get_observation_params=get_observation_params, # Observation specified to act on joint local state + get_init_params=get_init_params, # Init specified to generate factorial state + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state + get_observation_params=get_observation_params, # Observation specified to act on joint local state ) # Online inference @@ -63,9 +63,19 @@ for t in range(1, T): local_state = extract_and_join(factorial_state, model_inputs_t) prepare_state = kalman_filter.filter_prepare(model_inputs_t) filtered_local_state = kalman_filter.filter_combine(local_state, prepare_state) - factorial_state = factorial_marginalize_and_insert(factorial_state, filtered_local_state, model_inputs_t) + factorial_state = factorial_marginalize_and_insert( + factorial_state, filtered_local_state, model_inputs_t + ) ``` +You can also use `cuthbert.factorial.filter` for convenient offline filtering. +Note that associative/parallel filtering is not supported for factorial filtering. + +```python +filter_states = cuthbert.factorial.filter( + kalman_filter, extract_and_join, marginalize_and_insert, model_inputs +) +``` ## Factorial smoothing with `cuthbert` @@ -109,3 +119,10 @@ for t in range(T - 1, -1, -1): smoother_state = kalman_smoother.smoother_combine(prepare_state, smoother_state) ``` +Or directly using the `cuthbert.smoother`: + +```python +smoother_states = cuthbert.smoother( + kalman_smoother, filter_states_single_factor, model_inputs_single_factor +) +``` \ No newline at end of file diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py new file mode 100644 index 00000000..be8aa04a --- /dev/null +++ b/cuthbert/factorial/types.py @@ -0,0 +1,64 @@ +"""Provides types for factorial state-space models.""" + +from typing import Protocol + +from cuthbertlib.types import ArrayTree, ArrayTreeLike + + +class ExtractAndJoin(Protocol): + """Protocol for extracting and joining the relevant factors.""" + + def __call__( + self, factorial_state: ArrayTree, model_inputs: ArrayTreeLike + ) -> ArrayTree: + """Extract factors from factorial state and combine into a joint local state. + + E.g. state might encode factorial `means` with shape (F, d) and `chol_covs` + with shape (F, d, d). Then `model_inputs` tells us factors `i` and `j` are + relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and + `chol_covs[j]`. Then combine them into `joint_mean` with shape (2 * d,) + and block diagonal `joint_chol_cov` with shape (2 * d, 2 * d). + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + model_inputs: Model inputs including information required to determine + the relevant factors (e.g. factor indices). + + Returns: + Joint local state with no factorial index dimension. + """ + ... + + +class MarginalizeAndInsert(Protocol): + """Protocol for marginalizing and inserting the updated factors.""" + + def __call__( + self, + factorial_state: ArrayTree, + local_state: ArrayTree, + model_inputs: ArrayTreeLike, + ) -> ArrayTree: + """Marginalize joint state into factored state and insert into factorial state. + + E.g. `local_state` might have shape (2 * d,) and `joint_chol_cov` + with shape (2 * d, 2 * d). Then we marginalize out the joint local state into + two factorial `means` with shape (2, d) and `chol_covs` with shape (2, d, d). + If `model_inputs` tells us we're working with factors `i` and `j`, then we + insert `means[0]` and `means[1]` into `state[i]` and `state[j]` respectively. + Similarly, we insert `chol_covs[0]` and `chol_covs[1]`. In both cases, we + overwrite the existing factors in the factorial state for `i` and `j`, + leaving the other factors unchanged. + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + local_state: Joint local state with no factorial index dimension. + model_inputs: Model inputs including information required to determine + the relevant factors (e.g. factor indices). + + Returns: + Factorial state with factorial index as the first dimension. + The updated factors are inserted into the factorial state. + The remaining factors are left unchanged. + """ + ... From 6b8efebed916797b5ed4533f914aa7c24cc0379c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 12:37:39 +0000 Subject: [PATCH 05/29] Draft filtering --- cuthbert/factorial/filtering.py | 82 +++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 cuthbert/factorial/filtering.py diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py new file mode 100644 index 00000000..f940f805 --- /dev/null +++ b/cuthbert/factorial/filtering.py @@ -0,0 +1,82 @@ +"""cuthbert factorial filtering interface.""" + +from jax import numpy as jnp +from jax import random, tree, vmap +from jax.lax import associative_scan, scan + +from cuthbert.inference import Filter +from cuthbert.factorial.types import ExtractAndJoin, MarginalizeAndInsert +from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray + + +def filter( + filter_obj: Filter, + extract_and_join: ExtractAndJoin, + marginalize_and_insert: MarginalizeAndInsert, + model_inputs: ArrayTreeLike, + key: KeyArray | None = None, +) -> ArrayTree: + """Applies offline factorial filtering for given model inputs. + + `model_inputs` should have leading temporal dimension of length T + 1, + where T is the number of time steps excluding the initial state. + + Parallel associative filtering is not supported for factorial filtering. + + Note that this function will output a factorial state with first temporal dimension + of length T + 1 and second factorial dimension of length F. Many of the factors + will be unchanged across timesteps where they aren't relevant. So some memory + can be saved with more sophisticated data structures although this is left to the + user for maximum flexibility (and jax.lax.scan can be hard to work with varliable + sized arrays). + + Args: + filter_obj: The filter inference object. + extract_and_join: Function to extract and join the relevant factors into + a single joint state. + marginalize_and_insert: Function to marginalize and insert the updated factors + back into the factorial state. + model_inputs: The model inputs (with leading temporal dimension of length T + 1). + key: The key for the random number generator. + + Returns: + The filtered states (NamedTuple with leading temporal dimension of length T + 1). + """ + T = tree.leaves(model_inputs)[0].shape[0] - 1 + + if key is None: + # This will throw error if used as a key, which is desired behavior + # (albeit not a useful error, we could improve this) + prepare_keys = jnp.empty(T + 1) + else: + prepare_keys = random.split(key, T + 1) + + init_model_input = tree.map(lambda x: x[0], model_inputs) + init_factorial_state = filter_obj.init_prepare( + init_model_input, key=prepare_keys[0] + ) + + prep_model_inputs = tree.map(lambda x: x[1:], model_inputs) + + def body(prev_factorial_state, prep_inp_and_k): + prep_inp, k = prep_inp_and_k + local_state = extract_and_join(prev_factorial_state, prep_inp) + prep_state = filter_obj.filter_prepare(prep_inp, key=k) + filtered_joint_state = filter_obj.filter_combine(local_state, prep_state) + factorial_state = marginalize_and_insert( + prev_factorial_state, filtered_joint_state, prep_inp + ) + return factorial_state, factorial_state + + _, factorial_states = scan( + body, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + factorial_states = tree.map( + lambda x, y: jnp.concatenate([x[None], y]), + init_factorial_state, + factorial_states, + ) + + return factorial_states From d5af70d026b79eede7398367cd15491d35143545 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 14:27:25 +0000 Subject: [PATCH 06/29] Add init and sort imports --- cuthbert/factorial/__init__.py | 2 ++ cuthbert/factorial/filtering.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 cuthbert/factorial/__init__.py diff --git a/cuthbert/factorial/__init__.py b/cuthbert/factorial/__init__.py new file mode 100644 index 00000000..fd55e23a --- /dev/null +++ b/cuthbert/factorial/__init__.py @@ -0,0 +1,2 @@ +from cuthbert.factorial.filtering import filter +from cuthbert.factorial.types import ExtractAndJoin, MarginalizeAndInsert diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py index f940f805..7107b46d 100644 --- a/cuthbert/factorial/filtering.py +++ b/cuthbert/factorial/filtering.py @@ -4,8 +4,8 @@ from jax import random, tree, vmap from jax.lax import associative_scan, scan -from cuthbert.inference import Filter from cuthbert.factorial.types import ExtractAndJoin, MarginalizeAndInsert +from cuthbert.inference import Filter from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray From ff0f085ae8900dd9508d216aa6b13731683872f3 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 14:36:24 +0000 Subject: [PATCH 07/29] Add cuthbert import to READMEs --- cuthbert/README.md | 1 + cuthbert/factorial/README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/cuthbert/README.md b/cuthbert/README.md index 92b9e384..6cdc6307 100644 --- a/cuthbert/README.md +++ b/cuthbert/README.md @@ -7,6 +7,7 @@ All inference methods are implemented with the following unified interface: ```python from jax import tree +import cuthbert # Define model_inputs model_inputs = ... diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index cd8fe590..4db6421c 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -35,6 +35,7 @@ updated factors back into the factorial state. ```python from jax import tree +import cuthbert # Define model_inputs model_inputs = ... From 5d4fb9a9d88366113b7923ed44cb7989497376e3 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 2 Feb 2026 21:15:08 +0000 Subject: [PATCH 08/29] Start test_filtering --- cuthbert/factorial/README.md | 6 ++ tests/cuthbert/factorial/test_filtering.py | 64 ++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 tests/cuthbert/factorial/test_filtering.py diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index 4db6421c..557b7c5d 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -24,6 +24,12 @@ A tutorial on factorial state-space models can be found in [Duffield et al](http The factorial approximation allows us to exploit significant benefits in terms of memory, compute and parallelization. +Note that although the dynamics are factorized, `cuthbert` does not differentiate +between `predict` and `update` (instead favouring a unified filter operation +via `filter_prepare` and `filter_combine`). Thus the dynamics and model inputs +should be specified to act on the joint local state (i.e. block diagonal +where appropriate). + ## Factorial filtering with `cuthbert` diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py new file mode 100644 index 00000000..27b69c05 --- /dev/null +++ b/tests/cuthbert/factorial/test_filtering.py @@ -0,0 +1,64 @@ +import jax +import jax.numpy as jnp +import pytest +from jax import Array, vmap + + +from cuthbert import factorial +from cuthbert.gaussian import kalman +from cuthbert.inference import Filter, Smoother +from cuthbertlib.kalman.generate import generate_lgssm +from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update +from tests.cuthbertlib.kalman.test_smoothing import std_kalman_smoother + + +@pytest.fixture(scope="module", autouse=True) +def config(): + jax.config.update("jax_enable_x64", True) + yield + jax.config.update("jax_enable_x64", False) + + +def load_kalman_pairwise_factorial_inference( + m0: Array, # (F, d) + chol_P0: Array, # (F, d, d) + Fs: Array, # (T, 2 * d, 2 * d) + cs: Array, # (T, 2 * d) + chol_Qs: Array, # (T, 2 * d, 2 * d) + Hs: Array, # (T, 2 * d, d_y) + ds: Array, # (T, d_y) + chol_Rs: Array, # (T, d_y, d_y) + ys: Array, # (T + 1, d_y) + factorial_indices: Array, # (T, 2) +) -> tuple[Filter, Smoother, Array]: + """Builds Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" + + def get_init_params(model_inputs: int) -> tuple[Array, Array]: + return m0, chol_P0 + + def get_dynamics_params(model_inputs: int) -> tuple[Array, Array, Array]: + return Fs[model_inputs - 1], cs[model_inputs - 1], chol_Qs[model_inputs - 1] + + def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Array]: + return ( + Hs[model_inputs], + ds[model_inputs], + chol_Rs[model_inputs], + ys[model_inputs], + ) + + def extract_and_join(factorial_state, model_inputs): + fac_inds = factorial_indices[model_inputs - 1] + + means = + + + + filter = kalman.build_filter( + get_init_params, get_dynamics_params, get_observation_params + ) + smoother = kalman.build_smoother( + get_dynamics_params, store_gain=True, store_chol_cov_given_next=True + ) + model_inputs = jnp.arange(len(ys)) + return filter, smoother, model_inputs From 64fbb37d7333a852230734d5453e51e9d595b1e0 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 3 Feb 2026 17:01:45 +0000 Subject: [PATCH 09/29] Draft gaussian --- cuthbert/factorial/gaussian.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 cuthbert/factorial/gaussian.py diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py new file mode 100644 index 00000000..0d559537 --- /dev/null +++ b/cuthbert/factorial/gaussian.py @@ -0,0 +1,21 @@ +from jax import numpy as jnp, Array + +from cuthbert.gaussian.kalman import KalmanFilterState +from cuthbert.gaussian.types import LinearizedKalmanFilterState + + +def extract_and_join( + factorial_slice: list[int], state: KalmanFilterState | LinearizedKalmanFilterState +) -> KalmanFilterState | LinearizedKalmanFilterState: + """Extract factors from a Gaussian factorial state and combine into a joint local state.""" + ... + + +def extract_and_join_arr(factorial_inds: Array, mean: Array) -> Array: + """Extract factors from a Gaussian factorial state and combine into a joint local state.""" + return mean[factorial_inds].reshape(-1) + + +def extract_and_join_chol_cov(factorial_inds: Array, chol_cov: Array) -> Array: + """Extract factors from a Gaussian factorial state and combine into a joint local state.""" + ... From 76a85e0b0f76c1aded5bbb4846476679d9bc8000 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 3 Feb 2026 17:56:37 +0000 Subject: [PATCH 10/29] Flesh out factorial gaussian --- cuthbert/factorial/gaussian.py | 139 ++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py index 0d559537..aeeb1eb3 100644 --- a/cuthbert/factorial/gaussian.py +++ b/cuthbert/factorial/gaussian.py @@ -1,21 +1,136 @@ -from jax import numpy as jnp, Array +"""Factorial utilities for Kalman states.""" +from typing import TypeVar + +from jax import tree, numpy as jnp, vmap +from jax.scipy.linalg import block_diag + +from cuthbertlib.linalg import marginal_sqrt_cov from cuthbert.gaussian.kalman import KalmanFilterState from cuthbert.gaussian.types import LinearizedKalmanFilterState +from cuthbertlib.types import Array, ArrayLike + + +KalmanState = TypeVar("KalmanState", KalmanFilterState, LinearizedKalmanFilterState) + + +def extract_and_join(factorial_inds: ArrayLike, state: KalmanState) -> KalmanState: + """Convert a factorial Kalman state into a joint local Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Two dimensional arrays will be treated as means with shape (F, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a single array. + Three dimensional arrays will be treated as chol_covs with shape (F, d, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a block diagonal array. + + Here F is the number of factors and d is the dimension of the state. + + Args: + factorial_inds: Indices of the factors to extract. Integer array. + state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + factorial_inds = jnp.asarray(factorial_inds) + return tree.map(lambda x: _extract_and_join_arr(factorial_inds, x), state) + + +def _extract_and_join_arr(factorial_inds: Array, arr: Array) -> Array: + if arr.ndim == 1: + return arr + elif arr.ndim == 2: + return _extract_and_join_means(factorial_inds, arr) + elif arr.ndim == 3: + return _extract_and_join_chol_covs(factorial_inds, arr) + else: + raise ValueError(f"Array must be 1D, 2D or 3D, got {arr.ndim}D") + + +def _extract_and_join_means(factorial_inds: Array, means: Array) -> Array: + return means[factorial_inds].reshape(-1) + + +def _extract_and_join_chol_covs(factorial_inds: Array, chol_covs: Array) -> Array: + selected_chol_covs = chol_covs[factorial_inds] + return block_diag(*selected_chol_covs) + + +def marginalize_and_insert( + factorial_inds: Array, + local_state: KalmanState, + factorial_state: KalmanState, +) -> KalmanState: + """Marginalize and insert a joint local Kalman state into a factorial Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Two dimensional arrays will be treated as means with shape (F, d). + In this case the dimension d will be inferred and then the array split into + len(factorial_inds) arrays of shape (d,) then inserted into the factorial array + at the factorial_inds indices. + Three dimensional arrays will be treated as chol_covs with shape (F, d, d). + In this case the dimension d will be inferred and then the array split into + len(factorial_inds) arrays of shape (d, d) (noting that the marginal_sqrt_cov + function is called to preserve the lower triangular structure) then inserted + into the factorial array at the factorial_inds indices. + + Here F is the number of factors and d is the dimension of the state. + + Args: + factorial_inds: Indices of the factors to insert. Integer array. + local_state: Joint local Kalman state to marginalize and insert. + With means and chol_covs with shape (d * len(factorial_inds),) + and (d * len(factorial_inds), d * len(factorial_inds)) respectively. + factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + return tree.map( + lambda loc, fac: _marginalize_and_insert_arr(factorial_inds, loc, fac), + local_state, + factorial_state, + ) -def extract_and_join( - factorial_slice: list[int], state: KalmanFilterState | LinearizedKalmanFilterState -) -> KalmanFilterState | LinearizedKalmanFilterState: - """Extract factors from a Gaussian factorial state and combine into a joint local state.""" - ... +def _marginalize_and_insert_arr( + factorial_inds: ArrayLike, local_arr: Array, factorial_arr: Array +) -> Array: + factorial_inds = jnp.asarray(factorial_inds) + if local_arr.ndim == 1: + return local_arr + elif local_arr.ndim == 2: + return _marginalize_and_insert_mean(factorial_inds, local_arr, factorial_arr) + elif local_arr.ndim == 3: + return _marginalize_and_insert_chol_cov( + factorial_inds, local_arr, factorial_arr + ) + else: + raise ValueError(f"Array must be 1D, 2D or 3D, got {local_arr.ndim}D") -def extract_and_join_arr(factorial_inds: Array, mean: Array) -> Array: - """Extract factors from a Gaussian factorial state and combine into a joint local state.""" - return mean[factorial_inds].reshape(-1) +def _marginalize_and_insert_mean( + factorial_inds: Array, local_mean: Array, factorial_means: Array +) -> Array: + local_mean_with_factorial_dimension = local_mean.reshape(len(factorial_inds), -1) + return factorial_means.at[factorial_inds].set(local_mean_with_factorial_dimension) -def extract_and_join_chol_cov(factorial_inds: Array, chol_cov: Array) -> Array: - """Extract factors from a Gaussian factorial state and combine into a joint local state.""" - ... +def _marginalize_and_insert_chol_cov( + factorial_inds: Array, local_chol_cov: Array, factorial_chol_covs: Array +) -> Array: + d = factorial_chol_covs.shape[-1] + starts = jnp.arange(0, len(factorial_inds)) * d + ends = starts + d + marginal_chol_covs = vmap(lambda s, e: marginal_sqrt_cov(local_chol_cov, s, e))( + starts, ends + ) + return factorial_chol_covs.at[factorial_inds].set(marginal_chol_covs) From f98b875a3d195ca50217181d990f3b4764b1b6cd Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 4 Feb 2026 14:42:11 +0000 Subject: [PATCH 11/29] Refactor factorial API with factorializer --- cuthbert/factorial/README.md | 25 +++++----- cuthbert/factorial/filtering.py | 82 ++++++++++++++++++++------------- cuthbert/factorial/gaussian.py | 64 +++++++++++++++++-------- cuthbert/factorial/types.py | 56 +++++++++++++++++----- 4 files changed, 152 insertions(+), 75 deletions(-) diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index 557b7c5d..a57c06ea 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -46,14 +46,12 @@ import cuthbert # Define model_inputs model_inputs = ... -# Define factorial function to extract relevant factors and combine into a joint local state -def extract_and_join(state, model_inputs): - .... +# Define function to extract the factorial indices from model inputs +# Here we assume model_inputs is a NamedTuple with a field `factorial_inds` +get_factorial_indices = lambda mi: mi.factorial_inds -# Define factorial function to marginalize joint local state into a factored state -# and insert into factorial state -def marginalize_and_insert(state, local_state, model_inputs): - .... +# Build factorializer for the inference method +factorializer = cuthbert.factorial.gaussian.build_factorializer(get_factorial_indices) # Load inference method, with parameter extraction functions defined for factorial inference kalman_filter = cuthbert.gaussian.kalman.build_filter( @@ -67,11 +65,12 @@ factorial_state = kalman_filter.init_prepare(tree.map(lambda x: x[0], model_inpu for t in range(1, T): model_inputs_t = tree.map(lambda x: x[t], model_inputs) - local_state = extract_and_join(factorial_state, model_inputs_t) + factorial_inds = get_factorial_indices(model_inputs_t) + local_state = factorializer.extract_and_join(factorial_state, factorial_inds) prepare_state = kalman_filter.filter_prepare(model_inputs_t) filtered_local_state = kalman_filter.filter_combine(local_state, prepare_state) - factorial_state = factorial_marginalize_and_insert( - factorial_state, filtered_local_state, model_inputs_t + factorial_state = factorializer.marginalize_and_insert( + filtered_local_state, factorial_state, factorial_inds ) ``` @@ -79,8 +78,8 @@ You can also use `cuthbert.factorial.filter` for convenient offline filtering. Note that associative/parallel filtering is not supported for factorial filtering. ```python -filter_states = cuthbert.factorial.filter( - kalman_filter, extract_and_join, marginalize_and_insert, model_inputs +init_factorial_state, local_filter_states = cuthbert.factorial.filter( + kalman_filter, factorializer, model_inputs, output_factorial=False ) ``` @@ -93,7 +92,7 @@ are not accessed during smoothing). The model inputs and filter states require some preprocessing to convert from being single sequence with each state containing all factors into a sequence or multiple -sequences with each state corresponding to a single factor. This can be quite +sequences with each state corresponding to a single factor. This can be fiddly but is left to the user for maximum freedom. TODO: Document some use cases in the examples. diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py index 7107b46d..ff57cb9e 100644 --- a/cuthbert/factorial/filtering.py +++ b/cuthbert/factorial/filtering.py @@ -1,21 +1,21 @@ """cuthbert factorial filtering interface.""" from jax import numpy as jnp -from jax import random, tree, vmap -from jax.lax import associative_scan, scan +from jax import random, tree +from jax.lax import scan -from cuthbert.factorial.types import ExtractAndJoin, MarginalizeAndInsert +from cuthbert.factorial.types import Factorializer from cuthbert.inference import Filter from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray def filter( filter_obj: Filter, - extract_and_join: ExtractAndJoin, - marginalize_and_insert: MarginalizeAndInsert, + factorializer: Factorializer, model_inputs: ArrayTreeLike, + output_factorial: bool = False, key: KeyArray | None = None, -) -> ArrayTree: +) -> ArrayTree | tuple[ArrayTree, ArrayTree]: """Applies offline factorial filtering for given model inputs. `model_inputs` should have leading temporal dimension of length T + 1, @@ -23,20 +23,21 @@ def filter( Parallel associative filtering is not supported for factorial filtering. - Note that this function will output a factorial state with first temporal dimension - of length T + 1 and second factorial dimension of length F. Many of the factors - will be unchanged across timesteps where they aren't relevant. So some memory - can be saved with more sophisticated data structures although this is left to the - user for maximum flexibility (and jax.lax.scan can be hard to work with varliable - sized arrays). + Note that if output_factorial is True, this function will output a factorial state + with first temporal dimension of length T + 1 and second factorial dimension of + length F. Many of the factors will be unchanged across timesteps where they aren't + relevant. Args: filter_obj: The filter inference object. - extract_and_join: Function to extract and join the relevant factors into - a single joint state. - marginalize_and_insert: Function to marginalize and insert the updated factors - back into the factorial state. + factorializer: The factorializer object for the inference method. model_inputs: The model inputs (with leading temporal dimension of length T + 1). + output_factorial: If True, return a single state with first temporal dimension + of length T + 1 and second factorial dimension of length F. + If False, return a tuple of states. The first being the initial state + with first dimension of length F and temporal dimension. + The second being the local states for each time step, i.e. first + dimension of length T and no factorial dimension. key: The key for the random number generator. Returns: @@ -58,25 +59,42 @@ def filter( prep_model_inputs = tree.map(lambda x: x[1:], model_inputs) - def body(prev_factorial_state, prep_inp_and_k): + def body_local(prev_factorial_state, prep_inp_and_k): prep_inp, k = prep_inp_and_k - local_state = extract_and_join(prev_factorial_state, prep_inp) + factorial_inds = factorializer.get_factorial_indices(prep_inp) + local_state = factorializer.extract_and_join( + prev_factorial_state, factorial_inds + ) prep_state = filter_obj.filter_prepare(prep_inp, key=k) filtered_joint_state = filter_obj.filter_combine(local_state, prep_state) - factorial_state = marginalize_and_insert( - prev_factorial_state, filtered_joint_state, prep_inp + factorial_state = factorializer.marginalize_and_insert( + filtered_joint_state, prev_factorial_state, factorial_inds ) - return factorial_state, factorial_state + return factorial_state, filtered_joint_state - _, factorial_states = scan( - body, - init_factorial_state, - (prep_model_inputs, prepare_keys[1:]), - ) - factorial_states = tree.map( - lambda x, y: jnp.concatenate([x[None], y]), - init_factorial_state, - factorial_states, - ) + if output_factorial: + + def body_factorial(prev_factorial_state, prep_inp_and_k): + factorial_state, _ = body_local(prev_factorial_state, prep_inp_and_k) + return factorial_state, factorial_state + + _, factorial_states = scan( + body_factorial, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + factorial_states = tree.map( + lambda x, y: jnp.concatenate([x[None], y]), + init_factorial_state, + factorial_states, + ) + + return factorial_states - return factorial_states + else: + _, local_states = scan( + body_local, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + return init_factorial_state, local_states diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py index aeeb1eb3..6c772a3f 100644 --- a/cuthbert/factorial/gaussian.py +++ b/cuthbert/factorial/gaussian.py @@ -9,12 +9,35 @@ from cuthbert.gaussian.kalman import KalmanFilterState from cuthbert.gaussian.types import LinearizedKalmanFilterState from cuthbertlib.types import Array, ArrayLike +from cuthbert.factorial.types import Factorializer, GetFactorialIndices KalmanState = TypeVar("KalmanState", KalmanFilterState, LinearizedKalmanFilterState) -def extract_and_join(factorial_inds: ArrayLike, state: KalmanState) -> KalmanState: +def build_factorializer( + get_factorial_indices: GetFactorialIndices, +) -> Factorializer: + """Build a factorializer for Kalman states. + + Args: + get_factorial_indices: Function to extract the factorial indices + from model inputs. + + Returns: + Factorializer object for Kalman states with functions to extract and join + the relevant factors and marginalize and insert the updated factors. + """ + return Factorializer( + get_factorial_indices=get_factorial_indices, + extract_and_join=extract_and_join, + marginalize_and_insert=marginalize_and_insert, + ) + + +def extract_and_join( + factorial_state: KalmanState, factorial_inds: ArrayLike +) -> KalmanState: """Convert a factorial Kalman state into a joint local Kalman state. Single dimensional arrays will be treated as scalars e.g. log normalizing constants. @@ -30,41 +53,41 @@ def extract_and_join(factorial_inds: ArrayLike, state: KalmanState) -> KalmanSta Here F is the number of factors and d is the dimension of the state. Args: - factorial_inds: Indices of the factors to extract. Integer array. - state: Factorial Kalman state storing means and chol_covs + factorial_state: Factorial Kalman state storing means and chol_covs with shape (F, d) and (F, d, d) respectively. + factorial_inds: Indices of the factors to extract. Integer array. Returns: Joint local Kalman state with no factorial index dimension. """ factorial_inds = jnp.asarray(factorial_inds) - return tree.map(lambda x: _extract_and_join_arr(factorial_inds, x), state) + return tree.map(lambda x: _extract_and_join_arr(x, factorial_inds), factorial_state) -def _extract_and_join_arr(factorial_inds: Array, arr: Array) -> Array: +def _extract_and_join_arr(arr: Array, factorial_inds: Array) -> Array: if arr.ndim == 1: return arr elif arr.ndim == 2: - return _extract_and_join_means(factorial_inds, arr) + return _extract_and_join_means(arr, factorial_inds) elif arr.ndim == 3: - return _extract_and_join_chol_covs(factorial_inds, arr) + return _extract_and_join_chol_covs(arr, factorial_inds) else: raise ValueError(f"Array must be 1D, 2D or 3D, got {arr.ndim}D") -def _extract_and_join_means(factorial_inds: Array, means: Array) -> Array: +def _extract_and_join_means(means: Array, factorial_inds: Array) -> Array: return means[factorial_inds].reshape(-1) -def _extract_and_join_chol_covs(factorial_inds: Array, chol_covs: Array) -> Array: +def _extract_and_join_chol_covs(chol_covs: Array, factorial_inds: Array) -> Array: selected_chol_covs = chol_covs[factorial_inds] return block_diag(*selected_chol_covs) def marginalize_and_insert( - factorial_inds: Array, local_state: KalmanState, factorial_state: KalmanState, + factorial_inds: ArrayLike, ) -> KalmanState: """Marginalize and insert a joint local Kalman state into a factorial Kalman state. @@ -84,48 +107,53 @@ def marginalize_and_insert( Here F is the number of factors and d is the dimension of the state. Args: - factorial_inds: Indices of the factors to insert. Integer array. local_state: Joint local Kalman state to marginalize and insert. With means and chol_covs with shape (d * len(factorial_inds),) and (d * len(factorial_inds), d * len(factorial_inds)) respectively. factorial_state: Factorial Kalman state storing means and chol_covs - with shape (F, d) and (F, d, d) respectively. + with shape (F, d) and (F, d, d) respectively.\ + factorial_inds: Indices of the factors to insert. Integer array. Returns: Joint local Kalman state with no factorial index dimension. """ + factorial_inds = jnp.asarray(factorial_inds) return tree.map( - lambda loc, fac: _marginalize_and_insert_arr(factorial_inds, loc, fac), + lambda loc, fac: _marginalize_and_insert_arr(loc, fac, factorial_inds), local_state, factorial_state, ) def _marginalize_and_insert_arr( - factorial_inds: ArrayLike, local_arr: Array, factorial_arr: Array + local_arr: Array, factorial_arr: Array, factorial_inds: ArrayLike ) -> Array: factorial_inds = jnp.asarray(factorial_inds) if local_arr.ndim == 1: return local_arr elif local_arr.ndim == 2: - return _marginalize_and_insert_mean(factorial_inds, local_arr, factorial_arr) + return _marginalize_and_insert_mean(local_arr, factorial_arr, factorial_inds) elif local_arr.ndim == 3: return _marginalize_and_insert_chol_cov( - factorial_inds, local_arr, factorial_arr + local_arr, factorial_arr, factorial_inds ) else: raise ValueError(f"Array must be 1D, 2D or 3D, got {local_arr.ndim}D") def _marginalize_and_insert_mean( - factorial_inds: Array, local_mean: Array, factorial_means: Array + local_mean: Array, + factorial_means: Array, + factorial_inds: Array, ) -> Array: local_mean_with_factorial_dimension = local_mean.reshape(len(factorial_inds), -1) return factorial_means.at[factorial_inds].set(local_mean_with_factorial_dimension) def _marginalize_and_insert_chol_cov( - factorial_inds: Array, local_chol_cov: Array, factorial_chol_covs: Array + local_chol_cov: Array, + factorial_chol_covs: Array, + factorial_inds: Array, ) -> Array: d = factorial_chol_covs.shape[-1] starts = jnp.arange(0, len(factorial_inds)) * d diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py index be8aa04a..3158356b 100644 --- a/cuthbert/factorial/types.py +++ b/cuthbert/factorial/types.py @@ -1,28 +1,44 @@ """Provides types for factorial state-space models.""" -from typing import Protocol +from typing import Protocol, NamedTuple -from cuthbertlib.types import ArrayTree, ArrayTreeLike +from cuthbertlib.types import ArrayTree, ArrayTreeLike, ArrayLike + + +class GetFactorialIndices(Protocol): + """Protocol for getting the factorial indices.""" + + def __call__(self, model_inputs: ArrayTreeLike) -> ArrayLike: + """Extract the factorial indices from model inputs. + + Args: + model_inputs: Model inputs. + + Returns: + Indices of the factors to extract. Integer array. + """ + ... class ExtractAndJoin(Protocol): """Protocol for extracting and joining the relevant factors.""" def __call__( - self, factorial_state: ArrayTree, model_inputs: ArrayTreeLike + self, + factorial_state: ArrayTreeLike, + factorial_inds: ArrayLike, ) -> ArrayTree: """Extract factors from factorial state and combine into a joint local state. - E.g. state might encode factorial `means` with shape (F, d) and `chol_covs` - with shape (F, d, d). Then `model_inputs` tells us factors `i` and `j` are - relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and + E.g. factorial_state might encode factorial `means` with shape (F, d) and + `chol_covs` with shape (F, d, d). Then `model_inputs` tells us factors `i` and + `j` are relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and `chol_covs[j]`. Then combine them into `joint_mean` with shape (2 * d,) and block diagonal `joint_chol_cov` with shape (2 * d, 2 * d). Args: factorial_state: Factorial state with factorial index as the first dimension. - model_inputs: Model inputs including information required to determine - the relevant factors (e.g. factor indices). + factorial_inds: Indices of the factors to extract. Integer array. Returns: Joint local state with no factorial index dimension. @@ -35,9 +51,9 @@ class MarginalizeAndInsert(Protocol): def __call__( self, - factorial_state: ArrayTree, local_state: ArrayTree, - model_inputs: ArrayTreeLike, + factorial_state: ArrayTree, + factorial_inds: ArrayLike, ) -> ArrayTree: """Marginalize joint state into factored state and insert into factorial state. @@ -53,8 +69,7 @@ def __call__( Args: factorial_state: Factorial state with factorial index as the first dimension. local_state: Joint local state with no factorial index dimension. - model_inputs: Model inputs including information required to determine - the relevant factors (e.g. factor indices). + factorial_inds: Indices of the factors to insert. Integer array. Returns: Factorial state with factorial index as the first dimension. @@ -62,3 +77,20 @@ def __call__( The remaining factors are left unchanged. """ ... + + +class Factorializer(NamedTuple): + """Factorializer object. + + Attributes: + get_factorial_indices: Function to get the factorial indices. + Model inputs dependent. + extract_and_join: Function to extract and join the relevant factors. + Inference method dependent (e.g. Gaussian/SMC etc) + marginalize_and_insert: Function to marginalize and insert the updated factors. + Inference method dependent (e.g. Gaussian/SMC etc). + """ + + get_factorial_indices: GetFactorialIndices + extract_and_join: ExtractAndJoin + marginalize_and_insert: MarginalizeAndInsert From f47c78b573f4463b7422fbbd8d1b15b6a095817b Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 4 Feb 2026 15:03:20 +0000 Subject: [PATCH 12/29] Update init --- cuthbert/factorial/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cuthbert/factorial/__init__.py b/cuthbert/factorial/__init__.py index fd55e23a..f2fa3a29 100644 --- a/cuthbert/factorial/__init__.py +++ b/cuthbert/factorial/__init__.py @@ -1,2 +1,9 @@ from cuthbert.factorial.filtering import filter -from cuthbert.factorial.types import ExtractAndJoin, MarginalizeAndInsert +from cuthbert.factorial.types import ( + Factorializer, + GetFactorialIndices, + ExtractAndJoin, + MarginalizeAndInsert, +) + +from cuthbert.factorial import gaussian From 01c7dbf69aeeca5dca4c3412b054903be023376a Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 16:57:19 +0000 Subject: [PATCH 13/29] Fix gaussian and add block_marginal_sqrt_cov --- cuthbert/factorial/gaussian.py | 22 ++++---- cuthbert/gaussian/kalman.py | 7 ++- cuthbertlib/linalg/__init__.py | 5 +- cuthbertlib/linalg/marginal_sqrt_cov.py | 44 +++++++++++----- tests/cuthbert/factorial/gaussian_utils.py | 50 ++++++++++++++++++ tests/cuthbert/factorial/test_filtering.py | 52 ++++++++++++++----- .../linalg/test_marginal_sqrt_cov.py | 41 ++++++++++++++- 7 files changed, 177 insertions(+), 44 deletions(-) create mode 100644 tests/cuthbert/factorial/gaussian_utils.py diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py index 6c772a3f..7f17db26 100644 --- a/cuthbert/factorial/gaussian.py +++ b/cuthbert/factorial/gaussian.py @@ -2,10 +2,10 @@ from typing import TypeVar -from jax import tree, numpy as jnp, vmap +from jax import tree, numpy as jnp from jax.scipy.linalg import block_diag -from cuthbertlib.linalg import marginal_sqrt_cov +from cuthbertlib.linalg import block_marginal_sqrt_cov from cuthbert.gaussian.kalman import KalmanFilterState from cuthbert.gaussian.types import LinearizedKalmanFilterState from cuthbertlib.types import Array, ArrayLike @@ -65,14 +65,14 @@ def extract_and_join( def _extract_and_join_arr(arr: Array, factorial_inds: Array) -> Array: - if arr.ndim == 1: + if arr.ndim == 0 or arr.ndim == 1: return arr elif arr.ndim == 2: return _extract_and_join_means(arr, factorial_inds) elif arr.ndim == 3: return _extract_and_join_chol_covs(arr, factorial_inds) else: - raise ValueError(f"Array must be 1D, 2D or 3D, got {arr.ndim}D") + raise ValueError(f"Array must be 3D or lower, got {arr.ndim}D") def _extract_and_join_means(means: Array, factorial_inds: Array) -> Array: @@ -129,16 +129,16 @@ def _marginalize_and_insert_arr( local_arr: Array, factorial_arr: Array, factorial_inds: ArrayLike ) -> Array: factorial_inds = jnp.asarray(factorial_inds) - if local_arr.ndim == 1: + if factorial_arr.ndim == 0 or factorial_arr.ndim == 1: return local_arr - elif local_arr.ndim == 2: + elif factorial_arr.ndim == 2: return _marginalize_and_insert_mean(local_arr, factorial_arr, factorial_inds) - elif local_arr.ndim == 3: + elif factorial_arr.ndim == 3: return _marginalize_and_insert_chol_cov( local_arr, factorial_arr, factorial_inds ) else: - raise ValueError(f"Array must be 1D, 2D or 3D, got {local_arr.ndim}D") + raise ValueError(f"Array must be 3D or lower, got {local_arr.ndim}D") def _marginalize_and_insert_mean( @@ -156,9 +156,5 @@ def _marginalize_and_insert_chol_cov( factorial_inds: Array, ) -> Array: d = factorial_chol_covs.shape[-1] - starts = jnp.arange(0, len(factorial_inds)) * d - ends = starts + d - marginal_chol_covs = vmap(lambda s, e: marginal_sqrt_cov(local_chol_cov, s, e))( - starts, ends - ) + marginal_chol_covs = block_marginal_sqrt_cov(local_chol_cov, d) return factorial_chol_covs.at[factorial_inds].set(marginal_chol_covs) diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index e79642df..efb1e5a0 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -154,7 +154,12 @@ def init_prepare( m0, chol_P0 = get_init_params(model_inputs) H, d, chol_R, y = get_observation_params(model_inputs) - (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y) + if jnp.isnan(y).any(): + m, chol_P = m0, chol_P0 + ell = jnp.array(0.0) + else: + (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y) + elem = filtering.FilterScanElement( A=jnp.zeros_like(chol_P), b=m, diff --git a/cuthbertlib/linalg/__init__.py b/cuthbertlib/linalg/__init__.py index f63dfc9a..84aa73ec 100644 --- a/cuthbertlib/linalg/__init__.py +++ b/cuthbertlib/linalg/__init__.py @@ -1,5 +1,8 @@ from cuthbertlib.linalg.collect_nans_chol import collect_nans_chol -from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov +from cuthbertlib.linalg.marginal_sqrt_cov import ( + marginal_sqrt_cov, + block_marginal_sqrt_cov, +) from cuthbertlib.linalg.symmetric_inv_sqrt import ( chol_cov_with_nans_to_cov, symmetric_inv_sqrt, diff --git a/cuthbertlib/linalg/marginal_sqrt_cov.py b/cuthbertlib/linalg/marginal_sqrt_cov.py index 1c8285e7..c5a43c22 100644 --- a/cuthbertlib/linalg/marginal_sqrt_cov.py +++ b/cuthbertlib/linalg/marginal_sqrt_cov.py @@ -1,34 +1,52 @@ """Extract marginal square root covariance from a joint square root covariance.""" -from typing import Sequence +from functools import partial from jax import numpy as jnp +from jax import jit, vmap +from jax.lax import dynamic_slice from cuthbertlib.linalg.tria import tria from cuthbertlib.types import Array, ArrayLike -def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, end: int) -> Array: +@partial(jit, static_argnums=(2,)) +def marginal_sqrt_cov(chol_cov: ArrayLike, start: int | Array, size: int) -> Array: """Extracts square root submatrix from a joint square root matrix. Specifically, returns B such that - B @ B.T = (chol_cov @ chol_cov.T)[start:end, start:end] + B @ B.T = (chol_cov @ chol_cov.T)[start:start+size, start:start+size] Args: chol_cov: Generalized Cholesky factor of the covariance matrix. - start: Start index of the submatrix. - end: End index of the submatrix. + start: Start index of the submatrix (int or 0-d array for use under vmap). + size: Number of rows/columns of the marginal block. Must be a Python int + so that the function can be JIT-compiled. Returns: Lower triangular square root matrix of the marginal covariance matrix. """ chol_cov = jnp.asarray(chol_cov) - assert chol_cov.ndim == 2, "chol_cov must be a 2D array" - assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square" - assert start >= 0 and end <= chol_cov.shape[0], ( - "start and end must be within the bounds of chol_cov" - ) - assert start < end, "start must be less than end" - - chol_cov_select_rows = chol_cov[start:end, :] + slice_sizes = (size, chol_cov.shape[1]) + chol_cov_select_rows = dynamic_slice(chol_cov, (start, 0), slice_sizes) return tria(chol_cov_select_rows) + + +@partial(jit, static_argnums=(1,)) +def block_marginal_sqrt_cov(chol_cov: ArrayLike, subdim: int) -> Array: + """Extracts all square root submatrices of specified size from joint square root matrix. + + Args: + chol_cov: Generalized Cholesky factor of the covariance matrix. + subdim: Size of the square root submatrices to extract. + Must be a divisor of the number of rows in chol_cov. + + Returns: + Array of shape (chol_cov.shape[0] // subdim, subdim, subdim) + containing the square root submatrices. + """ + chol_cov = jnp.asarray(chol_cov) + n_blocks = chol_cov.shape[0] // subdim + return vmap(lambda i: marginal_sqrt_cov(chol_cov, i * subdim, subdim))( + jnp.arange(n_blocks) + ) diff --git a/tests/cuthbert/factorial/gaussian_utils.py b/tests/cuthbert/factorial/gaussian_utils.py new file mode 100644 index 00000000..48722d4f --- /dev/null +++ b/tests/cuthbert/factorial/gaussian_utils.py @@ -0,0 +1,50 @@ +import jax.numpy as jnp +from jax import random, vmap + +from cuthbertlib.kalman import generate + + +def generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps +): + # T = num_time_steps, F = num_factors + + key = random.key(seed) + init_key, factorial_indices_key = random.split(key, 2) + + # m0 with shape (F, x_dim) + # chol_P0 with shape (F, x_dim, x_dim) + init_keys_factorial = random.split(init_key, num_factors) + m0s, chol_P0s = vmap(generate.generate_init_model, in_axes=(0, None))( + init_keys_factorial, x_dim + ) + + # Fs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # cs with shape (T, num_factors_local * x_dim) + # chol_Qs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # Hs with shape (T, d_y, num_factors_local * x_dim) + # ds with shape (T, y_dim) + # chol_Rs with shape (T, num_factors_local * y_dim, num_factors_local * y_dim) + # ys with shape (T, d_y) + _, _, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate.generate_lgssm( + seed + 1, num_factors_local * x_dim, y_dim, num_time_steps + ) + # Remove the first time step from observation parameters (set as nan) + # no initial observation for factorial models + Hs = Hs.at[0].set(jnp.full_like(Hs[0], jnp.nan)) + ds = ds.at[0].set(jnp.full_like(ds[0], jnp.nan)) + chol_Rs = chol_Rs.at[0].set(jnp.full_like(chol_Rs[0], jnp.nan)) + ys = ys.at[0].set(jnp.full_like(ys[0], jnp.nan)) + + # factorial_indices with shape (T, num_factors_local) + # Each entry is a random integer in {0, ..., num_factors - 1} + # But each row must have unique entries + def rand_unique_indices(key): + indices = random.choice( + key, jnp.arange(num_factors), (num_factors_local,), replace=False + ) + return indices + + factorial_indices_keys = random.split(factorial_indices_key, num_time_steps) + factorial_indices = vmap(rand_unique_indices)(factorial_indices_keys) + return m0s, chol_P0s, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys, factorial_indices diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index 27b69c05..a7b3ad0b 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -1,15 +1,17 @@ +import itertools + import jax import jax.numpy as jnp import pytest from jax import Array, vmap - from cuthbert import factorial from cuthbert.gaussian import kalman from cuthbert.inference import Filter, Smoother from cuthbertlib.kalman.generate import generate_lgssm from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update from tests.cuthbertlib.kalman.test_smoothing import std_kalman_smoother +from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model @pytest.fixture(scope="module", autouse=True) @@ -25,12 +27,12 @@ def load_kalman_pairwise_factorial_inference( Fs: Array, # (T, 2 * d, 2 * d) cs: Array, # (T, 2 * d) chol_Qs: Array, # (T, 2 * d, 2 * d) - Hs: Array, # (T, 2 * d, d_y) - ds: Array, # (T, d_y) - chol_Rs: Array, # (T, d_y, d_y) - ys: Array, # (T + 1, d_y) + Hs: Array, # (T+1, d_y, 2 * d) with nans for initial time step + ds: Array, # (T+1, d_y) with nans for initial time step + chol_Rs: Array, # (T+1, d_y, d_y) with nans for initial time step + ys: Array, # (T+1, d_y) with nans for initial time step factorial_indices: Array, # (T, 2) -) -> tuple[Filter, Smoother, Array]: +) -> tuple[Filter, Smoother, factorial.Factorializer, Array]: """Builds Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" def get_init_params(model_inputs: int) -> tuple[Array, Array]: @@ -47,18 +49,40 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra ys[model_inputs], ) - def extract_and_join(factorial_state, model_inputs): - fac_inds = factorial_indices[model_inputs - 1] - - means = - - - filter = kalman.build_filter( get_init_params, get_dynamics_params, get_observation_params ) smoother = kalman.build_smoother( get_dynamics_params, store_gain=True, store_chol_cov_given_next=True ) + + factorializer = factorial.gaussian.build_factorializer( + get_factorial_indices=lambda model_inputs: factorial_indices[model_inputs - 1] + ) model_inputs = jnp.arange(len(ys)) - return filter, smoother, model_inputs + return filter, smoother, factorializer, model_inputs + + +seeds = [1, 43] +x_dims = [1, 3] +y_dims = [1, 2] +num_factors = [10, 20] +num_factors_local = [2] # number of factors to interact at each time step +num_time_steps = [1, 25] + +common_params = list( + itertools.product(seeds, x_dims, y_dims, num_factors, num_time_steps) +) + + +def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, smoother_obj, factorializer, model_inputs = ( + load_kalman_pairwise_factorial_inference(*model_params) + ) + + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=False + ) diff --git a/tests/cuthbertlib/linalg/test_marginal_sqrt_cov.py b/tests/cuthbertlib/linalg/test_marginal_sqrt_cov.py index b39e82d2..f54674f3 100644 --- a/tests/cuthbertlib/linalg/test_marginal_sqrt_cov.py +++ b/tests/cuthbertlib/linalg/test_marginal_sqrt_cov.py @@ -3,7 +3,10 @@ import pytest from jax import random -from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov +from cuthbertlib.linalg.marginal_sqrt_cov import ( + block_marginal_sqrt_cov, + marginal_sqrt_cov, +) @pytest.fixture(scope="module", autouse=True) @@ -30,7 +33,7 @@ def test_marginal_sqrt_cov(seed, n, start, end): L = jnp.tril(random.normal(key, (n, n))) # Extract marginal square root - B = marginal_sqrt_cov(L, start, end) + B = marginal_sqrt_cov(L, start, end - start) # Expected marginal covariance block Sigma = L @ L.T @@ -41,3 +44,37 @@ def test_marginal_sqrt_cov(seed, n, start, end): # Check B B^T reproduces marginal covariance assert jnp.allclose(B @ B.T, Sigma_block) + + +@pytest.mark.parametrize("seed", [0, 42]) +@pytest.mark.parametrize( + "n,subdim", + [ + (6, 2), + (6, 3), + (8, 4), + (9, 3), + ], +) +def test_block_marginal_sqrt_cov(seed, n, subdim): + key = random.key(seed) + L = jnp.tril(random.normal(key, (n, n))) + + blocks = block_marginal_sqrt_cov(L, subdim) + + n_blocks = n // subdim + assert blocks.shape == (n_blocks, subdim, subdim) + + Sigma = L @ L.T + for i in range(n_blocks): + start, end = i * subdim, (i + 1) * subdim + Sigma_block = Sigma[start:end, start:end] + assert jnp.allclose( + blocks[i], jnp.tril(blocks[i]) + ) # Check that blocks are lower triangular + assert jnp.allclose( + blocks[i] @ blocks[i].T, Sigma_block + ) # Check that blocks reproduce the marginal covariance + assert jnp.allclose( + blocks[i], marginal_sqrt_cov(L, start, subdim) + ) # Check that blocks are the same as the marginal square root covariance From 29de6a0c6401865959ff51a1ec95678ef4894f06 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:13:59 +0000 Subject: [PATCH 14/29] Flesh out test --- tests/cuthbert/factorial/test_filtering.py | 88 ++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index a7b3ad0b..02152d23 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -4,6 +4,8 @@ import jax.numpy as jnp import pytest from jax import Array, vmap +from jax.scipy.linalg import block_diag +import chex from cuthbert import factorial from cuthbert.gaussian import kalman @@ -75,6 +77,9 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra ) +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps", common_params +) def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps): model_params = generate_factorial_kalman_model( seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps @@ -83,6 +88,89 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste load_kalman_pairwise_factorial_inference(*model_params) ) + # True means, covs and log norm constants + fac_means = model_params[0] + fac_chol_covs = model_params[1] + fac_covs = fac_chol_covs @ fac_chol_covs.transpose(0, 2, 1) + ell = jnp.array(0.0) + + local_means = [] + local_covs = [] + ells = [] + fac_means_t_all = [fac_means] + fac_covs_t_all = [fac_covs] + for i in model_inputs[1:]: + F, c, chol_Q = ( + model_params[2][i - 1], + model_params[3][i - 1], + model_params[4][i - 1], + ) + H, d, chol_R, y = ( + model_params[5][i], + model_params[6][i], + model_params[7][i], + model_params[8][i], + ) + fac_inds = model_params[9][i - 1] + + joint_mean = fac_means[fac_inds].reshape(-1) + joint_cov = block_diag(*fac_covs[fac_inds]) + Q = chol_Q @ chol_Q.T + R = chol_R @ chol_R.T + pred_mean, pred_cov = std_predict(joint_mean, joint_cov, F, c, Q) + upd_mean, upd_cov, upd_ell = std_update(pred_mean, pred_cov, H, d, R, y) + marginal_means = upd_mean.reshape(len(fac_inds), -1) + marginal_covs = jnp.array( + [ + upd_cov[i * x_dim : (i + 1) * x_dim, i * x_dim : (i + 1) * x_dim] + for i in range(len(fac_inds)) + ] + ) + ell += upd_ell + local_means.append(marginal_means) + local_covs.append(marginal_covs) + ells.append(ell) + fac_means = fac_means.at[fac_inds].set(marginal_means) + fac_covs = fac_covs.at[fac_inds].set(marginal_covs) + fac_means_t_all.append(fac_means) + fac_covs_t_all.append(fac_covs) + + local_means = jnp.stack(local_means) + local_covs = jnp.stack(local_covs) + ells = jnp.stack(ells) + fac_means_t_all = jnp.stack(fac_means_t_all) + fac_covs_t_all = jnp.stack(fac_covs_t_all) + + # Check output_factorial = False init_state, local_filter_states = factorial.filter( filter_obj, factorializer, model_inputs, output_factorial=False ) + local_filter_covs = ( + local_filter_states.chol_cov @ local_filter_states.chol_cov.transpose(0, 2, 1) + ) + chex.assert_trees_all_close( + (init_state.mean, init_state.chol_cov), (model_params[0], model_params[1]) + ) + chex.assert_trees_all_close( + (local_means, local_covs, ells), + ( + local_filter_states.mean, + local_filter_covs, + local_filter_states.log_normalizing_constant, + ), + ) + + # Check output_factorial = False + factorial_filtering_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=True + ) + local_filter_covs = ( + local_filter_states.chol_cov @ local_filter_states.chol_cov.transpose(0, 2, 1) + ) + chex.assert_trees_all_close( + (fac_means_t_all, fac_covs_t_all), + (factorial_filtering_states.mean, factorial_filtering_states.chol_cov), + ) + chex.assert_trees_all_close( + ells, factorial_filtering_states.log_normalizing_constant[1:] + ) From 5ff9dbbac914d0ab5d258d5ae3497e3d131b1aef Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:28:59 +0000 Subject: [PATCH 15/29] Add semi hack to extract local factorial states --- cuthbert/factorial/filtering.py | 14 ++++++++++++-- tests/cuthbert/factorial/test_filtering.py | 14 ++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py index ff57cb9e..7dc67518 100644 --- a/cuthbert/factorial/filtering.py +++ b/cuthbert/factorial/filtering.py @@ -15,7 +15,9 @@ def filter( model_inputs: ArrayTreeLike, output_factorial: bool = False, key: KeyArray | None = None, -) -> ArrayTree | tuple[ArrayTree, ArrayTree]: +) -> ( + ArrayTree | tuple[ArrayTree, ArrayTree] +): # TODO: Can overload this function so the type checker knows that the output is a ArrayTree if output_factorial is True and a tuple[ArrayTree, ArrayTree] if output_factorial is False """Applies offline factorial filtering for given model inputs. `model_inputs` should have leading temporal dimension of length T + 1, @@ -70,7 +72,15 @@ def body_local(prev_factorial_state, prep_inp_and_k): factorial_state = factorializer.marginalize_and_insert( filtered_joint_state, prev_factorial_state, factorial_inds ) - return factorial_state, filtered_joint_state + + def extract(arr): + if arr.ndim >= 2: + return arr[factorial_inds] + else: + return arr + + factorial_state_fac_inds = tree.map(extract, factorial_state) + return factorial_state, factorial_state_fac_inds if output_factorial: diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index 02152d23..814059de 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -1,4 +1,5 @@ import itertools +from typing import cast import jax import jax.numpy as jnp @@ -7,6 +8,7 @@ from jax.scipy.linalg import block_diag import chex +from cuthbertlib.types import ArrayTree from cuthbert import factorial from cuthbert.gaussian import kalman from cuthbert.inference import Filter, Smoother @@ -146,7 +148,8 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste filter_obj, factorializer, model_inputs, output_factorial=False ) local_filter_covs = ( - local_filter_states.chol_cov @ local_filter_states.chol_cov.transpose(0, 2, 1) + local_filter_states.chol_cov + @ local_filter_states.chol_cov.transpose(0, 1, 3, 2) ) chex.assert_trees_all_close( (init_state.mean, init_state.chol_cov), (model_params[0], model_params[1]) @@ -164,12 +167,15 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste factorial_filtering_states = factorial.filter( filter_obj, factorializer, model_inputs, output_factorial=True ) - local_filter_covs = ( - local_filter_states.chol_cov @ local_filter_states.chol_cov.transpose(0, 2, 1) + + factorial_filtering_states = cast(ArrayTree, factorial_filtering_states) + factorial_filtering_covs = ( + factorial_filtering_states.chol_cov + @ factorial_filtering_states.chol_cov.transpose(0, 1, 3, 2) ) chex.assert_trees_all_close( (fac_means_t_all, fac_covs_t_all), - (factorial_filtering_states.mean, factorial_filtering_states.chol_cov), + (factorial_filtering_states.mean, factorial_filtering_covs), ) chex.assert_trees_all_close( ells, factorial_filtering_states.log_normalizing_constant[1:] From e39565e441cab4c5849c7d40e526f75872687444 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:29:28 +0000 Subject: [PATCH 16/29] Rename --- cuthbert/factorial/filtering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py index 7dc67518..d909f386 100644 --- a/cuthbert/factorial/filtering.py +++ b/cuthbert/factorial/filtering.py @@ -79,8 +79,8 @@ def extract(arr): else: return arr - factorial_state_fac_inds = tree.map(extract, factorial_state) - return factorial_state, factorial_state_fac_inds + factorial_state_fac_inds_only = tree.map(extract, factorial_state) + return factorial_state, factorial_state_fac_inds_only if output_factorial: From 5b2a0f50fa7e9efb15cb6ac3fede5c430d44ace5 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:32:32 +0000 Subject: [PATCH 17/29] Fix test --- tests/cuthbert/factorial/__init__.py | 0 tests/cuthbert/factorial/test_filtering.py | 4 +++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 tests/cuthbert/factorial/__init__.py diff --git a/tests/cuthbert/factorial/__init__.py b/tests/cuthbert/factorial/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index 814059de..7e3da5ce 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -75,7 +75,9 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra num_time_steps = [1, 25] common_params = list( - itertools.product(seeds, x_dims, y_dims, num_factors, num_time_steps) + itertools.product( + seeds, x_dims, y_dims, num_factors, num_factors_local, num_time_steps + ) ) From 026cc4778548fe07d519bc4ef22ffea78f99f67d Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:35:28 +0000 Subject: [PATCH 18/29] Change init_prepare nan check to all --- cuthbert/gaussian/kalman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index efb1e5a0..2f67d576 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -154,7 +154,7 @@ def init_prepare( m0, chol_P0 = get_init_params(model_inputs) H, d, chol_R, y = get_observation_params(model_inputs) - if jnp.isnan(y).any(): + if jnp.isnan(y).all(): m, chol_P = m0, chol_P0 ell = jnp.array(0.0) else: From ea41af2e156e5e514a3e5bf19b59e7dd8105fe2e Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 5 Feb 2026 18:38:44 +0000 Subject: [PATCH 19/29] Sort imports --- cuthbert/factorial/__init__.py | 5 ++--- cuthbert/factorial/gaussian.py | 8 ++++---- cuthbert/factorial/types.py | 4 ++-- cuthbertlib/linalg/__init__.py | 2 +- cuthbertlib/linalg/marginal_sqrt_cov.py | 2 +- tests/cuthbert/factorial/test_filtering.py | 6 +++--- 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/cuthbert/factorial/__init__.py b/cuthbert/factorial/__init__.py index f2fa3a29..87a6a70f 100644 --- a/cuthbert/factorial/__init__.py +++ b/cuthbert/factorial/__init__.py @@ -1,9 +1,8 @@ +from cuthbert.factorial import gaussian from cuthbert.factorial.filtering import filter from cuthbert.factorial.types import ( + ExtractAndJoin, Factorializer, GetFactorialIndices, - ExtractAndJoin, MarginalizeAndInsert, ) - -from cuthbert.factorial import gaussian diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py index 7f17db26..895c0288 100644 --- a/cuthbert/factorial/gaussian.py +++ b/cuthbert/factorial/gaussian.py @@ -2,15 +2,15 @@ from typing import TypeVar -from jax import tree, numpy as jnp +from jax import numpy as jnp +from jax import tree from jax.scipy.linalg import block_diag -from cuthbertlib.linalg import block_marginal_sqrt_cov +from cuthbert.factorial.types import Factorializer, GetFactorialIndices from cuthbert.gaussian.kalman import KalmanFilterState from cuthbert.gaussian.types import LinearizedKalmanFilterState +from cuthbertlib.linalg import block_marginal_sqrt_cov from cuthbertlib.types import Array, ArrayLike -from cuthbert.factorial.types import Factorializer, GetFactorialIndices - KalmanState = TypeVar("KalmanState", KalmanFilterState, LinearizedKalmanFilterState) diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py index 3158356b..dfeeeed3 100644 --- a/cuthbert/factorial/types.py +++ b/cuthbert/factorial/types.py @@ -1,8 +1,8 @@ """Provides types for factorial state-space models.""" -from typing import Protocol, NamedTuple +from typing import NamedTuple, Protocol -from cuthbertlib.types import ArrayTree, ArrayTreeLike, ArrayLike +from cuthbertlib.types import ArrayLike, ArrayTree, ArrayTreeLike class GetFactorialIndices(Protocol): diff --git a/cuthbertlib/linalg/__init__.py b/cuthbertlib/linalg/__init__.py index 84aa73ec..503a1bc5 100644 --- a/cuthbertlib/linalg/__init__.py +++ b/cuthbertlib/linalg/__init__.py @@ -1,7 +1,7 @@ from cuthbertlib.linalg.collect_nans_chol import collect_nans_chol from cuthbertlib.linalg.marginal_sqrt_cov import ( - marginal_sqrt_cov, block_marginal_sqrt_cov, + marginal_sqrt_cov, ) from cuthbertlib.linalg.symmetric_inv_sqrt import ( chol_cov_with_nans_to_cov, diff --git a/cuthbertlib/linalg/marginal_sqrt_cov.py b/cuthbertlib/linalg/marginal_sqrt_cov.py index c5a43c22..d62fd067 100644 --- a/cuthbertlib/linalg/marginal_sqrt_cov.py +++ b/cuthbertlib/linalg/marginal_sqrt_cov.py @@ -2,8 +2,8 @@ from functools import partial -from jax import numpy as jnp from jax import jit, vmap +from jax import numpy as jnp from jax.lax import dynamic_slice from cuthbertlib.linalg.tria import tria diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index 7e3da5ce..e2d4085e 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -1,21 +1,21 @@ import itertools from typing import cast +import chex import jax import jax.numpy as jnp import pytest from jax import Array, vmap from jax.scipy.linalg import block_diag -import chex -from cuthbertlib.types import ArrayTree from cuthbert import factorial from cuthbert.gaussian import kalman from cuthbert.inference import Filter, Smoother from cuthbertlib.kalman.generate import generate_lgssm +from cuthbertlib.types import ArrayTree +from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update from tests.cuthbertlib.kalman.test_smoothing import std_kalman_smoother -from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model @pytest.fixture(scope="module", autouse=True) From 38bb0eeb71021faeacf61b302ce4d9569ffff678 Mon Sep 17 00:00:00 2001 From: SamDuffield <34280297+SamDuffield@users.noreply.github.com> Date: Fri, 6 Feb 2026 09:57:16 +0000 Subject: [PATCH 20/29] Update cuthbert/gaussian/kalman.py Co-authored-by: Sahel Iqbal --- cuthbert/gaussian/kalman.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index 2f67d576..9e60c763 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -154,11 +154,12 @@ def init_prepare( m0, chol_P0 = get_init_params(model_inputs) H, d, chol_R, y = get_observation_params(model_inputs) - if jnp.isnan(y).all(): - m, chol_P = m0, chol_P0 - ell = jnp.array(0.0) - else: - (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y) + (m, chol_P), ell = lax.cond( + jnp.isnan(y).all(), + lambda _: ((m0, chol_P0), jnp.zeros((), dtype=m0.dtype)), + lambda _: filtering.update(m0, chol_P0, H, d, chol_R, y), + operand=None, + ) elem = filtering.FilterScanElement( A=jnp.zeros_like(chol_P), From 167b6f501baecbd8ae38b544abc9704f2eb9124b Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 6 Feb 2026 12:48:51 +0000 Subject: [PATCH 21/29] Add cond import --- cuthbert/gaussian/kalman.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index 9e60c763..0241bce1 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -8,6 +8,7 @@ from jax import numpy as jnp from jax import tree +from jax.lax import cond from cuthbert.gaussian.types import ( GetDynamicsParams, @@ -154,12 +155,12 @@ def init_prepare( m0, chol_P0 = get_init_params(model_inputs) H, d, chol_R, y = get_observation_params(model_inputs) - (m, chol_P), ell = lax.cond( + (m, chol_P), ell = cond( jnp.isnan(y).all(), lambda _: ((m0, chol_P0), jnp.zeros((), dtype=m0.dtype)), lambda _: filtering.update(m0, chol_P0, H, d, chol_R, y), operand=None, - ) + ) elem = filtering.FilterScanElement( A=jnp.zeros_like(chol_P), From db9bec59f4727157f30ac836fdcaf8539ed274e6 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Sat, 7 Feb 2026 13:31:11 +0000 Subject: [PATCH 22/29] Fix imports --- cuthbertlib/linalg/marginal_sqrt_cov.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuthbertlib/linalg/marginal_sqrt_cov.py b/cuthbertlib/linalg/marginal_sqrt_cov.py index 95eed92c..da231259 100644 --- a/cuthbertlib/linalg/marginal_sqrt_cov.py +++ b/cuthbertlib/linalg/marginal_sqrt_cov.py @@ -1,6 +1,5 @@ """Extract marginal square root covariance(s) from a joint square root covariance.""" -from jax import jit, vmap from jax import numpy as jnp from jax import vmap from jax.lax import dynamic_slice From 8d463e8dc9d2681decdfced12bc063dd49e2ff44 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Sat, 7 Feb 2026 13:40:34 +0000 Subject: [PATCH 23/29] Ignore model_inputs and revert init_prepare --- cuthbert/factorial/gaussian.py | 27 +++++++++++++++++++++++---- cuthbert/gaussian/kalman.py | 10 ++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py index 895c0288..f20c9be2 100644 --- a/cuthbert/factorial/gaussian.py +++ b/cuthbert/factorial/gaussian.py @@ -61,7 +61,16 @@ def extract_and_join( Joint local Kalman state with no factorial index dimension. """ factorial_inds = jnp.asarray(factorial_inds) - return tree.map(lambda x: _extract_and_join_arr(x, factorial_inds), factorial_state) + new_elem = tree.map( + lambda x: _extract_and_join_arr(x, factorial_inds), factorial_state.elem + ) + new_state = factorial_state._replace(elem=new_elem) + + if isinstance(factorial_state, LinearizedKalmanFilterState): + new_mean_prev = _extract_and_join_arr(factorial_state.mean_prev, factorial_inds) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state def _extract_and_join_arr(arr: Array, factorial_inds: Array) -> Array: @@ -118,11 +127,21 @@ def marginalize_and_insert( Joint local Kalman state with no factorial index dimension. """ factorial_inds = jnp.asarray(factorial_inds) - return tree.map( + new_elem = tree.map( lambda loc, fac: _marginalize_and_insert_arr(loc, fac, factorial_inds), - local_state, - factorial_state, + local_state.elem, + factorial_state.elem, ) + new_state = local_state._replace(elem=new_elem) + if isinstance(local_state, LinearizedKalmanFilterState) and isinstance( + factorial_state, LinearizedKalmanFilterState + ): + new_mean_prev = _marginalize_and_insert_arr( + local_state.mean_prev, factorial_state.mean_prev, factorial_inds + ) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state def _marginalize_and_insert_arr( diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index 0241bce1..51a3b9c0 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -155,12 +155,10 @@ def init_prepare( m0, chol_P0 = get_init_params(model_inputs) H, d, chol_R, y = get_observation_params(model_inputs) - (m, chol_P), ell = cond( - jnp.isnan(y).all(), - lambda _: ((m0, chol_P0), jnp.zeros((), dtype=m0.dtype)), - lambda _: filtering.update(m0, chol_P0, H, d, chol_R, y), - operand=None, - ) + if jnp.isnan(y).all(): + (m, chol_P), ell = ((m0, chol_P0), jnp.zeros((), dtype=m0.dtype)) + else: + (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y) elem = filtering.FilterScanElement( A=jnp.zeros_like(chol_P), From 8ecbab34c1d6c67b172bcf83fac78eae867752af Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 24 Feb 2026 11:18:37 +0000 Subject: [PATCH 24/29] Remove cond import --- cuthbert/gaussian/kalman.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuthbert/gaussian/kalman.py b/cuthbert/gaussian/kalman.py index 0b01bc71..7ee2fd9c 100644 --- a/cuthbert/gaussian/kalman.py +++ b/cuthbert/gaussian/kalman.py @@ -8,7 +8,6 @@ from jax import numpy as jnp from jax import tree -from jax.lax import cond from cuthbert.gaussian.types import ( GetDynamicsParams, From 7afb6479fe1c7d95ff63ec27a94ad8f9b5cb879b Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 24 Feb 2026 11:20:53 +0000 Subject: [PATCH 25/29] Remove readme edit --- cuthbert/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/cuthbert/README.md b/cuthbert/README.md index 6cdc6307..92b9e384 100644 --- a/cuthbert/README.md +++ b/cuthbert/README.md @@ -7,7 +7,6 @@ All inference methods are implemented with the following unified interface: ```python from jax import tree -import cuthbert # Define model_inputs model_inputs = ... From baf55f39010abf0556b4f3d327ccc8bf3a53828d Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 24 Feb 2026 11:42:36 +0000 Subject: [PATCH 26/29] Fix filtering --- tests/cuthbert/factorial/gaussian_utils.py | 6 ---- tests/cuthbert/factorial/test_filtering.py | 32 ++++++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/tests/cuthbert/factorial/gaussian_utils.py b/tests/cuthbert/factorial/gaussian_utils.py index 48722d4f..e3579118 100644 --- a/tests/cuthbert/factorial/gaussian_utils.py +++ b/tests/cuthbert/factorial/gaussian_utils.py @@ -29,12 +29,6 @@ def generate_factorial_kalman_model( _, _, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate.generate_lgssm( seed + 1, num_factors_local * x_dim, y_dim, num_time_steps ) - # Remove the first time step from observation parameters (set as nan) - # no initial observation for factorial models - Hs = Hs.at[0].set(jnp.full_like(Hs[0], jnp.nan)) - ds = ds.at[0].set(jnp.full_like(ds[0], jnp.nan)) - chol_Rs = chol_Rs.at[0].set(jnp.full_like(chol_Rs[0], jnp.nan)) - ys = ys.at[0].set(jnp.full_like(ys[0], jnp.nan)) # factorial_indices with shape (T, num_factors_local) # Each entry is a random integer in {0, ..., num_factors - 1} diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index e2d4085e..c4fdb969 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -5,17 +5,15 @@ import jax import jax.numpy as jnp import pytest -from jax import Array, vmap +from jax import Array from jax.scipy.linalg import block_diag from cuthbert import factorial from cuthbert.gaussian import kalman from cuthbert.inference import Filter, Smoother -from cuthbertlib.kalman.generate import generate_lgssm from cuthbertlib.types import ArrayTree from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update -from tests.cuthbertlib.kalman.test_smoothing import std_kalman_smoother @pytest.fixture(scope="module", autouse=True) @@ -31,10 +29,10 @@ def load_kalman_pairwise_factorial_inference( Fs: Array, # (T, 2 * d, 2 * d) cs: Array, # (T, 2 * d) chol_Qs: Array, # (T, 2 * d, 2 * d) - Hs: Array, # (T+1, d_y, 2 * d) with nans for initial time step - ds: Array, # (T+1, d_y) with nans for initial time step - chol_Rs: Array, # (T+1, d_y, d_y) with nans for initial time step - ys: Array, # (T+1, d_y) with nans for initial time step + Hs: Array, # (T, d_y, 2 * d) with nans for initial time step + ds: Array, # (T, d_y) with nans for initial time step + chol_Rs: Array, # (T, d_y, d_y) with nans for initial time step + ys: Array, # (T, d_y) with nans for initial time step factorial_indices: Array, # (T, 2) ) -> tuple[Filter, Smoother, factorial.Factorializer, Array]: """Builds Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" @@ -47,10 +45,10 @@ def get_dynamics_params(model_inputs: int) -> tuple[Array, Array, Array]: def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Array]: return ( - Hs[model_inputs], - ds[model_inputs], - chol_Rs[model_inputs], - ys[model_inputs], + Hs[model_inputs - 1], + ds[model_inputs - 1], + chol_Rs[model_inputs - 1], + ys[model_inputs - 1], ) filter = kalman.build_filter( @@ -63,7 +61,7 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra factorializer = factorial.gaussian.build_factorializer( get_factorial_indices=lambda model_inputs: factorial_indices[model_inputs - 1] ) - model_inputs = jnp.arange(len(ys)) + model_inputs = jnp.arange(len(ys) + 1) return filter, smoother, factorializer, model_inputs @@ -88,7 +86,7 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste model_params = generate_factorial_kalman_model( seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps ) - filter_obj, smoother_obj, factorializer, model_inputs = ( + filter_obj, _, factorializer, model_inputs = ( load_kalman_pairwise_factorial_inference(*model_params) ) @@ -110,10 +108,10 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste model_params[4][i - 1], ) H, d, chol_R, y = ( - model_params[5][i], - model_params[6][i], - model_params[7][i], - model_params[8][i], + model_params[5][i - 1], + model_params[6][i - 1], + model_params[7][i - 1], + model_params[8][i - 1], ) fac_inds = model_params[9][i - 1] From 7c808f8fe29bbe9cf6ed8e76d484938d16e57e30 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 24 Feb 2026 11:43:42 +0000 Subject: [PATCH 27/29] Remove old comment --- tests/cuthbert/factorial/test_filtering.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_filtering.py index c4fdb969..90920094 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_filtering.py @@ -29,10 +29,10 @@ def load_kalman_pairwise_factorial_inference( Fs: Array, # (T, 2 * d, 2 * d) cs: Array, # (T, 2 * d) chol_Qs: Array, # (T, 2 * d, 2 * d) - Hs: Array, # (T, d_y, 2 * d) with nans for initial time step - ds: Array, # (T, d_y) with nans for initial time step - chol_Rs: Array, # (T, d_y, d_y) with nans for initial time step - ys: Array, # (T, d_y) with nans for initial time step + Hs: Array, # (T, d_y, 2 * d) + ds: Array, # (T, d_y) + chol_Rs: Array, # (T, d_y, d_y) + ys: Array, # (T, d_y) factorial_indices: Array, # (T, 2) ) -> tuple[Filter, Smoother, factorial.Factorializer, Array]: """Builds Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" From a8ac8c0964bc363ae8430cf9f14e42447e544760 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 26 Feb 2026 11:11:19 +0000 Subject: [PATCH 28/29] Update README --- cuthbert/factorial/README.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md index a57c06ea..e1760514 100644 --- a/cuthbert/factorial/README.md +++ b/cuthbert/factorial/README.md @@ -16,7 +16,7 @@ $$ This motivates a factored approximation of filtering and smoothing distributions, e.g. $$ -p(x_t \mid y_{0:t}) = \prod_{f=1}^F p(x_t^f \mid y_{0:t}). +p(x_t \mid y_{1:t}) = \prod_{f=1}^F p(x_t^f \mid y_{1:t}). $$ A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035). @@ -86,16 +86,15 @@ init_factorial_state, local_filter_states = cuthbert.factorial.filter( ## Factorial smoothing with `cuthbert` Smoothing in factorial state-space models can be performed embarrassingly parallel -along the factors since the dynamics and factorial approximation are independent +across factors since the dynamics and factorial approximation are independent across factors (the observations are fully absorbed in the filtering and are not accessed during smoothing). The model inputs and filter states require some preprocessing to convert from being single sequence with each state containing all factors into a sequence or multiple sequences with each state corresponding to a single factor. This can be -fiddly but is left to the user for maximum freedom. - -TODO: Document some use cases in the examples. +fiddly but is left to the user for maximum freedom. Oftentimes, it is easiest to +specify different parameter functions for smoothing than filtering. After this preprocessing, smoothing can be performed as usual: @@ -106,9 +105,9 @@ model_inputs_single_factor = ... # Similarly, we need to extract the filter states for the single factor we're smoothing. filter_states_single_factor = ... -# Load smoother, with parameter extraction functions defined for factorial inference +# Load smoother, with parameter extraction functions defined for a single factor kalman_smoother = cuthbert.gaussian.kalman.build_smoother( - get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on a single factor ) smoother_state = kalman_smoother.convert_filter_to_smoother_state( From 90591dd2f988572728bf7be3d227b996b1d4e31d Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 26 Feb 2026 12:00:25 +0000 Subject: [PATCH 29/29] Start testing smoother --- .../{test_filtering.py => test_kalman.py} | 98 ++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) rename tests/cuthbert/factorial/{test_filtering.py => test_kalman.py} (65%) diff --git a/tests/cuthbert/factorial/test_filtering.py b/tests/cuthbert/factorial/test_kalman.py similarity index 65% rename from tests/cuthbert/factorial/test_filtering.py rename to tests/cuthbert/factorial/test_kalman.py index 90920094..7de0042d 100644 --- a/tests/cuthbert/factorial/test_filtering.py +++ b/tests/cuthbert/factorial/test_kalman.py @@ -1,5 +1,5 @@ import itertools -from typing import cast +from typing import cast, Callable import chex import jax @@ -12,6 +12,7 @@ from cuthbert.gaussian import kalman from cuthbert.inference import Filter, Smoother from cuthbertlib.types import ArrayTree +from cuthbertlib.linalg import block_marginal_sqrt_cov from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update @@ -34,8 +35,9 @@ def load_kalman_pairwise_factorial_inference( chol_Rs: Array, # (T, d_y, d_y) ys: Array, # (T, d_y) factorial_indices: Array, # (T, 2) -) -> tuple[Filter, Smoother, factorial.Factorializer, Array]: - """Builds Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" + smoother_factorial_index: int, +) -> tuple[Filter, factorial.Factorializer, Array, Smoother, Array]: + """Builds factorial Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" def get_init_params(model_inputs: int) -> tuple[Array, Array]: return m0, chol_P0 @@ -54,15 +56,52 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra filter = kalman.build_filter( get_init_params, get_dynamics_params, get_observation_params ) - smoother = kalman.build_smoother( - get_dynamics_params, store_gain=True, store_chol_cov_given_next=True - ) factorializer = factorial.gaussian.build_factorializer( get_factorial_indices=lambda model_inputs: factorial_indices[model_inputs - 1] ) - model_inputs = jnp.arange(len(ys) + 1) - return filter, smoother, factorializer, model_inputs + filter_model_inputs = jnp.arange(len(ys) + 1) + + # Some processing to get smoothing for a single factor + num_factors = len(m0) + d_x = m0.shape[1] + Fs_per_factor = [[] for _ in range(num_factors)] + cs_per_factor = [[] for _ in range(num_factors)] + chol_Qs_per_factor = [[] for _ in range(num_factors)] + + for i in range(1, len(ys) + 1): + h, a = factorial_indices[i - 1] + + F_h = Fs[i - 1][:d_x, :d_x] + F_a = Fs[i - 1][-d_x:, -d_x:] + c_h = cs[i - 1][:d_x] + c_a = cs[i - 1][-d_x:] + chol_Q_h, chol_Q_a = block_marginal_sqrt_cov(chol_Qs[i - 1], d_x) + Fs_per_factor[h].append(F_h) + cs_per_factor[h].append(c_h) + chol_Qs_per_factor[h].append(chol_Q_h) + + Fs_per_factor[a].append(F_a) + cs_per_factor[a].append(c_a) + chol_Qs_per_factor[a].append(chol_Q_a) + + def get_dynamics_params_single_factor( + model_inputs: int, + ) -> tuple[Array, Array, Array]: + return ( + Fs_per_factor[smoother_factorial_index][model_inputs - 1], + cs_per_factor[smoother_factorial_index][model_inputs - 1], + chol_Qs_per_factor[smoother_factorial_index][model_inputs - 1], + ) + + smoother = kalman.build_smoother( + get_dynamics_params_single_factor, + store_gain=True, + store_chol_cov_given_next=True, + ) + smoother_model_inputs = jnp.arange(len(Fs_per_factor[smoother_factorial_index]) + 1) + + return filter, factorializer, filter_model_inputs, smoother, smoother_model_inputs seeds = [1, 43] @@ -86,8 +125,10 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste model_params = generate_factorial_kalman_model( seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps ) - filter_obj, _, factorializer, model_inputs = ( - load_kalman_pairwise_factorial_inference(*model_params) + filter_obj, factorializer, model_inputs, _, _ = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=0 + ) ) # True means, covs and log norm constants @@ -163,7 +204,7 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste ), ) - # Check output_factorial = False + # Check output_factorial = True factorial_filtering_states = factorial.filter( filter_obj, factorializer, model_inputs, output_factorial=True ) @@ -180,3 +221,38 @@ def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_ste chex.assert_trees_all_close( ells, factorial_filtering_states.log_normalizing_constant[1:] ) + + +smoother_indices = [0, 1, 5] + +common_smoother_params = list(itertools.product(common_params, smoother_indices)) + + +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps,smoother_factorial_index", + common_smoother_params, +) +def test_smoother( + seed, + x_dim, + y_dim, + num_factors, + num_factors_local, + num_time_steps, + smoother_factorial_index, +): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, factorializer, filter_model_inputs, smoother, smoother_model_inputs = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=smoother_factorial_index + ) + ) + + # Check output_factorial = False + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, filter_model_inputs, output_factorial=False + ) + + # Convert to local smoother states