Skip to content

Commit 7ff7f2b

Browse files
authored
Merge pull request #100 from florisvb/improve-unittests-more
finite difference unit tests now live in , the way I'm checking bound…
2 parents b48d133 + 9fa36b1 commit 7ff7f2b

5 files changed

Lines changed: 122 additions & 99 deletions

File tree

pynumdiff/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
"""
2-
Import useful functions from all modules
1+
"""Import useful functions from all modules
32
"""
43
from pynumdiff._version import __version__
54
from pynumdiff.finite_difference import first_order, second_order
6-
from pynumdiff.smooth_finite_difference import mediandiff, meandiff, gaussiandiff, \
5+
from pynumdiff.smooth_finite_difference import mediandiff, meandiff, gaussiandiff,\
76
friedrichsdiff, butterdiff, splinediff
8-
from pynumdiff.total_variation_regularization import *
9-
from pynumdiff.linear_model import *
10-
from pynumdiff.kalman_smooth import constant_velocity, constant_acceleration, constant_jerk, \
7+
from pynumdiff.total_variation_regularization import iterative_velocity, velocity,\
8+
acceleration, jerk, smooth_acceleration, jerk_sliding
9+
from pynumdiff.linear_model import lineardiff, polydiff, spectraldiff, savgoldiff
10+
from pynumdiff.kalman_smooth import constant_velocity, constant_acceleration, constant_jerk,\
1111
known_dynamics
12-

pynumdiff/finite_difference/_finite_difference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _iterate_first_order(x, dt, num_iterations):
6868
- **x_hat** -- estimated (smoothed) x
6969
- **dxdt_hat** -- estimated derivative of x
7070
"""
71-
w = np.arange(len(x)) / (len(x) - 1) # set up weights, [0., ... 1.0]
71+
w = np.linspace(0, 1, len(x)) # set up weights, [0., ... 1.0]
7272

7373
# forward backward passes
7474
for _ in range(num_iterations):

