-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlag_detector.py
More file actions
122 lines (97 loc) · 3.96 KB
/
Copy pathlag_detector.py
File metadata and controls
122 lines (97 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Functions for detecting lag between signals
"""
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as signal
from xarray import DataArray
from pandas import date_range
def find_time_lag(t1, f1, t2, f2, time_step = 0.25, trim = 3600, interpolation_method = 'linear', verbose = False):
"""
Find the time lag between the two signals f1(t1) and f2(t2). The sign of the lag is
so that removing the lag from t1 will result in synchronized signals.
The lag is found by first interpolating the two signals to a common time grid, and then taking
argmax of the cross-correlation between interpolated f1 and f2.
parameters:
time_step : time step in seconds of the common time grid
trim : trim seconds are removed from the start and end of the signals before finding the lag.
"""
# Make time grid
tmin = np.max([t1[0], t2[0]]) + np.timedelta64(trim,'s')
tmax = np.min([t1[-1], t2[-1]]) - np.timedelta64(trim,'s')
time = date_range(start=tmin, end=tmax, freq=f'{time_step}s')
# xarray deals well with time interpolation, so we make a DataArray first and interpolate that
s1 = DataArray(data = f1, dims = 'time', coords = {'time' : t1})
s2 = DataArray(data = f2, dims = 'time', coords = {'time' : t2})
# If there are several time steps with the same time, it will cause errors when interpolate. Remove them!
s1 = remove_time_step_duplicates(s1)
s2 = remove_time_step_duplicates(s2)
# Interpolate to common time grid
s1_interp = s1.interp(time=time, method=interpolation_method)
s2_interp = s2.interp(time=time, method=interpolation_method)
if verbose:
# Original signals
s1.plot(label = 'f1')
s2.plot(label = 'f2')
plt.title('Original signals')
plt.legend()
plt.grid()
plt.show()
# Trimmed and interpolated signals
_plot_signals(s1_interp,s2_interp, 'Preprocessed signals')
# Find lag (in samples)
lag__samples = detect_lag(s1_interp.values, s2_interp.values, verbose = verbose)
# Convert lag to time unit
lag = np.timedelta64(int(time_step*1e9), 'ns') * lag__samples
if verbose:
s1_corrected = s1_interp.copy()
s1_corrected['time'] = s1_corrected['time'] - lag
_plot_signals(s1_corrected, s2_interp, 'Lag compensated', s1_label='f1 corrected')
print(f'Detected time lag : {lag}')
return lag
def remove_time_step_duplicates(da):
"""
Remove time step duplicates from a DataArray (da) with a time dimension called 'time'
"""
has_unique_time = np.concatenate([[True], da.time.diff(dim='time').values > np.timedelta64(1,'ns')])
return da.isel(time=has_unique_time)
def detect_lag(f1, f2, mode='same', verbose=False):
x = signal.correlate(f1-np.mean(f1),f2-np.mean(f2), mode=mode)
lags = signal.correlation_lags(f1.size,f2.size, mode=mode)
detected_lag = lags[np.argmax(x)]
if verbose:
plt.plot(lags,x)
plt.title('cross-correlation')
plt.xlabel('lag (samples)')
plt.grid()
plt.show()
print(f'Detected lag : {detected_lag} samples')
return detected_lag
def _plot_signals(s1,s2, title, s1_label = 'f1', s2_label = 'f2'):
N1 = len(s1.time.values)
N2 = len(s2.time.values)
zoom = (min(N1,N2) > 4000)
if zoom:
plt.figure(figsize=(8,3))
plt.subplot(1,2,1)
s1.plot(label = s1_label)
s2.plot(label = s2_label)
plt.grid()
plt.title(title)
plt.legend()
plt.subplot(1,2,2)
index = slice(int(min(N1,N2)/4), int(min(N1,N2)/4) + 2000)
s1.isel(time=index).plot(label = s1_label)
s2.isel(time=index).plot(label = s2_label)
plt.grid()
plt.title('zoomed')
plt.legend()
plt.show()
else:
s1.plot(label = s1_label)
s2.plot(label = s2_label)
plt.grid()
plt.title(title)
plt.legend()
plt.show()
return