forked from arturs-berzins/sniROM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path99_plot_errors_over_features.py
More file actions
56 lines (46 loc) · 1.79 KB
/
99_plot_errors_over_features.py
File metadata and controls
56 lines (46 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Plot standardized errors over features for all components.
All features are standardized.
Observe the behavior of model and projection errors over the parameter space.
The errors are likely to be worse at the extremes of the space.
"""
# Author: Arturs Berzins <berzins@cats.rwth-aachen.de>
# License: BSD 3 clause
import config
import utils
from matplotlib import pyplot
dataset = 'test'
model_keys = ['RBF',
'GPR',
'FNN',
'' # projection error
]
df = utils.load_error_table(dataset)
P = len(config.mu_names)
fig, axes = pyplot.subplots(len(config.components),P,sharex='col',sharey='row')
fig.suptitle('Standardized errors over standardized features')
cmap = pyplot.get_cmap('tab10')
features = utils.load_features(dataset)
for idx_ax, component in enumerate(config.components):
L = config.num_basis[component]
df_filtered = df.loc[ (df['component']==component) &
(df['l']==L)]
for p in range(P):
xs = features[:,p]
for i, model_key in enumerate(model_keys):
ys = df_filtered[F'eps_pod{model_key.lower()}_sq'].values
ys = ys ** 0.5
# Plot transparent data due to bug in scatter limits
axes[idx_ax,p].plot(xs, ys, alpha=0)
axes[idx_ax,p].scatter(xs, ys, marker='o', s=2, color=cmap(i), label=F'POD-{model_key}')
# Label x axes
for p in range(P):
axes[len(config.components)-1,p].set_xlabel(config.mu_names[p])
# Label y axes
for idx_ax, component in enumerate(config.components):
axes[idx_ax,0].set_ylabel(F'{component}')
handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(model_keys))
fig.set_size_inches(w=6.3, h=4.3)
fig.subplots_adjust(bottom=.2)
pyplot.show()