Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
540b3ba
Draft factorial filtering API
SamDuffield Feb 1, 2026
248fa12
Add smoothing API
SamDuffield Feb 2, 2026
1ff85c5
Fix typos
SamDuffield Feb 2, 2026
51fdc08
Add types.py
SamDuffield Feb 2, 2026
6b8efeb
Draft filtering
SamDuffield Feb 2, 2026
d5af70d
Add init and sort imports
SamDuffield Feb 2, 2026
ff0f085
Add cuthbert import to READMEs
SamDuffield Feb 2, 2026
5d4fb9a
Start test_filtering
SamDuffield Feb 2, 2026
64fbb37
Draft gaussian
SamDuffield Feb 3, 2026
76a85e0
Flesh out factorial gaussian
SamDuffield Feb 3, 2026
f98b875
Refactor factorial API with factorializer
SamDuffield Feb 4, 2026
f47c78b
Update init
SamDuffield Feb 4, 2026
01c7dbf
Fix gaussian and add block_marginal_sqrt_cov
SamDuffield Feb 5, 2026
29de6a0
Flesh out test
SamDuffield Feb 5, 2026
5ff9dbb
Add semi hack to extract local factorial states
SamDuffield Feb 5, 2026
e39565e
Rename
SamDuffield Feb 5, 2026
5b2a0f5
Fix test
SamDuffield Feb 5, 2026
026cc47
Change init_prepare nan check to all
SamDuffield Feb 5, 2026
ea41af2
Sort imports
SamDuffield Feb 5, 2026
c9c8a15
Merge branch 'main' into factorial
SamDuffield Feb 5, 2026
38bb0ee
Update cuthbert/gaussian/kalman.py
SamDuffield Feb 6, 2026
167b6f5
Add cond import
SamDuffield Feb 6, 2026
9fb2482
Merge branch 'main' into factorial
SamDuffield Feb 7, 2026
db9bec5
Fix imports
SamDuffield Feb 7, 2026
8d463e8
Ignore model_inputs and revert init_prepare
SamDuffield Feb 7, 2026
b7aafa6
Merge branch 'main' into factorial
SamDuffield Feb 24, 2026
8ecbab3
Remove cond import
SamDuffield Feb 24, 2026
7afb647
Remove readme edit
SamDuffield Feb 24, 2026
baf55f3
Fix filtering
SamDuffield Feb 24, 2026
7c808f8
Remove old comment
SamDuffield Feb 24, 2026
a8ac8c0
Update README
SamDuffield Feb 26, 2026
90591dd
Start testing smoother
SamDuffield Feb 26, 2026
dbae9ea
Merge branch 'main' into factorial
SamDuffield Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions cuthbert/factorial/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Factorial State-Space Models

A factorial state-space model is a state-space model where the dynamics distribution
factors into a product of independent distributions across factors

$$
p(x_t \mid x_{t-1}) = \prod_{f=1}^F p(x_t^f \mid x_{t-1}^f),
$$
for factorial index $f \in \{1, \ldots, F\}$. We additionally assume that observations
act locally on some subset of factors $S_t \subseteq \{1, \ldots, F\}$.

$$
p(y_t \mid x_t) = p(y_t \mid x_t^{S_t}).
$$

This motivates a factored approximation of filtering and smoothing distributions, e.g.

$$
p(x_t \mid y_{1:t}) = \prod_{f=1}^F p(x_t^f \mid y_{1:t}).
$$

A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035).

The factorial approximation allows us to exploit significant benefits in terms of
memory, compute and parallelization.

Note that although the dynamics are factorized, `cuthbert` does not differentiate
between `predict` and `update` (instead favouring a unified filter operation
via `filter_prepare` and `filter_combine`). Thus the dynamics and model inputs
should be specified to act on the joint local state (i.e. block diagonal
where appropriate).


## Factorial filtering with `cuthbert`

Filtering in a factorial state-space model is similar to standard filtering, but with
an additional step before the filtering operation to extract the relevant
factors as well as an additional step after the filtering operation to insert the
updated factors back into the factorial state.


```python
from jax import tree
import cuthbert

# Define model_inputs
model_inputs = ...

# Define function to extract the factorial indices from model inputs
# Here we assume model_inputs is a NamedTuple with a field `factorial_inds`
get_factorial_indices = lambda mi: mi.factorial_inds

# Build factorializer for the inference method
factorializer = cuthbert.factorial.gaussian.build_factorializer(get_factorial_indices)

# Load inference method, with parameter extraction functions defined for factorial inference
kalman_filter = cuthbert.gaussian.kalman.build_filter(
get_init_params=get_init_params, # Init specified to generate factorial state
get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state
get_observation_params=get_observation_params, # Observation specified to act on joint local state
)

# Online inference
factorial_state = kalman_filter.init_prepare(tree.map(lambda x: x[0], model_inputs))

for t in range(1, T):
model_inputs_t = tree.map(lambda x: x[t], model_inputs)
factorial_inds = get_factorial_indices(model_inputs_t)
local_state = factorializer.extract_and_join(factorial_state, factorial_inds)
prepare_state = kalman_filter.filter_prepare(model_inputs_t)
filtered_local_state = kalman_filter.filter_combine(local_state, prepare_state)
factorial_state = factorializer.marginalize_and_insert(
filtered_local_state, factorial_state, factorial_inds
)
```

You can also use `cuthbert.factorial.filter` for convenient offline filtering.
Note that associative/parallel filtering is not supported for factorial filtering.

```python
init_factorial_state, local_filter_states = cuthbert.factorial.filter(
kalman_filter, factorializer, model_inputs, output_factorial=False
)
```

## Factorial smoothing with `cuthbert`

