Skip to content

Commit 18e78d2

Browse files
authored
Merge pull request #154 from florisvb/efficient-slide-function
more efficient slide function
2 parents e03f5f8 + 1354dd3 commit 18e78d2

2 files changed

Lines changed: 21 additions & 20 deletions

File tree

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
137137
[(0, 0), (1, 1), (0, 0), (1, 1)],
138138
[(1, 0), (2, 2), (1, 0), (2, 2)],
139139
[(1, 0), (3, 3), (1, 0), (3, 3)]],
140-
polydiff: [[(-14, -15), (-14, -14), (0, -1), (1, 1)],
140+
polydiff: [[(-14, -15), (-13, -14), (0, -1), (1, 1)],
141141
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
142142
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
143143
[(-2, -2), (0, 0), (0, -1), (1, 1)],
@@ -179,11 +179,11 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
179179
[(0, 0), (1, 0), (0, -1), (1, 0)],
180180
[(1, 1), (2, 2), (1, 1), (2, 2)],
181181
[(1, 1), (3, 3), (1, 1), (3, 3)]],
182-
jerk_sliding: [[(-15, -15), (-16, -16), (0, -1), (1, 0)],
182+
jerk_sliding: [[(-25, -25), (-16, -16), (0, -1), (1, 0)],
183183
[(-14, -14), (-14, -14), (0, -1), (0, 0)],
184184
[(-14, -14), (-14, -14), (0, -1), (0, 0)],
185-
[(-1, -1), (0, 0), (0, -1), (1, 0)],
186-
[(0, 0), (2, 2), (0, 0), (2, 2)],
185+
[(-1, -1), (0, 0), (0, -1), (0, 0)],
186+
[(1, 0), (2, 2), (1, 0), (2, 2)],
187187
[(1, 1), (3, 3), (1, 1), (3, 3)]],
188188
constant_velocity: [[(-25, -25), (-25, -25), (0, -1), (1, 1)],
189189
[(-4, -5), (-3, -3), (0, -1), (1, 1)],

pynumdiff/utils/utility.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,26 +181,27 @@ def slide_function(func, x, dt, kernel, *args, stride=1, pass_weights=False, **k
181181
if len(kernel) % 2 == 0: raise ValueError("Kernel window size should be odd.")
182182
half_window_size = (len(kernel) - 1)//2 # int because len(kernel) is always odd
183183

184-
weights = np.zeros((int(np.ceil(len(x)/stride)), len(x))) # Could be more space efficient
185-
x_hats = np.zeros(weights.shape)
186-
dxdt_hats = np.zeros(weights.shape)
184+
x_hat = np.zeros(x.shape)
185+
dxdt_hat = np.zeros(x.shape)
186+
weight_sum = np.zeros(x.shape)
187187

188188
for i,midpoint in enumerate(range(0, len(x), stride)): # iterate window midpoints
189189
# find where to index data and kernel, taking care at edges
190-
window = slice(max(0, midpoint - half_window_size),
191-
min(len(x), midpoint + half_window_size + 1)) # +1 because slicing is exclusive of end
192-
kslice = slice(max(0, half_window_size - midpoint),
193-
min(len(kernel), len(kernel) - (midpoint + half_window_size + 1 - len(x))))
190+
start = max(0, midpoint - half_window_size)
191+
end = min(len(x), midpoint + half_window_size + 1) # +1 because slicing is exclusive of end
192+
window = slice(start, end)
194193

195-
# weights need to be renormalized if running off an edge
196-
weights[i, window] = kernel if kslice.stop - kslice.stop == len(kernel) else kernel[kslice]/np.sum(kernel[kslice])
197-
if pass_weights: kwargs['weights'] = weights[i, window]
194+
kstart = max(0, half_window_size - midpoint)
195+
kend = kstart + (end - start)
196+
kslice = slice(kstart, kend)
198197

199-
# run the function on the window and save results
200-
x_hats[i,window], dxdt_hats[i,window] = func(x[window], dt, *args, **kwargs)
198+
w = kernel if (end-start) == len(kernel) else kernel[kslice]/np.sum(kernel[kslice])
199+
if pass_weights: kwargs['weights'] = w
201200

202-
weights /= weights.sum(axis=0, keepdims=True) # normalize the weights
203-
x_hat = np.sum(weights*x_hats, axis=0)
204-
dxdt_hat = np.sum(weights*dxdt_hats, axis=0)
201+
# run the function on the window and add weighted results to cumulative answers
202+
x_window_hat, dxdt_window_hat = func(x[window], dt, *args, **kwargs)
203+
x_hat[window] += w * x_window_hat
204+
dxdt_hat[window] += w * dxdt_window_hat
205+
weight_sum[window] += w # save sum of weights for normalization at the end
205206

206-
return x_hat, dxdt_hat
207+
return x_hat/weight_sum, dxdt_hat/weight_sum

0 commit comments

Comments
 (0)