55
66from pynumdiff .utils import utility
77
8-
98# pylint: disable-msg=too-many-locals, too-many-arguments
109def plot (x , dt , x_hat , dxdt_hat , x_truth , dxdt_truth , xlim = None , show_error = True , markersize = 5 ):
1110 """Make comparison plots of 'x (blue) vs x_truth (black) vs x_hat (red)' and 'dxdt_truth
@@ -27,36 +26,54 @@ def plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth, xlim=None, show_error=True
2726 if xlim is None :
2827 xlim = [t [0 ], t [- 1 ]]
2928
30- fig = plt .figure (figsize = (18 , 6 ))
31- ax_x = fig .add_subplot (121 )
32- ax_dxdt = fig .add_subplot (122 )
29+ fig , axes = plt .subplots (1 , 2 , figsize = (18 , 6 ))
3330
34- ax_x .plot (t , x_truth , '--' , color = 'black' , linewidth = 3 , label = r"true $x$" )
35- ax_x .plot (t , x , '.' , color = 'blue' , zorder = - 100 , markersize = markersize , label = r"noisy data" )
36- ax_x .plot (t , x_hat , color = 'red' , label = r"estimated $\hat{x}$" )
37- ax_x .set_ylabel ('Position' , fontsize = 18 )
38- ax_x .set_xlabel ('Time' , fontsize = 18 )
39- ax_x .set_xlim (xlim [0 ], xlim [- 1 ])
40- ax_x .tick_params (axis = 'x' , labelsize = 15 )
41- ax_x .tick_params (axis = 'y' , labelsize = 15 )
42- ax_x .legend (loc = 'lower right' , fontsize = 12 )
43- ax_x .set_rasterization_zorder (0 )
44-
45- ax_dxdt .plot (t , dxdt_truth , '--' , color = 'black' , linewidth = 3 , label = r"true $\frac{dx}{dt}$" )
46- ax_dxdt .plot (t , dxdt_hat , color = 'red' , label = r"est. $\hat{\frac{dx}{dt}}$" )
47- ax_dxdt .set_ylabel ('Velocity' , fontsize = 18 )
48- ax_dxdt .set_xlabel ('Time' , fontsize = 18 )
49- ax_dxdt .set_xlim (xlim [0 ], xlim [- 1 ])
50- ax_dxdt .tick_params (axis = 'x' , labelsize = 15 )
51- ax_dxdt .tick_params (axis = 'y' , labelsize = 15 )
52- ax_dxdt .legend (loc = 'lower right' , fontsize = 12 )
53- ax_dxdt .set_rasterization_zorder (0 )
31+ axes [ 0 ] .plot (t , x_truth , '--' , color = 'black' , linewidth = 3 , label = r"true $x$" )
32+ axes [ 0 ] .plot (t , x , '.' , color = 'blue' , zorder = - 100 , markersize = markersize , label = r"noisy data" )
33+ axes [ 0 ] .plot (t , x_hat , color = 'red' , label = r"estimated $\hat{x}$" )
34+ axes [ 0 ] .set_ylabel ('Position' , fontsize = 18 )
35+ axes [ 0 ] .set_xlabel ('Time' , fontsize = 18 )
36+ axes [ 0 ] .set_xlim (xlim [0 ], xlim [- 1 ])
37+ axes [ 0 ] .tick_params (axis = 'x' , labelsize = 15 )
38+ axes [ 0 ] .tick_params (axis = 'y' , labelsize = 15 )
39+ axes [ 0 ] .legend (loc = 'lower right' , fontsize = 12 )
40+ axes [ 0 ] .set_rasterization_zorder (0 )
41+
42+ axes [ 1 ] .plot (t , dxdt_truth , '--' , color = 'black' , linewidth = 3 , label = r"true $\frac{dx}{dt}$" )
43+ axes [ 1 ] .plot (t , dxdt_hat , color = 'red' , label = r"est. $\hat{\frac{dx}{dt}}$" )
44+ axes [ 1 ] .set_ylabel ('Velocity' , fontsize = 18 )
45+ axes [ 1 ] .set_xlabel ('Time' , fontsize = 18 )
46+ axes [ 1 ] .set_xlim (xlim [0 ], xlim [- 1 ])
47+ axes [ 1 ] .tick_params (axis = 'x' , labelsize = 15 )
48+ axes [ 1 ] .tick_params (axis = 'y' , labelsize = 15 )
49+ axes [ 1 ] .legend (loc = 'lower right' , fontsize = 12 )
50+ axes [ 1 ] .set_rasterization_zorder (0 )
5451
5552 fig .tight_layout ()
5653
5754 if show_error :
5855 _ , _ , rms_dxdt = metrics (x , dt , x_hat , dxdt_hat , x_truth , dxdt_truth )
56+ R_sqr = error_correlation (dxdt_hat , dxdt_truth )
5957 print ('RMS error in velocity: ' , rms_dxdt )
58+ print ('Error correlation: ' , R_sqr )
59+
60+
61+ def plot_comparison (dt , dxdt_truth , dxdt_hat1 , title1 , dxdt_hat2 , title2 , dxdt_hat3 , title3 ):
62+ """This is intended to show method performances with different choices of parameter"""
63+ t = np .arange (0 , dt * len (dxdt_truth ), dt )
64+ fig , axes = plt .subplots (1 , 3 , figsize = (22 ,6 ))
65+
66+ for i ,(dxdt_hat ,title ) in enumerate (zip ([dxdt_hat1 , dxdt_hat2 , dxdt_hat3 ], [title1 , title2 , title3 ])):
67+ axes [i ].plot (t , dxdt_truth , '--' , color = 'black' , linewidth = 3 , label = r"true $\frac{dx}{dt}$" )
68+ axes [i ].plot (t , dxdt_hat , color = 'red' , label = r"est. $\hat{\frac{dx}{dt}}$" )
69+ if i == 0 : axes [i ].set_ylabel ('Velocity' , fontsize = 18 )
70+ axes [i ].set_xlabel ('Time' , fontsize = 18 )
71+ axes [i ].tick_params (axis = 'x' , labelsize = 15 )
72+ axes [i ].tick_params (axis = 'y' , labelsize = 15 )
73+ axes [i ].set_title (title , fontsize = 18 )
74+ if i == 2 : axes [i ].legend (loc = 'lower right' , fontsize = 12 )
75+
76+ fig .tight_layout ()
6077
6178
6279def metrics (x , dt , x_hat , dxdt_hat , x_truth = None , dxdt_truth = None , padding = 0 ):
0 commit comments