pynumdiff/tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Pytest configuration for pynumdiff tests"""
2+
import pytest
3+
4+
def pytest_addoption(parser): parser.addoption("--plot", action="store_true", default=False)
Lines changed: 111 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from matplotlib import pyplot
23
from pytest import mark
34
from warnings import warn
45

@@ -7,77 +8,138 @@
78
from ..kalman_smooth import * # constant_velocity, constant_acceleration, constant_jerk, known_dynamics
89
from ..smooth_finite_difference import * # mediandiff, meandiff, gaussiandiff, friedrichsdiff, butterdiff, splinediff
910
from ..finite_difference import first_order, second_order
11+
# Function aliases for testing cases where parameters change the behavior in a big way
12+
iterated_first_order = lambda *args, **kwargs: first_order(*args, **kwargs)
1013

11-
12-
dt = 0.01
13-
t = np.arange(0, 3, dt) # domain to test over
14+
dt = 0.1
15+
t = np.arange(0, 3, dt) # sample locations
16+
tt = np.linspace(0, 3) # full domain, for visualizing denser plots
1417
np.random.seed(7) # for repeatability of the test, so we don't get random failures
1518
noise = 0.05*np.random.randn(*t.shape)
1619

20+
# Analytic (function, derivative) pairs on which to test differentiation methods.
21+
test_funcs_and_derivs = [
22+
(r"$x(t)=1$", lambda t: np.ones(t.shape), lambda t: np.zeros(t.shape)), # constant
23+
(r"$x(t)=2t+1$", lambda t: 2*t + 1, lambda t: 2*np.ones(t.shape)), # affine
24+
(r"$x(t)=t^2-t+1$", lambda t: t**2 - t + 1, lambda t: 2*t - 1), # quadratic
25+
(r"$x(t)=\sin(3t)+1/2$", lambda t: np.sin(3*t) + 1/2, lambda t: 3*np.cos(3*t)), # sinuoidal
26+
(r"$x(t)=e^t\sin(5t)$", lambda t: np.exp(t)*np.sin(5*t), # growing sinusoidal
27+
lambda t: np.exp(t)*(5*np.cos(5*t) + np.sin(5*t))),
28+
(r"$x(t)=\frac{\sin(8t)}{(t+0.1)^{3/2}}$", lambda t: np.sin(8*t)/((t + 0.1)**(3/2)), # steep challenger
29+
lambda t: ((0.8 + 8*t)*np.cos(8*t) - 1.5*np.sin(8*t))/(0.1 + t)**(5/2))]
30+
31+
# Call both ways, with kwargs (new) and with params list with default options dict (legacy), to ensure both work
1732
diff_methods_and_params = [
33+
(first_order, None), (iterated_first_order, {'num_iterations':5}),
34+
(second_order, None),
1835
#(lineardiff, {'order':3, 'gamma':5, 'window_size':10, 'solver':'CVXOPT'}),
19-
(polydiff, {'polynomial_order':2, 'window_size':3}),
20-
(savgoldiff, {'polynomial_order':2, 'window_size':4, 'smoothing_win':4}),
21-
(spectraldiff, {'high_freq_cutoff':0.1})
36+
(polydiff, {'polynomial_order':2, 'window_size':3}), (polydiff, [2, 3]),
37+
(savgoldiff, {'polynomial_order':2, 'window_size':4, 'smoothing_win':4}), (savgoldiff, [2, 4, 4]),
38+
(spectraldiff, {'high_freq_cutoff':0.1}), (spectraldiff, [0.1])
2239
]
2340

24-
# Analytic (function, derivative) pairs on which to test differentiation methods.
25-
test_funcs_and_derivs = [
26-
(0, lambda t: np.ones(t.shape), lambda t: np.zeros(t.shape)), # x(t) = 1
27-
(1, lambda t: t, lambda t: np.ones(t.shape)), # x(t) = t
28-
(2, lambda t: 2*t + 1, lambda t: 2*np.ones(t.shape)), # x(t) = 2t+1
29-
(3, lambda t: t**2 - t + 1, lambda t: 2*t - 1), # x(t) = t^2 - t + 1
30-
(4, lambda t: np.sin(t) + 1/2, lambda t: np.cos(t))] # x(t) = sin(t) + 1/2
31-
3241
# All the testing methodology follows the exact same pattern; the only thing that changes is the
3342
# closeness to the right answer various methods achieve with the given parameterizations. So index a
3443
# big ol' table by the method, then the test function, then the pair of quantities we're comparing.
3544
error_bounds = {
36-
lineardiff: [[(1e-25, 1e-25)]*4]*len(test_funcs_and_derivs),
37-
polydiff: [[(1e-14, 1e-15), (1e-12, 1e-13), (1, 0.1), (100, 100)],
38-
[(1e-13, 1e-14), (1e-12, 1e-13), (1, 0.1), (100, 100)],
39-
[(1e-13, 1e-14), (1e-11, 1e-12), (1, 0.1), (100, 100)],
40-
[(1e-13, 1e-14), (1e-12, 1e-12), (1, 0.1), (100, 100)],
41-
[(1e-6, 1e-7), (0.001, 0.0001), (1, 0.1), (100, 100)]],
42-
savgoldiff: [[(1e-7, 1e-8), (1e-12, 1e-13), (1, 0.1), (100, 10)],
43-
[(1e-5, 1e-7), (1e-12, 1e-13), (1, 0.1), (100, 10)],
44-
[(1e-7, 1e-8), (1e-11, 1e-12), (1, 0.1), (100, 10)],
45-
[(0.1, 0.01), (0.1, 0.01), (1, 0.1), (100, 10)],
46-
[(0.01, 1e-3), (0.01, 1e-3), (1, 0.1), (100, 10)]],
47-
spectraldiff: [[(1e-7, 1e-8), ( 1e-25 , 1e-25), (1, 0.1), (100, 10)],
48-
[(0.1, 0.1), (10, 10), (1, 0.1), (100, 10)],
49-
[(0.1, 0.1), (10, 10), (1, 0.1), (100, 10)],
50-
[(1, 1), (100, 10), (1, 1), (100, 10)],
51-
[(0.1, 0.1), (10, 10), (1, 0.1), (100, 10)]]
45+
first_order: [[(-25, -25), (-25, -25), (0, 0), (1, 1)],
46+
[(-25, -25), (-14, -14), (0, 0), (1, 1)],
47+
[(-25, -25), (0, 0), (0, 0), (1, 0)],
48+
[(-25, -25), (0, 0), (0, 0), (1, 1)],
49+
[(-25, -25), (2, 2), (0, 0), (2, 2)],
50+
[(-25, -25), (3, 3), (0, 0), (3, 3)]],
51+
iterated_first_order: [[(-7, -7), (-10, -11), (0, -1), (0, 0)],
52+
[(-5, -5), (-5, -6), (0, -1), (0, 0)],
53+
[(-1, -1), (0, 0), (0, -1), (0, 0)],
54+
[(0, 0), (1, 1), (0, 0), (1, 1)],
55+
[(1, 1), (2, 2), (1, 1), (2, 2)],
56+
[(1, 1), (3, 3), (1, 1), (3, 3)]],
57+
second_order: [[(-25, -25), (-25, -25), (0, 0), (1, 1)],
58+
[(-25, -25), (-14, -14), (0, 0), (1, 1)],
59+
[(-25, -25), (-13, -14), (0, 0), (1, 1)],
60+
[(-25, -25), (0, -1), (0, 0), (1, 1)],
61+
[(-25, -25), (1, 1), (0, 0), (1, 1)],
62+
[(-25, -25), (3, 3), (0, 0), (3, 3)]],
63+
#lineardiff: [TBD when #91 is solved],
64+
polydiff: [[(-15, -15), (-14, -14), (0, -1), (1, 1)],
65+
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
66+
[(-14, -15), (-13, -14), (0, -1), (1, 1)],
67+
[(-2, -2), (0, 0), (0, -1), (1, 1)],
68+
[(0, 0), (2, 1), (0, 0), (2, 1)],
69+
[(0, 0), (3, 3), (0, 0), (3, 3)]],
70+
savgoldiff: [[(-7, -8), (-13, -14), (0, -1), (0, 0)],
71+
[(-7, -8), (-13, -13), (0, -1), (0, 0)],
72+
[(-1, -1), (0, -1), (0, -1), (0, 0)],
73+
[(0, -1), (0, 0), (0, -1), (1, 0)],
74+
[(1, 1), (2, 2), (1, 1), (2, 2)],
75+
[(1, 1), (3, 3), (1, 1), (3, 3)]],
76+
spectraldiff: [[(-7, -8), (-14, -15), (-1, -1), (0, 0)],
77+
[(0, 0), (1, 1), (0, 0), (1, 1)],
78+
[(1, 0), (1, 1), (1, 0), (1, 1)],
79+
[(0, 0), (1, 1), (0, 0), (1, 1)],
80+
[(1, 1), (2, 2), (1, 1), (2, 2)],
81+
[(1, 1), (3, 3), (1, 1), (3, 3)]]
5282
}
5383

5484

85+
@mark.filterwarnings("ignore::DeprecationWarning") # I want to test the old and new functionality intentionally
5586
@mark.parametrize("diff_method_and_params", diff_methods_and_params)
56-
@mark.parametrize("test_func_and_deriv", test_funcs_and_derivs)
57-
def test_diff_method(diff_method_and_params, test_func_and_deriv):
87+
def test_diff_method(diff_method_and_params, request):
5888
diff_method, params = diff_method_and_params # unpack
59-
i, f, df = test_func_and_deriv
6089

6190
# some methods rely on cvxpy, and we'd like to allow use of pynumdiff without convex optimization
6291
if diff_method in [lineardiff, velocity]:
6392
try: import cvxpy
6493
except: warn(f"Cannot import cvxpy, skipping {diff_method} test."); return
6594

66-
x = f(t) # sample the function
67-
x_noisy = x + noise # add a little noise
68-
dxdt = df(t) # true values of the derivative
95+
plot = request.config.getoption("--plot") # Get the plot flag from pytest configuration
96+
if plot: fig, axes = pyplot.subplots(len(test_funcs_and_derivs), 2, figsize=(12,7))
97+
98+
# loop over the test functions
99+
for i,(latex,f,df) in enumerate(test_funcs_and_derivs):
100+
x = f(t) # sample the function
101+
x_noisy = x + noise # add a little noise
102+
dxdt = df(t) # true values of the derivative at samples
69103

70-
# differentiate without and with noise
71-
x_hat, dxdt_hat = diff_method(x, dt, **params) if isinstance(params, dict) else diff_method(x, dt, params)
72-
x_hat_noisy, dxdt_hat_noisy = diff_method(x_noisy, dt, **params) if isinstance(params, dict) else diff_method(x_noisy, dt, params)
73-
74-
# check x_hat and x_hat_noisy are close to x and dxdt_hat and dxdt_hat_noisy are close to dxdt
75-
#print("]\n[", end="")
76-
for j,(a,b) in enumerate([(x,x_hat), (dxdt,dxdt_hat), (x,x_hat_noisy), (dxdt,dxdt_hat_noisy)]):
77-
l2_error = np.linalg.norm(a - b)
78-
linf_error = np.max(np.abs(a - b))
104+
# differentiate without and with noise
105+
x_hat, dxdt_hat = diff_method(x, dt, **params) if isinstance(params, dict) else diff_method(x, dt, params) \
106+
if isinstance(params, list) else diff_method(x, dt)
107+
x_hat_noisy, dxdt_hat_noisy = diff_method(x_noisy, dt, **params) if isinstance(params, dict) \
108+
else diff_method(x_noisy, dt, params) if isinstance(params, list) else diff_method(x_noisy, dt)
79109

80-
#print(f"({10 ** np.ceil(np.log10(l2_error)) if l2_error> 0 else 1e-25}, {10 ** np.ceil(np.log10(linf_error)) if linf_error > 0 else 1e-25})", end=", ")
81-
l2_bound, linf_bound = error_bounds[diff_method][i][j]
82-
assert np.linalg.norm(a - b) < l2_bound
83-
assert np.max(np.abs(a - b)) < linf_bound
110+
# check x_hat and x_hat_noisy are close to x and that dxdt_hat and dxdt_hat_noisy are close to dxdt
111+
print("]\n[", end="")
112+
for j,(a,b) in enumerate([(x,x_hat), (dxdt,dxdt_hat), (x,x_hat_noisy), (dxdt,dxdt_hat_noisy)]):
113+
l2_error = np.linalg.norm(a - b)
114+
linf_error = np.max(np.abs(a - b))
115+
116+
print(f"({int(np.ceil(np.log10(l2_error))) if l2_error> 0 else -25}, {int(np.ceil(np.log10(linf_error))) if linf_error > 0 else -25})", end=", ")
117+
#print(error_bounds[diff_method])
118+
#log_l2_bound, log_linf_bound = error_bounds[diff_method][i][j]
119+
# assert np.linalg.norm(a - b) < 10**log_l2_bound
120+
# assert np.max(np.abs(a - b)) < 10**log_linf_bound
121+
# if np.linalg.norm(a - b) < 10**(log_l2_bound - 1) or np.max(np.abs(a - b)) < 10**(log_linf_bound - 1):
122+
# print(f"Improvement detected for method {diff_method}")
123+
124+
if plot:
125+
axes[i, 0].plot(t, f(t), label=r"$x(t)$")
126+
axes[i, 0].plot(t, x, 'C0+')
127+
axes[i, 0].plot(tt, df(tt), label=r"$\frac{dx(t)}{dt}$")
128+
axes[i, 0].plot(t, dxdt_hat, 'C1+')
129+
axes[i, 0].set_ylabel(latex, rotation=0, labelpad=50)
130+
if i < len(test_funcs_and_derivs)-1: axes[i, 0].set_xticklabels([])
131+
else: axes[i, 0].set_xlabel('t')
132+
if i == 0: axes[i, 0].set_title('noiseless')
133+
axes[i, 1].plot(t, f(t), label=r"$x(t)$")
134+
axes[i, 1].plot(t, x_noisy, 'C0+')
135+
axes[i, 1].plot(tt, df(tt), label=r"$\frac{dx(t)}{dt}$")
136+
axes[i, 1].plot(t, dxdt_hat_noisy, 'C1+')
137+
if i < len(test_funcs_and_derivs)-1: axes[i, 1].set_xticklabels([])
138+
else: axes[i, 1].set_xlabel('t')
139+
axes[i, 1].set_yticklabels([])
140+
if i == 0: axes[i, 1].set_title('with noise')
141+
142+
if plot:
143+
axes[-1,-1].legend()
144+
pyplot.tight_layout()
145+
pyplot.show()

pynumdiff/tests/test_finite_difference.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)