Skip to content

Commit ba67ab3

Browse files
authored
Merge pull request #1585 from FlorianPfaff/copilot/add-filters-and-tests
Add tests for AbstractDummyFilter and HypersphericalDummyFilter
2 parents c3471d3 + 8514498 commit ba67ab3

4 files changed

Lines changed: 165 additions & 0 deletions

File tree

pyrecest/filters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1+
from .abstract_dummy_filter import AbstractDummyFilter
12
from .abstract_axial_filter import AbstractAxialFilter
23
from .abstract_filter import AbstractFilter
34
from .abstract_particle_filter import AbstractParticleFilter
45
from .euclidean_particle_filter import EuclideanParticleFilter
6+
from .hyperspherical_dummy_filter import HypersphericalDummyFilter
57
from .hypertoroidal_particle_filter import HypertoroidalParticleFilter
68
from .kalman_filter import KalmanFilter
79
from .manifold_mixins import EuclideanFilterMixin, HypertoroidalFilterMixin
810

911
__all__ = [
12+
"AbstractDummyFilter",
1013
"AbstractAxialFilter",
1114
"AbstractFilter",
1215
"EuclideanFilterMixin",
1316
"HypertoroidalFilterMixin",
1417
"AbstractParticleFilter",
1518
"HypertoroidalParticleFilter",
19+
"HypersphericalDummyFilter",
1620
"KalmanFilter",
1721
"EuclideanParticleFilter",
1822
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from .abstract_filter import AbstractFilter
2+
3+
4+
class AbstractDummyFilter(AbstractFilter):
5+
"""Abstract dummy filter that does nothing on predictions and updates.
6+
7+
Subclasses should call super().__init__ with the initial distribution.
8+
"""
9+
10+
def __init__(self, initial_filter_state):
11+
AbstractFilter.__init__(self, initial_filter_state)
12+
13+
@property
14+
def dist(self):
15+
return self._filter_state
16+
17+
@property
18+
def filter_state(self):
19+
return self._filter_state
20+
21+
@filter_state.setter
22+
def filter_state(self, new_state):
23+
# Do nothing - the dummy filter state is fixed at initialization
24+
pass
25+
26+
def set_state(self, dist):
27+
assert dist.dim == self.dim
28+
# Do nothing
29+
30+
def predict_identity(self, noise_distribution):
31+
pass
32+
33+
def predict_nonlinear(self, f, *args, **kwargs):
34+
pass
35+
36+
def predict_nonlinear_via_transition_density(self, transition_density, *args):
37+
pass
38+
39+
def update_identity(self, noise_distribution, measurement):
40+
pass
41+
42+
def update_nonlinear(self, likelihood, measurement=None):
43+
pass
44+
45+
def get_estimate(self):
46+
return self.dist
47+
48+
def get_point_estimate(self):
49+
return self.dist.sample(1)[0]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from pyrecest.distributions.hypersphere_subset.hyperspherical_uniform_distribution import (
2+
HypersphericalUniformDistribution,
3+
)
4+
5+
from .abstract_dummy_filter import AbstractDummyFilter
6+
from .manifold_mixins import HypersphericalFilterMixin
7+
8+
9+
class HypersphericalDummyFilter(AbstractDummyFilter, HypersphericalFilterMixin):
10+
"""Hyperspherical dummy filter initialized with a uniform distribution.
11+
12+
This filter does nothing on predictions and updates, always returning
13+
samples from the initial uniform distribution as point estimates.
14+
"""
15+
16+
def __init__(self, dim):
17+
"""Initialize HypersphericalDummyFilter.
18+
19+
Parameters:
20+
dim (int >= 2): Manifold dimension of the hypersphere (e.g. 2 for S^2).
21+
"""
22+
assert dim >= 2, "dim must be at least 2"
23+
HypersphericalFilterMixin.__init__(self)
24+
AbstractDummyFilter.__init__(self, HypersphericalUniformDistribution(dim))
25+
26+
def get_point_estimate(self):
27+
return AbstractDummyFilter.get_point_estimate(self)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import unittest
2+
3+
import numpy.testing as npt
4+
5+
# pylint: disable=no-name-in-module,no-member
6+
from pyrecest.backend import array, linalg
7+
from pyrecest.distributions.hypersphere_subset.hyperspherical_uniform_distribution import (
8+
HypersphericalUniformDistribution,
9+
)
10+
from pyrecest.distributions import VonMisesFisherDistribution
11+
from pyrecest.filters.hyperspherical_dummy_filter import HypersphericalDummyFilter
12+
13+
14+
class HypersphericalDummyFilterTest(unittest.TestCase):
15+
def setUp(self):
16+
self.filter_s2 = HypersphericalDummyFilter(2)
17+
self.filter_s3 = HypersphericalDummyFilter(3)
18+
19+
def test_dim_s2(self):
20+
self.assertEqual(self.filter_s2.dim, 2)
21+
22+
def test_dim_s3(self):
23+
self.assertEqual(self.filter_s3.dim, 3)
24+
25+
def test_assert_dim_too_small(self):
26+
with self.assertRaises(AssertionError):
27+
HypersphericalDummyFilter(1)
28+
29+
def test_filter_state_is_uniform(self):
30+
self.assertIsInstance(
31+
self.filter_s2.filter_state, HypersphericalUniformDistribution
32+
)
33+
34+
def test_get_point_estimate_unit_norm_s2(self):
35+
est = self.filter_s2.get_point_estimate()
36+
self.assertEqual(est.shape, (3,))
37+
npt.assert_allclose(linalg.norm(est), 1.0, atol=1e-10)
38+
39+
def test_get_point_estimate_unit_norm_s3(self):
40+
est = self.filter_s3.get_point_estimate()
41+
self.assertEqual(est.shape, (4,))
42+
npt.assert_allclose(linalg.norm(est), 1.0, atol=1e-10)
43+
44+
def test_predict_identity_is_noop(self):
45+
noise = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 1.0)
46+
state_before = self.filter_s2.filter_state
47+
self.filter_s2.predict_identity(noise)
48+
self.assertIs(self.filter_s2.filter_state, state_before)
49+
50+
def test_predict_nonlinear_is_noop(self):
51+
state_before = self.filter_s2.filter_state
52+
self.filter_s2.predict_nonlinear(lambda x: x)
53+
self.assertIs(self.filter_s2.filter_state, state_before)
54+
55+
def test_update_identity_is_noop(self):
56+
noise = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 1.0)
57+
measurement = array([0.0, 0.0, 1.0])
58+
state_before = self.filter_s2.filter_state
59+
self.filter_s2.update_identity(noise, measurement)
60+
self.assertIs(self.filter_s2.filter_state, state_before)
61+
62+
def test_update_nonlinear_is_noop(self):
63+
state_before = self.filter_s2.filter_state
64+
self.filter_s2.update_nonlinear(lambda z, x: x)
65+
self.assertIs(self.filter_s2.filter_state, state_before)
66+
67+
def test_filter_state_setter_is_noop(self):
68+
original_state = self.filter_s2.filter_state
69+
new_dist = HypersphericalUniformDistribution(2)
70+
self.filter_s2.filter_state = new_dist
71+
self.assertIs(self.filter_s2.filter_state, original_state)
72+
73+
def test_set_state_is_noop(self):
74+
original_state = self.filter_s2.filter_state
75+
new_dist = HypersphericalUniformDistribution(2)
76+
self.filter_s2.set_state(new_dist)
77+
self.assertIs(self.filter_s2.filter_state, original_state)
78+
79+
def test_get_estimate_returns_distribution(self):
80+
est = self.filter_s2.get_estimate()
81+
self.assertIsInstance(est, HypersphericalUniformDistribution)
82+
83+
84+
if __name__ == "__main__":
85+
unittest.main()

0 commit comments

Comments
 (0)