-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautoencoder_utils.py
More file actions
135 lines (100 loc) · 3.33 KB
/
autoencoder_utils.py
File metadata and controls
135 lines (100 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import matplotlib.pyplot as plt
import numpy as np
import random
def plot_acc(x):
plt.plot(x[0])
plt.plot(x[1])
plt.plot(x[2])
def plot_gyro(x):
plt.plot(x[3])
plt.plot(x[4])
plt.plot(x[5])
def plot_sensors(x):
plt.subplot(1, 2, 1)
plot_acc(x)
plt.subplot(1, 2, 2)
plot_gyro(x)
def show_samples(X: np.ndarray, y: np.ndarray, n=10, is_random=False):
def i_or_random(i): return i if not is_random else random.randint(
0, X.shape[0])
indicies = list(map(i_or_random, range(n)))
for i in indicies:
x = X[i]
print(f"X[{i}]: {y[i]}")
plot_sensors(x)
plt.show()
def get_name(i):
"""
Return the name of i-th component of a sensor sample
"""
assert i >= 0 and i <= 5, f"Component {i} is not supported, must be between 0 and 5"
names = ["x_acc", "y_acc", "z_acc", "x_gyro", "y_gyro", "z_gyro"]
return names[i]
def plot_reconstruction_error(sample, reconstruction):
"""
Plot reconstruction error by diff for sensors
"""
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.title(get_name(i))
plt.plot(sample[i], 'b')
plt.plot(reconstruction[i], 'r')
plt.fill_between(
np.arange(125), reconstruction[i], sample[i], color='lightcoral')
plt.legend(labels=["sample", "reconstruction", "error"])
def plot_reconstructed_signal(sample, code, reconstruction):
"""
Show reconstucted acc/gyro signal splitted and visualize the code
"""
plt.subplot(1, 5, 1)
plt.title("acc original")
plot_acc(sample)
plt.subplot(1, 5, 2)
plt.title("gyro original")
plot_gyro(sample)
plt.subplot(1, 5, 3)
plt.title("code")
plt.imshow(code)
plt.subplot(1, 5, 4)
plt.title("acc reconstructed")
plot_acc(reconstruction)
plt.subplot(1, 5, 5)
plt.title("gyro reconstructed")
plot_gyro(reconstruction)
def show_loss(history):
h = history.history
if ("loss" in h):
plt.plot(h["loss"], label="loss")
if ("val_loss" in h):
plt.plot(h["val_loss"], label="val_loss")
if ("accuracy" in h):
plt.plot(h["accuracy"], label="accuracy")
if ("val_accuracy" in h):
plt.plot(h["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
def show_mse(autoencoder, X_test):
print("MSE =", autoencoder.evaluate(X_test, X_test, verbose=0))
def show_reconstructed_signals(X, encoder, decoder, n=10):
for i in range(n):
sample = X[i]
code = encoder.predict(sample[np.newaxis, :])[0]
reconstruction = decoder.predict(code[np.newaxis, :])[0]
# Resize and reshape code for plot
code_size_approx = int(np.ceil(np.sqrt(len(code))))
code = code.copy()
code.resize(np.power(code_size_approx, 2))
plt.figure(figsize=(12, 3))
plot_reconstructed_signal(sample=sample, code=code.reshape(
(code_size_approx, code_size_approx)), reconstruction=reconstruction)
plt.show()
def show_reconstruction_errors(X, encoder, decoder, n=10):
samples = X[:n]
codes = encoder.predict(samples)
reconstructions = decoder.predict(codes)
for i in range(n):
sample = samples[i]
reconstruction = reconstructions[i]
plt.figure(figsize=(12, 8))
plot_reconstruction_error(sample, reconstruction)
plt.show()