Skip to content

Commit e8d8075

Browse files
authored
Merge pull request #166 from florisvb/improve-robust
Incrementing robust with different cost function
2 parents db9d6c1 + a7685cc commit e8d8075

4 files changed

Lines changed: 18 additions & 18 deletions

File tree

pynumdiff/finite_difference/_finite_difference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from warnings import warn
55

66

7-
def finitediff(x, dt, num_iterations, order):
7+
def finitediff(x, dt, num_iterations=1, order=2):
88
"""Perform iterated finite difference of a given order. This serves as the common backing function for
99
all other methods in this module.
1010

pynumdiff/kalman_smooth/_kalman_smooth.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def kalman_filter(y, _t, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
5353
else:
5454
if not equispaced:
5555
dt = _t[n] - _t[n-1]
56-
eM = expm(M * dt) # form discrete-time matrices TODO doesn't work at n=0
56+
eM = expm(M * dt) # form discrete-time matrices
5757
An = eM[:m,:m] # upper left block
5858
Qn = eM[:m,m:] @ An.T # upper right block
5959
if dt < 0: Qn = np.abs(Qn) # eigenvalues go negative if reverse time, but noise shouldn't shrink
@@ -314,20 +314,20 @@ def convex_smooth(y, A, Q, C, R, huberM=0):
314314
N = len(y)
315315
x_states = cvxpy.Variable((N, A.shape[0])) # each row is [position, velocity, acceleration, ...] at step n
316316

317-
R_sqrt_inv = np.linalg.inv(sqrtm(R))
318317
Q_sqrt_inv = np.linalg.inv(sqrtm(Q))
319-
objective = cvxpy.sum([cvxpy.norm(R_sqrt_inv @ (y[n] - C @ x_states[n]), 1) if huberM < 1e-3 # Measurement terms: sum of ||R^(-1/2)(y_n - C x_n)||_1
320-
else cvxpy.sum(cvxpy.huber(R_sqrt_inv @ (y[n] - C @ x_states[n]), huberM)) for n in range(N)])
321-
objective += cvxpy.sum([cvxpy.norm(Q_sqrt_inv @ (x_states[n] - A @ x_states[n-1]), 1) if huberM < 1e-3 # Process terms: sum of ||Q^(-1/2)(x_n - A x_{n-1})||_1
322-
else cvxpy.sum(cvxpy.huber(Q_sqrt_inv @ (x_states[n] - A @ x_states[n-1]), huberM)) for n in range(1, N)])
323-
318+
R_sqrt_inv = np.linalg.inv(sqrtm(R))
319+
# Process terms: sum of 1/2||Q^(-1/2)(x_n - A x_{n-1})||_2^2
320+
objective = 0.5*cvxpy.sum([cvxpy.sum_squares(Q_sqrt_inv @ (x_states[n] - A @ x_states[n-1])) for n in range(1, N)])
321+
# Measurement terms: sum of sqrt(2)||R^(-1/2)(y_n - C x_n)||_1, per https://jmlr.org/papers/volume14/aravkin13a/aravkin13a.pdf section 6
322+
objective += np.sqrt(2)*cvxpy.sum([cvxpy.norm(R_sqrt_inv @ (y[n] - C @ x_states[n]), 1) if huberM < 1e-3
323+
else cvxpy.sum(cvxpy.huber(R_sqrt_inv @ (y[n] - C @ x_states[n]), huberM)) for n in range(N)])
324+
324325
problem = cvxpy.Problem(cvxpy.Minimize(objective))
325326
try:
326327
problem.solve(solver=cvxpy.CLARABEL)
327328
except cvxpy.error.SolverError:
328329
warn(f"CLARABEL failed. Retrying with SCS.")
329330
problem.solve(solver=cvxpy.SCS) # SCS is a lot slower but pretty bulletproof even with big condition numbers
330-
331331
if x_states.value is None: # There is occasional solver failure with huber as opposed to 1-norm
332332
warn("Convex solvers failed with status {problem.status}. Returning NaNs.")
333333
return np.full((N, A.shape[0]), np.nan)

pynumdiff/tests/test_diff_methods.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
5151
(constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]),
5252
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
5353
(rtsdiff, {'order':2, 'qr_ratio':1e7, 'forwardbackward':True}),
54-
(robustdiff, {'order':3, 'qr_ratio':1e6}),
54+
(robustdiff, {'order':3, 'qr_ratio':1e8}),
5555
(velocity, {'gamma':0.5}), (velocity, [0.5]),
5656
(acceleration, {'gamma':1}), (acceleration, [1]),
5757
(jerk, {'gamma':10}), (jerk, [10]),
@@ -223,11 +223,11 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
223223
[(-2, -3), (0, 0), (0, -1), (1, 1)],
224224
[(-1, -2), (1, 1), (0, -1), (1, 1)],
225225
[(0, 0), (3, 3), (0, 0), (3, 3)]],
226-
robustdiff: [[(-15, -15), (-14, -14), (0, -1), (0, 0)],
227-
[(-14, -14), (-13, -14), (0, -1), (0, 0)],
228-
[(-14, -14), (-13, -13), (0, -1), (0, 0)],
229-
[(-1, -1), (0, 0), (0, -1), (1, 0)],
230-
[(0, 0), (1, 1), (0, 0), (1, 1)],
226+
robustdiff: [[(-14, -15), (-17, -17), (0, -1), (1, 1)],
227+
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
228+
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
229+
[(-12, -12), (-2, -2), (0, -1), (1, 1)],
230+
[(0, 0), (2, 2), (0, 0), (2, 2)],
231231
[(1, 1), (3, 3), (1, 1), (3, 3)]],
232232
lineardiff: [[(-6, -6), (-5, -6), (0, -1), (0, 0)],
233233
[(0, 0), (2, 1), (0, 0), (2, 1)],

pynumdiff/utils/simulate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# local imports
88
from pynumdiff.utils.utility import peakdet
9-
from pynumdiff.finite_difference import second_order as finite_difference
9+
from pynumdiff.finite_difference import finitediff
1010

1111

1212
# pylint: disable-msg=too-many-locals, too-many-arguments, no-member
@@ -108,7 +108,7 @@ def triangle(duration=4, noise_type='normal', noise_parameters=(0, 0.5), outlier
108108
reversal_ts = t[reversal_idxs]
109109

110110
pos = np.interp(t, reversal_ts, reversal_vals)
111-
_, vel = finite_difference(pos, dt=simdt)
111+
_, vel = finitediff(pos, dt=simdt)
112112
noisy_pos = _add_noise(pos, random_seed, noise_type, noise_parameters, outliers)
113113

114114
idx = np.arange(0, len(t), int(dt/simdt))
@@ -182,7 +182,7 @@ def linear_autonomous(duration=4, noise_type='normal', noise_parameters=(0, 0.5)
182182
xs = np.vstack(xs).T
183183
pos = xs[0,:]
184184

185-
smooth_pos, vel = finite_difference(pos, simdt)
185+
smooth_pos, vel = finitediff(pos, simdt)
186186
noisy_pos = _add_noise(pos, random_seed, noise_type, noise_parameters, outliers)
187187

188188
idx = slice(0, len(t), int(dt/simdt)) # downsample so things are dt apart

0 commit comments

Comments
 (0)