Implement the Stop-Gradient Differentiable Particle Filter#209
Implement the Stop-Gradient Differentiable Particle Filter#209AdrienCorenflos merged 8 commits intostate-space-models:mainfrom
Conversation
Previously, computing the gradient of the MLL would result in RunTime errors, as the default _inverse_cdf implementation on CPU used a `jax.pure_callback`. Since we don't expect the `inverse_cdf` function to be differentiable, anyways, we can safely wrap its inputs with a `lax.stop_gradient` call to ensure at least biased gradients are available.
This commit introduces the stop-gradient resampler of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314) as a resample decorator. Wrapping a base resampling method in the stop_gradient_decorator results in automatic differentation estimates of the score matching the classical estimates obtained via Fisher's identity.
Adds an example to the documentatio showing the score estimates of a SG-DPF vs. a non-differentiable PF on a linear-Gaussian task.
Unit testing in this setting is somewhat difficult, for a few reasons: 1. The estimator is still stochastic, and requires lots of trials or particles to make accurate. But this seems rather expensive for CI. 2. The existing utilities for generating test cases are time-varying, and provide somewhat ill-conditioned systems The resolution here is to choose an extremely simple system (random walk on x, identity Gaussian observation on y), and run with 10,000 particles, taking the median over 10 runs, with a 20% relative error tolerance. This seems to run fast enough, at least locally. But happy to explore other options here.
|
I don't understand why it's hard to unittest this: it can be done almost without randomness. |
Clarified the impact of non-differentiable resampling on gradient estimates and provided math for the gradient it computes
Previously, a bootstrap PF was run in its entirety on a linear Gaussian system, and its parameter gradients checked against the Kalman filter. This was expensive and stochastic. Per @AdrienCorenflos's suggestion (state-space-models#209 (comment)), this is tested much more directly.
|
Hi @AdrienCorenflos! Thanks for the suggestion! Indeed, that's a nice unit test -- I had "testing the stop-gradient DPF" in my head instead of "testing the stop gradient decorator," which is why I had written the test that way before. I've updated the unit tests to perform your version of the test (and reverted the |
Enhance explanation of differentiable resampling methods
|
Responded to this round of feedback. Namely, I migrated the submodule name to |
This PR implements the stop-gradient differentiable PF (https://arxiv.org/abs/2106.10314) as a resampling decorator. It supersedes #202 after the merging of #207.
The following changes were made:
jax.pure_callback. Since we don't expect theinverse_cdffunction to be differentiable, anyways, we can safely wrap its inputs with alax.stop_gradientcall to ensure at least biased gradients are available.The resulting DPF is shown to provide strong score estimates in a toy example, now in the documentation.


To use the stop-gradient resampling, one simply uses the
stop_gradient_decorator:Unit testing in this setting is somewhat difficult, for a few reasons:
cuthbertlib.kalman.generatecan often be on the order ofThe resolution here is to choose an extremely simple system (random walk on x, identity Gaussian observation on y), and run with 10,000 particles. The score is estimated at the true parameter for the dynamics, and tested against the score from the ground-truth Kalman filter, taking the median over 10 runs, with a 20% relative error tolerance. This seems to run fast enough, at least locally. But happy to explore other options here.