-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_data.py
More file actions
100 lines (76 loc) · 3.69 KB
/
plot_data.py
File metadata and controls
100 lines (76 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from src.data.general import build_dataset
from src.data import vector_field as vf
import numpy as np
import diffrax
import jax.random as jrandom
import matplotlib
import matplotlib.pyplot as plt
# font = {'size': 18}
#
# matplotlib.rc('font', **font)
def main():
reaction = vf.BottleNeck(k_scale=0.0, open_index=1, flux=0.2)
key = jrandom.PRNGKey(5678)
data_key, model_key, loader_key, train_key, guess_key = jrandom.split(key, 5)
# Introduce noisy reaction rates for the training data, but not for the test data
k_prior_train = list(zip(reaction.loc_k, [reaction.k_scale] * len(reaction.loc_k)))
# Noise parameters
noise_dict = {'noise_type': "Gaussian",
'noise_scale': 0.1,
'time_fraction': 1.0,
'sample_fraction': 0.7}
# Get train, test and valuation data
train_key, val_key, test_key = jrandom.split(data_key, 3)
ts_train, xs_train, ys_train = build_dataset(1, reaction, k_prior_train, **noise_dict, key=train_key, seed=42,
solver='Tsit5')
# Get modelled vector field
args = tuple(reaction.loc_k)
solver = getattr(diffrax, 'Tsit5')()
dt0 = reaction.ts[1] - reaction.ts[0]
saveat = diffrax.SaveAt(ts=reaction.ts)
stepsize_controller = diffrax.PIDController(rtol=1e-9, atol=1e-9)
sol = diffrax.diffeqsolve(diffrax.ODETerm(reaction.modelled_vector_field), solver, reaction.ts[0], reaction.ts[-1],
dt0,
y0=ys_train[0, 0, :],
args=args,
stepsize_controller=stepsize_controller, max_steps=16 ** 5,
saveat=saveat, throw=True)
modelled_ys = sol.ys
if reaction.n_species == 3:
layout = [["A", "B"], ["C", '.']]
elif reaction.n_species == 4:
layout = [["A", "B"], ["C", 'D']]
else:
layout = [["A", "B"], ]
fig, axes = plt.subplot_mosaic(layout, figsize=(5, 2 * np.ceil(reaction.n_species / 2)))
species = ["A", "B", "C", "D"]
colours = ["tab:blue", "tab:orange", "tab:green", "tab:red"]
for i in range(min(ys_train.shape[-1], len(species))):
axes[species[i]].scatter(ts_train[0, :], xs_train[0, :, i], s=2, color='tab:orange', label="Observation")
axes[species[i]].plot(reaction.ts, ys_train[0, :, i], label="Ground truth")
axes[species[i]].plot(reaction.ts, modelled_ys[:, i], color="crimson", label="Modelled vector field",
linestyle='--')
axes[species[i]].set_title(f"Species {species[i]}")
axes[species[i]].grid()
fig.supxlabel("Time [A.U.]")
fig.supylabel("Concentration [A.U.]")
fig.tight_layout()
handles, labels = axes[species[0]].get_legend_handles_labels()
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
if reaction.n_species % 2 == 0:
fig.subplots_adjust(bottom=0.5 / (reaction.n_species // 2)) # bottom = 0.4...
fig.legend(*zip(*unique), loc='lower right')
else:
fig.legend(*zip(*unique), loc='center', bbox_to_anchor=(0.775, 0.3))
fig.savefig(rf"./{reaction}_{noise_dict['noise_type']}.pdf")
plt.close(fig)
fig, axes = plt.subplots(1, 1, figsize=(8, 4))
for i in range(reaction.n_species):
axes.plot(reaction.ts, ys_train[0, :, i], label=species[i])
axes.set_xlabel("Time [A.U.]")
axes.set_ylabel("Concentration [A.U.]")
axes.legend(title="Chemical species")
fig.savefig(rf"./{reaction}_ground_truth.pdf")
plt.close(fig)
if __name__ == '__main__':
main()