diff --git a/cuthbert/smc/backward_sampler.py b/cuthbert/smc/backward_sampler.py index a8e4da3d..9913404b 100644 --- a/cuthbert/smc/backward_sampler.py +++ b/cuthbert/smc/backward_sampler.py @@ -108,14 +108,19 @@ def convert_filter_to_smoother_state( dummy_model_inputs = dummy_tree_like(model_inputs) key, resampling_key = random.split(key) - indices = resampling(resampling_key, filter_state.log_weights, n_smoother_particles) + indices, log_weights, particles = resampling( + resampling_key, + filter_state.log_weights, + filter_state.particles, + n_smoother_particles, + ) return ParticleSmootherState( key=cast(KeyArray, key), - particles=jax.tree.map(lambda z: z[indices], filter_state.particles), + particles=particles, ancestor_indices=filter_state.ancestor_indices[indices], model_inputs=dummy_model_inputs, - log_weights=-jnp.log(n_smoother_particles) * jnp.ones(n_smoother_particles), + log_weights=log_weights, ) diff --git a/cuthbert/smc/marginal_particle_filter.py b/cuthbert/smc/marginal_particle_filter.py index 840f304d..afdfbb93 100644 --- a/cuthbert/smc/marginal_particle_filter.py +++ b/cuthbert/smc/marginal_particle_filter.py @@ -14,8 +14,7 @@ from cuthbert.smc.types import InitSample, LogPotential, PropagateSample from cuthbert.utils import dummy_tree_like from cuthbertlib.resampling import Resampling -from cuthbertlib.smc.ess import log_ess -from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray +from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray, ScalarArray class MarginalParticleFilterState(NamedTuple): @@ -35,7 +34,6 @@ def build_filter( log_potential: LogPotential, n_filter_particles: int, resampling_fn: Resampling, - ess_threshold: float, ) -> Filter: r"""Builds a marginal particle filter object. @@ -45,9 +43,9 @@ def build_filter( log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$. n_filter_particles: Number of particles for the filter. resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial). - ess_threshold: Fraction of particle count specifying when to resample. - Resampling is triggered when the - effective sample size (ESS) < ess_threshold * n_filter_particles. + The resampling function may be decorated with adaptive behaviour + (using cuthbertlib.resampling.adaptive.adaptive_resampling_decorator) + before being passed to the filter. Returns: Filter object for the particle filter. @@ -66,7 +64,6 @@ def build_filter( propagate_sample=propagate_sample, log_potential=log_potential, resampling_fn=resampling_fn, - ess_threshold=ess_threshold, ), associative=False, ) @@ -161,7 +158,6 @@ def filter_combine( propagate_sample: PropagateSample, log_potential: LogPotential, resampling_fn: Resampling, - ess_threshold: float, ) -> MarginalParticleFilterState: """Combine previous filter state with the state prepared for the current step. @@ -175,8 +171,6 @@ def filter_combine( propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}). log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t). resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial). - ess_threshold: Fraction of particle count specifying when to resample. - Resampling is triggered when the effective sample size (ESS) < ess_threshold * N. Returns: The filtered state at the current time step. @@ -188,13 +182,10 @@ def filter_combine( prev_log_weights = state_1.log_weights - jax.nn.logsumexp( state_1.log_weights ) # Ensure normalized - ancestor_indices, log_weights = jax.lax.cond( - log_ess(state_1.log_weights) < jnp.log(ess_threshold * N), - lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)), - lambda: (jnp.arange(N), state_1.log_weights), - ) - ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles) + ancestor_indices, log_weights, ancestors = resampling_fn( + keys[0], state_1.log_weights, state_1.particles, N + ) # Propagate next_particles = jax.vmap(propagate_sample, (0, 0, None))( keys[1:], ancestors, state_2.model_inputs diff --git a/cuthbert/smc/particle_filter.py b/cuthbert/smc/particle_filter.py index 07776d7b..d445af44 100644 --- a/cuthbert/smc/particle_filter.py +++ b/cuthbert/smc/particle_filter.py @@ -8,14 +8,13 @@ import jax import jax.numpy as jnp -from jax import Array, random, tree +from jax import random, tree from cuthbert.inference import Filter from cuthbert.smc.types import InitSample, LogPotential, PropagateSample from cuthbert.utils import dummy_tree_like from cuthbertlib.resampling import Resampling -from cuthbertlib.smc.ess import log_ess -from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray +from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray, ScalarArray class ParticleFilterState(NamedTuple): @@ -40,7 +39,6 @@ def build_filter( log_potential: LogPotential, n_filter_particles: int, resampling_fn: Resampling, - ess_threshold: float, ) -> Filter: r"""Builds a particle filter object. @@ -50,9 +48,9 @@ def build_filter( log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$. n_filter_particles: Number of particles for the filter. resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial). - ess_threshold: Fraction of particle count specifying when to resample. - Resampling is triggered when the - effective sample size (ESS) < ess_threshold * n_filter_particles. + The resampling function may be decorated with adaptive behaviour + (using cuthbertlib.resampling.adaptive.adaptive_resampling_decorator) + before being passed to the filter. Returns: Filter object for the particle filter. @@ -73,7 +71,6 @@ def build_filter( propagate_sample=propagate_sample, log_potential=log_potential, resampling_fn=resampling_fn, - ess_threshold=ess_threshold, ), associative=False, ) @@ -170,7 +167,6 @@ def filter_combine( propagate_sample: PropagateSample, log_potential: LogPotential, resampling_fn: Resampling, - ess_threshold: float, ) -> ParticleFilterState: """Combine previous filter state with the state prepared for the current step. @@ -183,8 +179,6 @@ def filter_combine( propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}). log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t). resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial). - ess_threshold: Fraction of particle count specifying when to resample. - Resampling is triggered when the effective sample size (ESS) < ess_threshold * N. Returns: The filtered state at the current time step. @@ -192,13 +186,10 @@ def filter_combine( N = state_1.log_weights.shape[0] keys = random.split(state_1.key, N + 1) - # Resample - ancestor_indices, log_weights = jax.lax.cond( - log_ess(state_1.log_weights) < jnp.log(ess_threshold * N), - lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)), - lambda: (jnp.arange(N), state_1.log_weights), + # Resample - resampling_fn is expected to handle adaptivity if desired + ancestor_indices, log_weights, ancestors = resampling_fn( + keys[0], state_1.log_weights, state_1.particles, N ) - ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles) # Propagate next_particles = jax.vmap(propagate_sample, (0, 0, None))( diff --git a/cuthbertlib/resampling/README.md b/cuthbertlib/resampling/README.md index b49d8d28..b532744c 100644 --- a/cuthbertlib/resampling/README.md +++ b/cuthbertlib/resampling/README.md @@ -11,16 +11,27 @@ sampling_key, resampling_key = jax.random.split(jax.random.key(0)) particles = jax.random.normal(sampling_key, (100, 2)) logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles) -resampled_indices = resampling.multinomial.resampling(resampling_key, logits, 100) -resampled_particles = particles[resampled_indices] +resampled_indices, _, resampled_particles = resampling.multinomial.resampling(resampling_key, logits, particles, 100) ``` Or for conditional resampling: ```python # Here we resample but keep particle at index 0 fixed -conditional_resampled_indices = resampling.multinomial.conditional_resampling( - resampling_key, logits, 100, pivot_in=0, pivot_out=0 +conditional_resampled_indices, _, conditional_resampled_particles = resampling.multinomial.conditional_resampling( + resampling_key, logits, particles, 100, pivot_in=0, pivot_out=0 ) -conditional_resampled_particles = particles[conditional_resampled_indices] ``` + +Adaptive resampling (i.e. resampling only when the effective sample size is below a +threshold) is also supported via a decorator: + +```python +adaptive_resampling = resampling.adaptive.ess_decorator( + resampling.multinomial.resampling, + threshold=0.5, +) +adaptive_resampled_indices, _, adaptive_resampled_particles = adaptive_resampling( + resampling_key, logits, particles, 100 +) +``` \ No newline at end of file diff --git a/cuthbertlib/resampling/__init__.py b/cuthbertlib/resampling/__init__.py index bafa097e..b23d83a5 100644 --- a/cuthbertlib/resampling/__init__.py +++ b/cuthbertlib/resampling/__init__.py @@ -1,3 +1,4 @@ -from cuthbertlib.resampling import killing, multinomial, systematic +from cuthbertlib.resampling import adaptive, killing, multinomial, systematic +from cuthbertlib.resampling.adaptive import ess_decorator from cuthbertlib.resampling.protocols import ConditionalResampling, Resampling from cuthbertlib.resampling.utils import inverse_cdf diff --git a/cuthbertlib/resampling/adaptive.py b/cuthbertlib/resampling/adaptive.py new file mode 100644 index 00000000..3dc1cc75 --- /dev/null +++ b/cuthbertlib/resampling/adaptive.py @@ -0,0 +1,72 @@ +"""Adaptive resampling decorator. + +Provides a decorator to turn any Resampling function into an adaptive resampling +function which performs resampling only when the effective sample size (ESS) +falls below a threshold. +""" + +from functools import wraps + +import jax +import jax.numpy as jnp + +from cuthbertlib.resampling.protocols import Resampling +from cuthbertlib.smc.ess import log_ess +from cuthbertlib.types import Array, ArrayLike, ArrayTree, ArrayTreeLike + + +def ess_decorator(func: Resampling, threshold: float) -> Resampling: + """Wrap a Resampling function so that it only resamples when ESS < threshold. + + The returned function is jitted and has `n` as a static argument. The + original resampler's docstring is appended to this wrapper's docstring so + IDEs and users can see the underlying algorithm documentation. + + Args: + func: A resampling function with signature + (key, logits, positions, n) -> (indices, logits_out, positions_out). + threshold: Fraction of particle count specifying when to resample. + Resampling is triggered when ESS < ess_threshold * n. + + Returns: + A Resampling function implementing adaptive resampling. + """ + # Build a descriptive docstring that includes the wrapped function doc + wrapped_doc = func.__doc__ or "" + doc = f""" + Adaptive resampling decorator (threshold={threshold}). + + This wrapper will call the provided resampling function only when the + effective sample size (ESS) is below `ess_threshold * n`. + + Wrapped resampler documentation: + {wrapped_doc} + """ + + @wraps(func) + def _wrapped( + key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int + ) -> tuple[Array, Array, ArrayTree]: + logits_arr = jnp.asarray(logits) + N = logits_arr.shape[0] + if n != N: + raise AssertionError( + "The number of sampled indices must be equal to the number of " + f"particles for `adaptive` resampling. Got {n} instead of {N}." + ) + + def _do_resample(): + return func(key, logits_arr, positions, n) + + def _no_resample(): + return jnp.arange(n), logits_arr, positions + + return jax.lax.cond( + log_ess(logits_arr) < jnp.log(threshold * n), + _do_resample, + _no_resample, + ) + + # Attach the composed docstring and return a jitted version + _wrapped.__doc__ = doc + return jax.jit(_wrapped, static_argnames=("n",)) diff --git a/cuthbertlib/resampling/killing.py b/cuthbertlib/resampling/killing.py index 11b025ad..131cd095 100644 --- a/cuthbertlib/resampling/killing.py +++ b/cuthbertlib/resampling/killing.py @@ -11,7 +11,14 @@ conditional_resampling_decorator, resampling_decorator, ) -from cuthbertlib.types import Array, ArrayLike, ScalarArrayLike +from cuthbertlib.resampling.utils import apply_resampling_indices +from cuthbertlib.types import ( + Array, + ArrayLike, + ArrayTree, + ArrayTreeLike, + ScalarArrayLike, +) _DESCRIPTION = """ The Killing resampling is a simple resampling mechanism that checks if @@ -28,7 +35,9 @@ @partial(resampling_decorator, name="Killing", desc=_DESCRIPTION) -def resampling(key: Array, logits: ArrayLike, n: int) -> Array: +def resampling( + key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int +) -> tuple[Array, Array, ArrayTree]: logits = jnp.asarray(logits) key_1, key_2 = random.split(key) N = logits.shape[0] @@ -43,27 +52,30 @@ def resampling(key: Array, logits: ArrayLike, n: int) -> Array: survived = log_uniforms <= logits - max_logit if_survived = jnp.arange(N) # If the particle survives, it keeps its index - otherwise = multinomial.resampling( - key_2, logits, N + otherwise_idx, _, _ = multinomial.resampling( + key_2, logits, positions, N ) # otherwise, it is replaced by another particle - idx = jnp.where(survived, if_survived, otherwise) - return idx + idx = jnp.where(survived, if_survived, otherwise_idx) + # After resampling, all particles have equal weight + logits_out = jnp.zeros_like(logits) + return idx, logits_out, apply_resampling_indices(positions, idx) @partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION) def conditional_resampling( key: Array, logits: ArrayLike, + positions: ArrayTreeLike, n: int, pivot_in: ScalarArrayLike, pivot_out: ScalarArrayLike, -) -> Array: +) -> tuple[Array, Array, ArrayTree]: pivot_in = jnp.asarray(pivot_in) pivot_out = jnp.asarray(pivot_out) # Unconditional resampling key_resample, key_shuffle = random.split(key) - idx = resampling(key_resample, logits, n) + idx_uncond, _, _ = resampling(key_resample, logits, positions, n) # Conditional rolling pivot max_logit = jnp.max(logits) @@ -76,9 +88,11 @@ def conditional_resampling( pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits)) pivot = random.choice(key_shuffle, n, p=pivot_weights) - idx = jnp.roll(idx, pivot_in - pivot) + idx = jnp.roll(idx_uncond, pivot_in - pivot) idx = idx.at[pivot_in].set(pivot_out) - return idx + # After resampling, all particles have equal weight + logits_out = jnp.zeros_like(logits) + return idx, logits_out, apply_resampling_indices(positions, idx) def _log1mexp(x: ArrayLike) -> Array: diff --git a/cuthbertlib/resampling/multinomial.py b/cuthbertlib/resampling/multinomial.py index 26080c26..dd93a585 100644 --- a/cuthbertlib/resampling/multinomial.py +++ b/cuthbertlib/resampling/multinomial.py @@ -10,8 +10,14 @@ conditional_resampling_decorator, resampling_decorator, ) -from cuthbertlib.resampling.utils import inverse_cdf -from cuthbertlib.types import Array, ArrayLike, ScalarArrayLike +from cuthbertlib.resampling.utils import apply_resampling_indices, inverse_cdf +from cuthbertlib.types import ( + Array, + ArrayLike, + ArrayTree, + ArrayTreeLike, + ScalarArrayLike, +) _DESCRIPTION = """ This has higher variance than other resampling schemes as it samples from @@ -21,7 +27,9 @@ @partial(resampling_decorator, name="Multinomial", desc=_DESCRIPTION) -def resampling(key: Array, logits: ArrayLike, n: int) -> Array: +def resampling( + key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int +) -> tuple[Array, Array, ArrayTree]: # In practice we don't have to sort the generated uniforms, but searchsorted # works faster and is more stable if both inputs are sorted, so we use the # _sorted_uniforms from N. Chopin, but still use searchsorted instead of his @@ -32,23 +40,26 @@ def resampling(key: Array, logits: ArrayLike, n: int) -> Array: key_uniforms, key_shuffle = random.split(key) sorted_uniforms = _sorted_uniforms(key_uniforms, n) idx = inverse_cdf(sorted_uniforms, logits) - return random.permutation(key_shuffle, idx) + idx = random.permutation(key_shuffle, idx) + logits_out = jnp.zeros_like(sorted_uniforms) + return idx, logits_out, apply_resampling_indices(positions, idx) @partial(conditional_resampling_decorator, name="Multinomial", desc=_DESCRIPTION) def conditional_resampling( key: Array, logits: ArrayLike, + positions: ArrayTreeLike, n: int, pivot_in: ScalarArrayLike, pivot_out: ScalarArrayLike, -) -> Array: +) -> tuple[Array, Array, ArrayTree]: pivot_in = jnp.asarray(pivot_in) pivot_out = jnp.asarray(pivot_out) - idx = resampling(key, logits, n) + idx, logits_out, _ = resampling(key, logits, positions, n) idx = idx.at[pivot_in].set(pivot_out) - return idx + return idx, logits_out, apply_resampling_indices(positions, idx) @partial(jax.jit, static_argnames=("n",)) diff --git a/cuthbertlib/resampling/protocols.py b/cuthbertlib/resampling/protocols.py index 85cd3901..68312224 100644 --- a/cuthbertlib/resampling/protocols.py +++ b/cuthbertlib/resampling/protocols.py @@ -4,23 +4,53 @@ import jax -from cuthbertlib.types import Array, ArrayLike, KeyArray, ScalarArrayLike +from cuthbertlib.types import ( + Array, + ArrayLike, + ArrayTree, + ArrayTreeLike, + KeyArray, + ScalarArrayLike, +) + +_RESAMPLING_DOC = """ +Args: + key: JAX PRNG key. + logits: Logits. + positions: ArrayTreeLike + n: Number of indices to sample. + +Returns: + ancestors: Array of resampling indices. + logits: Array of log-weights after resampling. + positions: ArrayTreeLike of resampled positions. +""" + +_CONDITIONAL_RESAMPLING_DOC = """ +Args: + key: JAX PRNG key. + logits: Log-weights, possibly unnormalized. + positions: ArrayTreeLike + n: Number of indices to sample. + pivot_in: Index of the particle to keep. + pivot_out: Value of the output at index `pivot_in`. + +Returns: + ancestors: Array of size n with indices to use for resampling. + logits: Array of log-weights after resampling. + positions: ArrayTreeLike of resampled positions. +""" @runtime_checkable class Resampling(Protocol): """Protocol for resampling operations.""" - def __call__(self, key: KeyArray, logits: ArrayLike, n: int) -> Array: - """Computes resampling indices according to given logits. - - Args: - key: JAX PRNG key. - logits: Logits. - n: Number of indices to sample. - - Returns: - Array of resampling indices. + def __call__( + self, key: KeyArray, logits: ArrayLike, positions: ArrayTreeLike, n: int + ) -> tuple[Array, Array, ArrayTree]: + f"""Computes resampling indices according to given logits. + {_RESAMPLING_DOC} """ ... @@ -33,21 +63,13 @@ def __call__( self, key: KeyArray, logits: ArrayLike, + positions: ArrayTreeLike, n: int, pivot_in: ScalarArrayLike, pivot_out: ScalarArrayLike, - ) -> Array: - """Conditional resampling. - - Args: - key: JAX PRNG key. - logits: Log-weights, possibly unnormalized. - n: Number of indices to sample. - pivot_in: Index of the particle to keep. - pivot_out: Value of the output at index `pivot_in`. - - Returns: - Array of size n with indices to use for resampling. + ) -> tuple[Array, Array, ArrayTree]: + f"""Conditional resampling. + {_CONDITIONAL_RESAMPLING_DOC} """ ... @@ -56,14 +78,7 @@ def resampling_decorator(func: Resampling, name: str, desc: str = "") -> Resampl """Decorate Resampling function with unified docstring.""" doc = f""" {name} resampling. {desc} - - Args: - key: PRNGKey to use in resampling - logits: Log-weights, possibly unnormalized. - n: Number of indices to sample. - - Returns: - Array of size n with indices to use for resampling. + {_RESAMPLING_DOC} """ func.__doc__ = doc @@ -76,16 +91,7 @@ def conditional_resampling_decorator( """Decorate ConditionalResampling function with unified docstring.""" doc = f""" {name} conditional resampling. {desc} - - Args: - key: PRNGKey to use in resampling - logits: Log-weights, possibly unnormalized. - n: Number of indices to sample - pivot_in: Index of the particle to keep - pivot_out: Value of the output at index `pivot_in` - - Returns: - Array of size n with indices to use for resampling. + {_CONDITIONAL_RESAMPLING_DOC} """ func.__doc__ = doc diff --git a/cuthbertlib/resampling/systematic.py b/cuthbertlib/resampling/systematic.py index c01f9025..7cb1c51b 100644 --- a/cuthbertlib/resampling/systematic.py +++ b/cuthbertlib/resampling/systematic.py @@ -11,8 +11,14 @@ conditional_resampling_decorator, resampling_decorator, ) -from cuthbertlib.resampling.utils import inverse_cdf -from cuthbertlib.types import Array, ArrayLike, ScalarArrayLike +from cuthbertlib.resampling.utils import apply_resampling_indices, inverse_cdf +from cuthbertlib.types import ( + Array, + ArrayLike, + ArrayTree, + ArrayTreeLike, + ScalarArrayLike, +) _DESCRIPTION = """ The Systematic resampling is a variance reduction which places marginally @@ -21,19 +27,24 @@ @partial(resampling_decorator, name="Systematic", desc=_DESCRIPTION) -def resampling(key: Array, logits: ArrayLike, n: int) -> Array: +def resampling( + key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int +) -> tuple[Array, Array, ArrayTree]: us = (random.uniform(key, ()) + jnp.arange(n)) / n - return inverse_cdf(us, logits) + idx = inverse_cdf(us, logits) + logits_out = jnp.zeros_like(us) + return idx, logits_out, apply_resampling_indices(positions, idx) @partial(conditional_resampling_decorator, name="Systematic", desc=_DESCRIPTION) def conditional_resampling( key: Array, logits: ArrayLike, + positions: ArrayTreeLike, n: int, pivot_in: ScalarArrayLike, pivot_out: ScalarArrayLike, -) -> Array: +) -> tuple[Array, Array, ArrayTree]: logits = jnp.asarray(logits) pivot_in = jnp.asarray(pivot_in) pivot_out = jnp.asarray(pivot_out) @@ -46,17 +57,17 @@ def conditional_resampling( logits = jnp.roll(logits, -pivot_out) arange = jnp.roll(arange, -pivot_out) - idx = conditional_resampling_0_to_0(key, logits, n) + idx, logits_out = conditional_resampling_0_to_0(key, logits, n) idx = arange[idx] idx = jnp.roll(idx, pivot_in) - return idx + return idx, logits_out, apply_resampling_indices(positions, idx) def conditional_resampling_0_to_0( key: Array, logits: ArrayLike, n: int, -) -> Array: +) -> tuple[Array, Array]: logits = jnp.asarray(logits) N = logits.shape[0] @@ -81,4 +92,4 @@ def _otherwise(): roll_idx = jnp.floor(n_zero * W).astype(int) idx = select(n_zero == 1, idx, jnp.roll(idx, -zero_loc[roll_idx])) - return jnp.clip(idx, 0, N - 1) + return jnp.clip(idx, 0, N - 1), jnp.zeros_like(linspace) diff --git a/cuthbertlib/resampling/utils.py b/cuthbertlib/resampling/utils.py index 4bfb743a..afbdea40 100644 --- a/cuthbertlib/resampling/utils.py +++ b/cuthbertlib/resampling/utils.py @@ -6,8 +6,9 @@ import numpy as np from jax.lax import platform_dependent from jax.scipy.special import logsumexp +from jax.tree_util import tree_map -from cuthbertlib.types import Array, ArrayLike +from cuthbertlib.types import Array, ArrayLike, ArrayTree, ArrayTreeLike @jax.jit @@ -80,3 +81,8 @@ def _inverse_cdf_numba(su, ws, idx): j += 1 s += ws[j] idx[n] = j + + +def apply_resampling_indices(positions: ArrayTreeLike, idx: Array) -> ArrayTree: + """Apply resampling indices to positions.""" + return tree_map(lambda x: x[idx], positions) diff --git a/cuthbertlib/smc/smoothing/mcmc.py b/cuthbertlib/smc/smoothing/mcmc.py index 219460cb..15d63058 100644 --- a/cuthbertlib/smc/smoothing/mcmc.py +++ b/cuthbertlib/smc/smoothing/mcmc.py @@ -55,8 +55,9 @@ def body(carry, keys_t): idx, x0_res, idx_log_p = carry key_prop, key_acc = keys_t - prop_idx = multinomial.resampling(key_prop, log_weight_x0_all, n_samples) - x0_prop = jax.tree.map(lambda z: z[prop_idx], x0_all) + prop_idx, _, x0_prop = multinomial.resampling( + key_prop, log_weight_x0_all, x0_all, n_samples + ) prop_log_p = jax.vmap(log_density)(x0_prop, x1_all) log_alpha = prop_log_p - idx_log_p diff --git a/docs/assets/online_stoch_vol_filter.png b/docs/assets/online_stoch_vol_filter.png index 05ffe4e8..ca1023e6 100644 Binary files a/docs/assets/online_stoch_vol_filter.png and b/docs/assets/online_stoch_vol_filter.png differ diff --git a/docs/assets/online_stoch_vol_predict.png b/docs/assets/online_stoch_vol_predict.png index 219aa607..41e30644 100644 Binary files a/docs/assets/online_stoch_vol_predict.png and b/docs/assets/online_stoch_vol_predict.png differ diff --git a/docs/examples/online_stoch_vol.md b/docs/examples/online_stoch_vol.md index 69ad605d..5b19c98e 100644 --- a/docs/examples/online_stoch_vol.md +++ b/docs/examples/online_stoch_vol.md @@ -27,7 +27,7 @@ import yfinance as yf from cuthbert import filter from cuthbert.smc import particle_filter -from cuthbertlib.resampling import systematic +from cuthbertlib.resampling import adaptive, systematic ``` We'll use a simple bootstrap particle filter for inference since our model is @@ -194,7 +194,9 @@ on how often to resample before constructing the filter object. ```{.python #online-stoch-vol-particle-filter-setup} n_particles = 1000 -ess_threshold = 0.5 +resampling = adaptive.ess_decorator( + systematic.resampling, 0.5 +) # Resample when ESS drops below 50% pf = particle_filter.build_filter( init_sample=init_sample, @@ -202,7 +204,6 @@ pf = particle_filter.build_filter( log_potential=log_potential, n_filter_particles=n_particles, resampling_fn=systematic.resampling, - ess_threshold=ess_threshold, ) ``` diff --git a/pyproject.toml b/pyproject.toml index dec55ac8..ac4ff6cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ override-dependencies = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F821", "E402", "D104"] "tests/**/*.py" = ["D"] # no docstring checking for tests -"cuthbertlib/resampling/**/*.py" = ["D103"] # resampling uses decorator for docstrings +"cuthbertlib/resampling/**/*.py" = ["D102", "D103"] # resampling uses decorator for docstrings [tool.ruff.lint] select = ["D"] [tool.ruff.lint.pydocstyle] @@ -24,4 +24,4 @@ markers = "examples: Run tangled example scripts as tests" [tool.entangled] version = "2.3.0" -ignore_list = ["CONTRIBUTING.md"] +ignore_list = ["CONTRIBUTING.md"] \ No newline at end of file diff --git a/tests/cuthbert/smc/test_backward_sampler.py b/tests/cuthbert/smc/test_backward_sampler.py index cbcbc8b2..226af1f6 100644 --- a/tests/cuthbert/smc/test_backward_sampler.py +++ b/tests/cuthbert/smc/test_backward_sampler.py @@ -11,7 +11,7 @@ from cuthbert.smc.backward_sampler import build_smoother from cuthbert.smc.particle_filter import build_filter from cuthbertlib.kalman.generate import generate_lgssm -from cuthbertlib.resampling import systematic +from cuthbertlib.resampling import ess_decorator, systematic from cuthbertlib.smc.smoothing.exact_sampling import simulate as exact from cuthbertlib.smc.smoothing.mcmc import simulate as mcmc from cuthbertlib.smc.smoothing.tracing import simulate as tracing @@ -45,13 +45,13 @@ def log_potential(state_prev, state, model_inputs: int): n_filter_particles = 5000 resampling_fn = systematic.resampling ess_threshold = 0.7 + adaptive_resampler = ess_decorator(resampling_fn, ess_threshold) filter_obj = build_filter( init_sample, propagate_sample, log_potential, n_filter_particles, - resampling_fn, - ess_threshold, + adaptive_resampler, ) model_inputs = jnp.arange(len(ys) + 1) return filter_obj, model_inputs, log_potential @@ -136,13 +136,13 @@ def log_potential(state_prev, state, model_inputs): n_filter_particles = 1000 resampling_fn = systematic.resampling ess_threshold = 0.7 + adaptive_resampler = ess_decorator(resampling_fn, ess_threshold) filter_obj = build_filter( init_sample, propagate_sample, log_potential, n_filter_particles, - resampling_fn, - ess_threshold, + adaptive_resampler, ) if method == "tracing": diff --git a/tests/cuthbert/smc/test_particle_filters.py b/tests/cuthbert/smc/test_particle_filters.py index 03e68360..f2dda632 100644 --- a/tests/cuthbert/smc/test_particle_filters.py +++ b/tests/cuthbert/smc/test_particle_filters.py @@ -11,7 +11,7 @@ from cuthbert.inference import Filter from cuthbert.smc import marginal_particle_filter, particle_filter from cuthbertlib.kalman.generate import generate_lgssm -from cuthbertlib.resampling import systematic +from cuthbertlib.resampling import ess_decorator, systematic from cuthbertlib.stats.multivariate_normal import logpdf from tests.cuthbert.gaussian.test_kalman import std_kalman_filter @@ -62,6 +62,9 @@ def log_potential(state_prev, state, model_inputs: int): ) ess_threshold = 0.7 + # Decorate the resampling with adaptive behaviour and pass that to the filter + adaptive_systematic = ess_decorator(systematic.resampling, ess_threshold) + inference = Filter( init_prepare=partial( algo.init_prepare, @@ -77,8 +80,7 @@ def log_potential(state_prev, state, model_inputs: int): algo.filter_combine, propagate_sample=propagate_sample, log_potential=log_potential, - resampling_fn=systematic.resampling, - ess_threshold=ess_threshold, + resampling_fn=adaptive_systematic, ), associative=False, ) @@ -158,6 +160,8 @@ def log_potential(state_prev, state, model_inputs): return jnp.zeros(()) ess_threshold = 0.7 + # Decorate the resampler for adaptive resampling behaviour + adaptive_systematic = ess_decorator(systematic.resampling, ess_threshold) if method == "bootstrap": n_filter_particles = 1_000 @@ -183,8 +187,7 @@ def log_potential(state_prev, state, model_inputs): algo.filter_combine, propagate_sample=propagate_sample, log_potential=log_potential, - resampling_fn=systematic.resampling, - ess_threshold=ess_threshold, + resampling_fn=adaptive_systematic, ), associative=False, ) diff --git a/tests/cuthbertlib/resampling/test_resamplings.py b/tests/cuthbertlib/resampling/test_resamplings.py index aa9f1431..73651f4e 100644 --- a/tests/cuthbertlib/resampling/test_resamplings.py +++ b/tests/cuthbertlib/resampling/test_resamplings.py @@ -86,7 +86,12 @@ def test_resampling(self, seed, test_case): method = get_resampling(test_case["method"]) for M, N in MNs: - resampling = self.variant(lambda k_, lw_: method(k_, lw_, M)) + # create dummy positions in the wrapper; accept the positions arg from tester and ignore it + resampling = self.variant( + lambda k_, lw_, positions: method( + k_, lw_, jax.random.normal(jax.random.key(0), (N,)), M + ) + ) log_weights = jax.random.uniform(key_weights, (N,)) resampling_tester(key_test, log_weights, resampling, M, self.K) @@ -102,8 +107,14 @@ def test_conditional_resampling(self, seed, test_case): conditional_method = get_conditional_resampling(test_case["method"]) for M in Ms: conditional_resampling = self.variant( - lambda k_, lw_, pivot_in, pivot_out: conditional_method( - k_, lw_, M, pivot_in, pivot_out + # accept positions arg from tester and ignore it + lambda k_, lw_, positions, pivot_in, pivot_out: conditional_method( + k_, + lw_, + jax.random.normal(jax.random.key(0), (M,)), + M, + pivot_in, + pivot_out, ) ) diff --git a/tests/cuthbertlib/resampling/utils.py b/tests/cuthbertlib/resampling/utils.py index 8f5a265e..b1b060db 100644 --- a/tests/cuthbertlib/resampling/utils.py +++ b/tests/cuthbertlib/resampling/utils.py @@ -7,7 +7,16 @@ def resampling_tester(rng_key, log_weights, resampling, m, k): keys = jax.random.split(rng_key, k) - indices = jax.vmap(resampling, [0, None])(keys, log_weights) + + # create dummy positions matching number of particles + n_particles = log_weights.shape[0] + positions = jax.random.normal(jax.random.key(0), (n_particles,)) + + def call_one(key): + idx, logits_out, _ = resampling(key, log_weights, positions) + return idx + + indices = jax.vmap(call_one)(keys) _check_bincounts(indices, log_weights, m, k) @@ -24,8 +33,11 @@ def do_one(key): p = jnp.exp(log_weights - logsumexp(log_weights)) pivot_out = jax.random.choice(key_resampling, m, shape=(), p=p) - conditional_indices = conditional_resampling( - key_conditional, log_weights, pivot_in, pivot_out + # create dummy positions + positions = jax.random.normal(jax.random.key(0), (m,)) + + conditional_indices, _, _ = conditional_resampling( + key_conditional, log_weights, positions, pivot_in, pivot_out ) return conditional_indices, pivot_out, conditional_indices[pivot_in]