Smoothing in factorial state-space models can be performed embarrassingly parallel
across factors since the dynamics and factorial approximation are independent
across factors (the observations are fully absorbed in the filtering and
are not accessed during smoothing).

The model inputs and filter states require some preprocessing to convert from being
single sequence with each state containing all factors into a sequence or multiple
sequences with each state corresponding to a single factor. This can be
fiddly but is left to the user for maximum freedom. Oftentimes, it is easiest to
specify different parameter functions for smoothing than filtering.

After this preprocessing, smoothing can be performed as usual:

```python
# Define model_inputs for a single factor
model_inputs_single_factor = ...

# Similarly, we need to extract the filter states for the single factor we're smoothing.
filter_states_single_factor = ...

# Load smoother, with parameter extraction functions defined for a single factor
kalman_smoother = cuthbert.gaussian.kalman.build_smoother(
get_dynamics_params=get_dynamics_params, # Dynamics specified to act on a single factor
)

smoother_state = kalman_smoother.convert_filter_to_smoother_state(
tree.map(lambda x: x[-1], filter_states_single_factor),
model_inputs=tree.map(lambda x: x[-1], model_inputs_single_factor),
)

for t in range(T - 1, -1, -1):
model_inputs_single_factor_t = tree.map(lambda x: x[t], model_inputs_single_factor)
filter_state_single_factor_t = tree.map(lambda x: x[t], filter_states_single_factor)
prepare_state = kalman_smoother.smoother_prepare(
filter_state_single_factor_t, model_inputs_single_factor_t
)
smoother_state = kalman_smoother.smoother_combine(prepare_state, smoother_state)
```

Or directly using the `cuthbert.smoother`:

```python
smoother_states = cuthbert.smoother(
kalman_smoother, filter_states_single_factor, model_inputs_single_factor
)
```
8 changes: 8 additions & 0 deletions cuthbert/factorial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cuthbert.factorial import gaussian
from cuthbert.factorial.filtering import filter
from cuthbert.factorial.types import (
ExtractAndJoin,
Factorializer,
GetFactorialIndices,
MarginalizeAndInsert,
)
110 changes: 110 additions & 0 deletions cuthbert/factorial/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""cuthbert factorial filtering interface."""

from jax import numpy as jnp
from jax import random, tree
from jax.lax import scan

from cuthbert.factorial.types import Factorializer
from cuthbert.inference import Filter
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray


def filter(
filter_obj: Filter,
factorializer: Factorializer,
model_inputs: ArrayTreeLike,
output_factorial: bool = False,
key: KeyArray | None = None,
) -> (
ArrayTree | tuple[ArrayTree, ArrayTree]
): # TODO: Can overload this function so the type checker knows that the output is a ArrayTree if output_factorial is True and a tuple[ArrayTree, ArrayTree] if output_factorial is False
"""Applies offline factorial filtering for given model inputs.

`model_inputs` should have leading temporal dimension of length T + 1,
where T is the number of time steps excluding the initial state.

Parallel associative filtering is not supported for factorial filtering.

Note that if output_factorial is True, this function will output a factorial state
with first temporal dimension of length T + 1 and second factorial dimension of
length F. Many of the factors will be unchanged across timesteps where they aren't
relevant.

Args:
filter_obj: The filter inference object.
factorializer: The factorializer object for the inference method.
model_inputs: The model inputs (with leading temporal dimension of length T + 1).
output_factorial: If True, return a single state with first temporal dimension
of length T + 1 and second factorial dimension of length F.
If False, return a tuple of states. The first being the initial state
with first dimension of length F and temporal dimension.
The second being the local states for each time step, i.e. first
dimension of length T and no factorial dimension.
key: The key for the random number generator.

Returns:
The filtered states (NamedTuple with leading temporal dimension of length T + 1).
"""
T = tree.leaves(model_inputs)[0].shape[0] - 1

if key is None:
# This will throw error if used as a key, which is desired behavior
# (albeit not a useful error, we could improve this)
prepare_keys = jnp.empty(T + 1)
else:
prepare_keys = random.split(key, T + 1)

init_model_input = tree.map(lambda x: x[0], model_inputs)
init_factorial_state = filter_obj.init_prepare(
init_model_input, key=prepare_keys[0]
)

prep_model_inputs = tree.map(lambda x: x[1:], model_inputs)

def body_local(prev_factorial_state, prep_inp_and_k):
prep_inp, k = prep_inp_and_k
factorial_inds = factorializer.get_factorial_indices(prep_inp)
local_state = factorializer.extract_and_join(
prev_factorial_state, factorial_inds
)
prep_state = filter_obj.filter_prepare(prep_inp, key=k)
filtered_joint_state = filter_obj.filter_combine(local_state, prep_state)
factorial_state = factorializer.marginalize_and_insert(
filtered_joint_state, prev_factorial_state, factorial_inds
)

def extract(arr):
if arr.ndim >= 2:
return arr[factorial_inds]
else:
return arr

factorial_state_fac_inds_only = tree.map(extract, factorial_state)
return factorial_state, factorial_state_fac_inds_only

if output_factorial:

def body_factorial(prev_factorial_state, prep_inp_and_k):
factorial_state, _ = body_local(prev_factorial_state, prep_inp_and_k)
return factorial_state, factorial_state

_, factorial_states = scan(
body_factorial,
init_factorial_state,
(prep_model_inputs, prepare_keys[1:]),
)
factorial_states = tree.map(
lambda x, y: jnp.concatenate([x[None], y]),
init_factorial_state,
factorial_states,
)

return factorial_states

else:
_, local_states = scan(
body_local,
init_factorial_state,
(prep_model_inputs, prepare_keys[1:]),
)
return init_factorial_state, local_states
Loading
Loading