Skip to content

Implement the Stop-Gradient Differentiable Particle Filter#209

Merged
AdrienCorenflos merged 8 commits intostate-space-models:mainfrom
DanWaxman:dw-make-pf-diff
Mar 3, 2026
Merged

Implement the Stop-Gradient Differentiable Particle Filter#209
AdrienCorenflos merged 8 commits intostate-space-models:mainfrom
DanWaxman:dw-make-pf-diff

Conversation

@DanWaxman
Copy link
Contributor

This PR implements the stop-gradient differentiable PF (https://arxiv.org/abs/2106.10314) as a resampling decorator. It supersedes #202 after the merging of #207.

The following changes were made:

The resulting DPF is shown to provide strong score estimates in a toy example, now in the documentation.
image
image

To use the stop-gradient resampling, one simply uses the stop_gradient_decorator:

resampling_fn = (
    stop_gradient.stop_gradient_decorator(systematic.resampling)
    if differentiable_resampling
    else systematic.resampling
)
# In either case, do adaptive resampling
resampling_fn = adaptive.ess_decorator(resampling_fn, ess_threshold)

Unit testing in this setting is somewhat difficult, for a few reasons:

  1. The estimator is still stochastic, and requires lots of trials or particles to make accurate. But this seems rather expensive for CI.
  2. The existing utilities for generating test cases are time-varying, and provide somewhat ill-conditioned systems. For example, the the scale of observation noise as generated in cuthbertlib.kalman.generate can often be on the order of $10^-3$, which can make gradients pretty poorly behaved.

The resolution here is to choose an extremely simple system (random walk on x, identity Gaussian observation on y), and run with 10,000 particles. The score is estimated at the true parameter for the dynamics, and tested against the score from the ground-truth Kalman filter, taking the median over 10 runs, with a 20% relative error tolerance. This seems to run fast enough, at least locally. But happy to explore other options here.

Previously, computing the gradient of the MLL would result in RunTime errors, as the default _inverse_cdf implementation on CPU used a `jax.pure_callback`. Since we don't expect the `inverse_cdf` function to be differentiable, anyways, we can safely wrap its inputs with a `lax.stop_gradient` call to ensure at least biased gradients are available.
This commit introduces the stop-gradient resampler of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314) as a resample decorator. Wrapping a base resampling method in the stop_gradient_decorator results in automatic differentation estimates of the score matching the classical estimates obtained via Fisher's identity.
Adds an example to the documentatio showing the score estimates of a SG-DPF vs. a non-differentiable PF on a linear-Gaussian task.
Unit testing in this setting is somewhat difficult, for a few reasons:
1. The estimator is still stochastic, and requires lots of trials or particles to make accurate. But this seems rather expensive for CI.
2. The existing utilities for generating test cases are time-varying, and provide somewhat ill-conditioned systems

The resolution here is to choose an extremely simple system (random walk on x, identity Gaussian observation on y), and run with 10,000 particles, taking the median over 10 runs, with a 20% relative error tolerance. This seems to run fast enough, at least locally. But happy to explore other options here.
@AdrienCorenflos
Copy link
Contributor

I don't understand why it's hard to unittest this: it can be done almost without randomness.
Take a Gaussian weights w_n = N(0;x_n, \sigma^2) evaluated on a grid xs = np.linspace(-1, 1, 100). Compute the gradient of logsumexp(logws) before and after resampling and it should be the same if the wrapper is there and not otherwise? Am I missing something?

AdrienCorenflos and others added 2 commits February 27, 2026 10:53
Clarified the impact of non-differentiable resampling on gradient estimates and provided math for the gradient it computes
Previously, a bootstrap PF was run in its entirety on a linear Gaussian system, and its parameter gradients checked against the Kalman filter. This was expensive and stochastic. Per @AdrienCorenflos's suggestion (state-space-models#209 (comment)), this is tested much more directly.
@DanWaxman
Copy link
Contributor Author

Hi @AdrienCorenflos! Thanks for the suggestion! Indeed, that's a nice unit test -- I had "testing the stop-gradient DPF" in my head instead of "testing the stop gradient decorator," which is why I had written the test that way before. I've updated the unit tests to perform your version of the test (and reverted the test_particle_filters file).

@DanWaxman
Copy link
Contributor Author

Responded to this round of feedback. Namely, I migrated the submodule name to cuthbertlib.resampling.autodiff and propagated those changes, and removed the erroneous printing statement. Comments regarding the end-to-end test in test_particle_filters.py were moot, since that test got replaced by the simpler unit test, but let me know if you think it's better to reinstate the end-to-end test as well.

@AdrienCorenflos AdrienCorenflos merged commit d1beb22 into state-space-models:main Mar 3, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants