Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cuthbert/smc/backward_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,19 @@ def convert_filter_to_smoother_state(
dummy_model_inputs = dummy_tree_like(model_inputs)

key, resampling_key = random.split(key)
indices = resampling(resampling_key, filter_state.log_weights, n_smoother_particles)
indices, log_weights, particles = resampling(
resampling_key,
filter_state.log_weights,
filter_state.particles,
n_smoother_particles,
)

return ParticleSmootherState(
key=cast(KeyArray, key),
particles=jax.tree.map(lambda z: z[indices], filter_state.particles),
particles=particles,
ancestor_indices=filter_state.ancestor_indices[indices],
model_inputs=dummy_model_inputs,
log_weights=-jnp.log(n_smoother_particles) * jnp.ones(n_smoother_particles),
log_weights=log_weights,
)


Expand Down
23 changes: 7 additions & 16 deletions cuthbert/smc/marginal_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
from cuthbert.utils import dummy_tree_like
from cuthbertlib.resampling import Resampling
from cuthbertlib.smc.ess import log_ess
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray, ScalarArray


class MarginalParticleFilterState(NamedTuple):
Expand All @@ -35,7 +34,6 @@ def build_filter(
log_potential: LogPotential,
n_filter_particles: int,
resampling_fn: Resampling,
ess_threshold: float,
) -> Filter:
r"""Builds a marginal particle filter object.

Expand All @@ -45,9 +43,9 @@ def build_filter(
log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
n_filter_particles: Number of particles for the filter.
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
ess_threshold: Fraction of particle count specifying when to resample.
Resampling is triggered when the
effective sample size (ESS) < ess_threshold * n_filter_particles.
The resampling function may be decorated with adaptive behaviour
(using cuthbertlib.resampling.adaptive.adaptive_resampling_decorator)
before being passed to the filter.

Returns:
Filter object for the particle filter.
Expand All @@ -66,7 +64,6 @@ def build_filter(
propagate_sample=propagate_sample,
log_potential=log_potential,
resampling_fn=resampling_fn,
ess_threshold=ess_threshold,
),
associative=False,
)
Expand Down Expand Up @@ -161,7 +158,6 @@ def filter_combine(
propagate_sample: PropagateSample,
log_potential: LogPotential,
resampling_fn: Resampling,
ess_threshold: float,
) -> MarginalParticleFilterState:
"""Combine previous filter state with the state prepared for the current step.

Expand All @@ -175,8 +171,6 @@ def filter_combine(
propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
ess_threshold: Fraction of particle count specifying when to resample.
Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.

Returns:
The filtered state at the current time step.
Expand All @@ -188,13 +182,10 @@ def filter_combine(
prev_log_weights = state_1.log_weights - jax.nn.logsumexp(
state_1.log_weights
) # Ensure normalized
ancestor_indices, log_weights = jax.lax.cond(
log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
lambda: (jnp.arange(N), state_1.log_weights),
)
ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)

ancestor_indices, log_weights, ancestors = resampling_fn(
keys[0], state_1.log_weights, state_1.particles, N
)
# Propagate
next_particles = jax.vmap(propagate_sample, (0, 0, None))(
keys[1:], ancestors, state_2.model_inputs
Expand Down
25 changes: 8 additions & 17 deletions cuthbert/smc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@

import jax
import jax.numpy as jnp
from jax import Array, random, tree
from jax import random, tree

from cuthbert.inference import Filter
from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
from cuthbert.utils import dummy_tree_like
from cuthbertlib.resampling import Resampling
from cuthbertlib.smc.ess import log_ess
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray, ScalarArray


class ParticleFilterState(NamedTuple):
Expand All @@ -40,7 +39,6 @@ def build_filter(
log_potential: LogPotential,
n_filter_particles: int,
resampling_fn: Resampling,
ess_threshold: float,
) -> Filter:
r"""Builds a particle filter object.

Expand All @@ -50,9 +48,9 @@ def build_filter(
log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
n_filter_particles: Number of particles for the filter.
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
ess_threshold: Fraction of particle count specifying when to resample.
Resampling is triggered when the
effective sample size (ESS) < ess_threshold * n_filter_particles.
The resampling function may be decorated with adaptive behaviour
(using cuthbertlib.resampling.adaptive.adaptive_resampling_decorator)
before being passed to the filter.

Returns:
Filter object for the particle filter.
Expand All @@ -73,7 +71,6 @@ def build_filter(
propagate_sample=propagate_sample,
log_potential=log_potential,
resampling_fn=resampling_fn,
ess_threshold=ess_threshold,
),
associative=False,
)
Expand Down Expand Up @@ -170,7 +167,6 @@ def filter_combine(
propagate_sample: PropagateSample,
log_potential: LogPotential,
resampling_fn: Resampling,
ess_threshold: float,
) -> ParticleFilterState:
"""Combine previous filter state with the state prepared for the current step.

Expand All @@ -183,22 +179,17 @@ def filter_combine(
propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
ess_threshold: Fraction of particle count specifying when to resample.
Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.

Returns:
The filtered state at the current time step.
"""
N = state_1.log_weights.shape[0]
keys = random.split(state_1.key, N + 1)

# Resample
ancestor_indices, log_weights = jax.lax.cond(
log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
lambda: (jnp.arange(N), state_1.log_weights),
# Resample - resampling_fn is expected to handle adaptivity if desired
ancestor_indices, log_weights, ancestors = resampling_fn(
keys[0], state_1.log_weights, state_1.particles, N
)
ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)

# Propagate
next_particles = jax.vmap(propagate_sample, (0, 0, None))(
Expand Down
21 changes: 16 additions & 5 deletions cuthbertlib/resampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@ sampling_key, resampling_key = jax.random.split(jax.random.key(0))
particles = jax.random.normal(sampling_key, (100, 2))
logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles)

resampled_indices = resampling.multinomial.resampling(resampling_key, logits, 100)
resampled_particles = particles[resampled_indices]
resampled_indices, _, resampled_particles = resampling.multinomial.resampling(resampling_key, logits, particles, 100)
```

Or for conditional resampling:

```python
# Here we resample but keep particle at index 0 fixed
conditional_resampled_indices = resampling.multinomial.conditional_resampling(
resampling_key, logits, 100, pivot_in=0, pivot_out=0
conditional_resampled_indices, _, conditional_resampled_particles = resampling.multinomial.conditional_resampling(
resampling_key, logits, particles, 100, pivot_in=0, pivot_out=0
)
conditional_resampled_particles = particles[conditional_resampled_indices]
```

Adaptive resampling (i.e. resampling only when the effective sample size is below a
threshold) is also supported via a decorator:

```python
adaptive_resampling = resampling.adaptive.ess_decorator(
resampling.multinomial.resampling,
threshold=0.5,
)
adaptive_resampled_indices, _, adaptive_resampled_particles = adaptive_resampling(
resampling_key, logits, particles, 100
)
```
3 changes: 2 additions & 1 deletion cuthbertlib/resampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from cuthbertlib.resampling import killing, multinomial, systematic
from cuthbertlib.resampling import adaptive, killing, multinomial, systematic
from cuthbertlib.resampling.adaptive import ess_decorator
from cuthbertlib.resampling.protocols import ConditionalResampling, Resampling
from cuthbertlib.resampling.utils import inverse_cdf
72 changes: 72 additions & 0 deletions cuthbertlib/resampling/adaptive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Adaptive resampling decorator.

Provides a decorator to turn any Resampling function into an adaptive resampling
function which performs resampling only when the effective sample size (ESS)
falls below a threshold.
"""

from functools import wraps

import jax
import jax.numpy as jnp

from cuthbertlib.resampling.protocols import Resampling
from cuthbertlib.smc.ess import log_ess
from cuthbertlib.types import Array, ArrayLike, ArrayTree, ArrayTreeLike


def ess_decorator(func: Resampling, threshold: float) -> Resampling:
"""Wrap a Resampling function so that it only resamples when ESS < threshold.

The returned function is jitted and has `n` as a static argument. The
original resampler's docstring is appended to this wrapper's docstring so
IDEs and users can see the underlying algorithm documentation.

Args:
func: A resampling function with signature
(key, logits, positions, n) -> (indices, logits_out, positions_out).
threshold: Fraction of particle count specifying when to resample.
Resampling is triggered when ESS < ess_threshold * n.

Returns:
A Resampling function implementing adaptive resampling.
"""
# Build a descriptive docstring that includes the wrapped function doc
wrapped_doc = func.__doc__ or ""
doc = f"""
Adaptive resampling decorator (threshold={threshold}).

This wrapper will call the provided resampling function only when the
effective sample size (ESS) is below `ess_threshold * n`.

Wrapped resampler documentation:
{wrapped_doc}
"""

@wraps(func)
def _wrapped(
key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
logits_arr = jnp.asarray(logits)
N = logits_arr.shape[0]
if n != N:
raise AssertionError(
"The number of sampled indices must be equal to the number of "
f"particles for `adaptive` resampling. Got {n} instead of {N}."
)

def _do_resample():
return func(key, logits_arr, positions, n)

def _no_resample():
return jnp.arange(n), logits_arr, positions

return jax.lax.cond(
log_ess(logits_arr) < jnp.log(threshold * n),
_do_resample,
_no_resample,
)

# Attach the composed docstring and return a jitted version
_wrapped.__doc__ = doc
return jax.jit(_wrapped, static_argnames=("n",))
34 changes: 24 additions & 10 deletions cuthbertlib/resampling/killing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
conditional_resampling_decorator,
resampling_decorator,
)
from cuthbertlib.types import Array, ArrayLike, ScalarArrayLike
from cuthbertlib.resampling.utils import apply_resampling_indices
from cuthbertlib.types import (
Array,
ArrayLike,
ArrayTree,
ArrayTreeLike,
ScalarArrayLike,
)

_DESCRIPTION = """
The Killing resampling is a simple resampling mechanism that checks if
Expand All @@ -28,7 +35,9 @@


@partial(resampling_decorator, name="Killing", desc=_DESCRIPTION)
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
def resampling(
key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
logits = jnp.asarray(logits)
key_1, key_2 = random.split(key)
N = logits.shape[0]
Expand All @@ -43,27 +52,30 @@ def resampling(key: Array, logits: ArrayLike, n: int) -> Array:

survived = log_uniforms <= logits - max_logit
if_survived = jnp.arange(N) # If the particle survives, it keeps its index
otherwise = multinomial.resampling(
key_2, logits, N
otherwise_idx, _, _ = multinomial.resampling(
key_2, logits, positions, N
) # otherwise, it is replaced by another particle
idx = jnp.where(survived, if_survived, otherwise)
return idx
idx = jnp.where(survived, if_survived, otherwise_idx)
# After resampling, all particles have equal weight
logits_out = jnp.zeros_like(logits)
return idx, logits_out, apply_resampling_indices(positions, idx)


@partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION)
def conditional_resampling(
key: Array,
logits: ArrayLike,
positions: ArrayTreeLike,
n: int,
pivot_in: ScalarArrayLike,
pivot_out: ScalarArrayLike,
) -> Array:
) -> tuple[Array, Array, ArrayTree]:
pivot_in = jnp.asarray(pivot_in)
pivot_out = jnp.asarray(pivot_out)

# Unconditional resampling
key_resample, key_shuffle = random.split(key)
idx = resampling(key_resample, logits, n)
idx_uncond, _, _ = resampling(key_resample, logits, positions, n)

# Conditional rolling pivot
max_logit = jnp.max(logits)
Expand All @@ -76,9 +88,11 @@ def conditional_resampling(

pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits))
pivot = random.choice(key_shuffle, n, p=pivot_weights)
idx = jnp.roll(idx, pivot_in - pivot)
idx = jnp.roll(idx_uncond, pivot_in - pivot)
idx = idx.at[pivot_in].set(pivot_out)
return idx
# After resampling, all particles have equal weight
logits_out = jnp.zeros_like(logits)
return idx, logits_out, apply_resampling_indices(positions, idx)


def _log1mexp(x: ArrayLike) -> Array:
Expand Down
Loading
Loading