-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
116 lines (109 loc) · 9.28 KB
/
plot.py
File metadata and controls
116 lines (109 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
rmse1 = [1.1536759477860183, 1.0145441354119316, 0.9907376607420599, 0.9805790741342425, 0.9747356460525013,
0.9709272391866042, 0.9682625358214296, 0.9663078589960344, 0.9648243932930523, 0.9636692031548864,
0.9627513315174138, 0.962010094677313, 0.9614034543220883, 0.9609013839795104, 0.9604818868090252,
0.9601285017592183, 0.9598286850999416, 0.9595727275961089, 0.9593530108070353, 0.959163484551355,
0.9589992923999938, 0.9588564985344533, 0.9587318854335531, 0.9586228019466326, 0.9585270477889841,
0.9584427847450657, 0.9583684677099628, 0.958302790639187, 0.9582446438210717, 0.9581930798320716,
0.9581472862099933]
mae1 = [0.9680487749999698, 0.8177142683525378, 0.7892246400617884, 0.7782112446532694, 0.7722556195911724,
0.7684645181178005, 0.7658849112017, 0.7640524059456165, 0.7626896747563826, 0.7616417105894817,
0.760819914279763, 0.7601710520796444, 0.7596401556446317, 0.7591968020649307, 0.7588286556057601,
0.758520588206389, 0.758258886154169, 0.7580357859745338, 0.7578463803729238, 0.7576829702912536,
0.7575412647044989, 0.7574183228806125, 0.7573116896004383, 0.7572178800355506, 0.7571354880917661,
0.7570630176316707, 0.7569994025981024, 0.7569431195968803, 0.7568933871164825, 0.7568493020717962,
0.7568101575591671]
train_t1 = 3976.579307794571
test_t1 = 257.3741672039032
rmse2 = [1.130663819468847, 0.9967488940718281, 0.9751883411431467, 0.9657762484900871, 0.9604185049396,
0.9569732965129525, 0.9545971862217242, 0.952880670155494, 0.9515983431856234, 0.9506154031541947,
0.9498463155693709, 0.9492343219522207, 0.9487404035422512, 0.9483369526030206, 0.9480039596816403,
0.9477266226888987, 0.9474937966957029, 0.9472969599231781, 0.9471295069747149, 0.9469862553322463,
0.946863094233122, 0.9467567306512342, 0.9466645027709814, 0.9465842411753137, 0.9465141642842141,
0.9464527987206219, 0.9463989180473567, 0.9463514951985925, 0.9463096652269327, 0.9462726958961716,
0.9462399642939429]
mae2 = [0.9489109899999697, 0.8011300855060084, 0.7743572389475285, 0.7643676391464437, 0.7589674370804881,
0.7555241116261486, 0.7532066888038513, 0.7515618844626578, 0.7503445994262334, 0.7494092136381095,
0.7486746377229708, 0.7480988968597384, 0.7476407154186336, 0.747274346762701, 0.7469725056626797,
0.7467198553786225, 0.7465060698127461, 0.7463250804001469, 0.7461702163529744, 0.7460381133449641,
0.7459256757192527, 0.7458289783829136, 0.7457449142400786, 0.7456717194688487, 0.745607562807898,
0.7455516560943718, 0.7455023216577882, 0.7454589930245075, 0.7454205008168759, 0.7453864274297204,
0.7453565671813632]
train_t2 = 3887.5170197486877
test_t2 = 315.92051792144775
rmse3 = [1.1115822805381261, 0.9872572053166876, 0.9660461586401607, 0.9570661775585745, 0.9521242432797,
0.9490478974641087, 0.9469867181496228, 0.9455343396750076, 0.9444720834254274, 0.943672395210918,
0.943056306525591, 0.9425726015942038, 0.9421867923176781, 0.9418749218001755, 0.9416199078367264,
0.9414092957563649, 0.9412338288131523, 0.9410865111297433, 0.9409619772939836, 0.9408560584484414,
0.9407654775338888, 0.9406876313714545, 0.9406204323268068, 0.9405621916051775, 0.940511532113644,
0.9404673226356804, 0.9404286275740161, 0.9403946682051855, 0.9403647925407412, 0.9403384516887215,
0.9403151811704664]
mae3 = [0.9306039450000336, 0.7933684972467707, 0.7684930077368486, 0.7587745019062715, 0.753763720608649,
0.750703125669192, 0.7486552601040863, 0.7472233080542106, 0.7461776256560871, 0.7453946682460785,
0.7447880465319254, 0.7443110678654029, 0.743922361190643, 0.7436029605064022, 0.7433413242918279,
0.7431245434801039, 0.7429445995531668, 0.7427928246540892, 0.7426636014907047, 0.7425537504962931,
0.7424602251627683, 0.742380770524795, 0.7423118852070887, 0.7422521170327343, 0.7421996435318209,
0.7421537059865947, 0.7421134779636839, 0.7420782868396668, 0.7420479028205452, 0.7420211180675816,
0.741997363406038]
train_t3 = 3867.6804897785187
test_t3 = 305.06651973724365
rmse4 = [1.113293684369096, 0.9861307554784009, 0.9627210871525096, 0.953522402981341, 0.9486582437889812,
0.9456859157313968, 0.943715789665698, 0.9423399667708697, 0.9413429709497385, 0.9405999618761376,
0.9400337917436954, 0.939594445396018, 0.9392482639676394, 0.9389719025546196, 0.9387487603385484,
0.9385667808129617, 0.9384170482032272, 0.9382928651022217, 0.938189130778636, 0.9381019128075938,
0.9380281461469061, 0.9379654181018025, 0.9379118123148891, 0.9378657940342169, 0.9378261247036693,
0.9377917976782841, 0.937761989353627, 0.9377360216728704, 0.9377133331207405, 0.9376934561090994,
0.9376759992186364]
mae4 = [0.9361313950000011, 0.7956044769570476, 0.7689157583241076, 0.758786916542721, 0.7535078646216137,
0.7503126513738405, 0.748190584579093, 0.7467049567660874, 0.7456213453789367, 0.7447960816036492,
0.7441555545304599, 0.743648955723774, 0.7432359276197689, 0.7429006790418868, 0.7426307541937948,
0.7424103154233594, 0.7422250707536595, 0.7420678087936647, 0.741934717827743, 0.7418208845432527,
0.7417214586498396, 0.741635840731647, 0.7415616422344901, 0.7414966861688429, 0.7414399177987939,
0.7413904006297434, 0.7413471573174949, 0.7413092674239955, 0.7412756635229549, 0.7412459908130298,
0.7412197015635031]
train_t4 = 4041.0517807006836
test_t4 = 284.2231683731079
rmse5 = [1.1186750576778326, 0.995739224983253, 0.9715736250314723, 0.9610100351786536, 0.9550025427127976,
0.951144429311716, 0.9484882804792256, 0.9465733762402003, 0.9451459345336305, 0.9440542264180649,
0.9432020439366742, 0.9425255894745074, 0.9419810291170392, 0.9415373624524471, 0.9411721320554572,
0.9408687352840317, 0.9406146819605454, 0.9404004320923725, 0.9402186010711646, 0.9400634043504668,
0.9399302620805075, 0.9398155129238936, 0.9397162038341114, 0.9396299335915805, 0.9395547349647224,
0.9394889850041829, 0.939431336081835, 0.9393806623980152, 0.9393360181406334, 0.9392966045037211,
0.9392617435002186]
mae5 = [0.9399372992262457, 0.805121886116225, 0.7772803621927094, 0.7662072548189801, 0.7602358331688083,
0.7565145242873201, 0.7539783550129365, 0.7521582895574418, 0.7508168147776526, 0.749789910117719,
0.7489890604491699, 0.7483501683430355, 0.7478279583339468, 0.7474017953035297, 0.7470522241576161,
0.7467610881949126, 0.7465165897397512, 0.7463099405242846, 0.7461334634163997, 0.7459803992864276,
0.7458480793598574, 0.745733022991386, 0.7456332823735723, 0.7455465072170416, 0.7454718863456318,
0.745407219618233, 0.7453506345781998, 0.7453009099719549, 0.7452572394324715, 0.7452188051022096,
0.7451849927292644]
train_t5 = 4064.9781572818756
test_t5 = 256.07301926612854
import matplotlib.pyplot as plt
import numpy as np
base_line_rmse =[1.1536759477860183,1.0145441354119316, 0.9907376607420599, 0.9805790741342425, 0.9747356460525013, 0.9709272391866042, 0.9682625358214296, 0.9663078589960344, 0.9648243932930523, 0.9636692031548864, 0.9627513315174138, 0.962010094677313, 0.9614034543220883, 0.9609013839795104, 0.9604818868090252, 0.9601285017592183, 0.9598286850999416, 0.9595727275961089, 0.9593530108070353, 0.959163484551355, 0.9589992923999938, 0.9588564985344533, 0.9587318854335531, 0.9586228019466326, 0.9585270477889841, 0.9584427847450657, 0.9583684677099628, 0.958302790639187, 0.9582446438210717, 0.9581930798320716, 0.9581472862099933]
base_line_mae =[0.9680487749999698,0.8177142683525378, 0.7892246400617884, 0.7782112446532694, 0.7722556195911724, 0.7684645181178005, 0.7658849112017, 0.7640524059456165, 0.7626896747563826, 0.7616417105894817, 0.760819914279763, 0.7601710520796444, 0.7596401556446317, 0.7591968020649307, 0.7588286556057601, 0.758520588206389, 0.758258886154169, 0.7580357859745338, 0.7578463803729238, 0.7576829702912536, 0.7575412647044989, 0.7574183228806125, 0.7573116896004383, 0.7572178800355506, 0.7571354880917661, 0.7570630176316707, 0.7569994025981024, 0.7569431195968803, 0.7568933871164825, 0.7568493020717962, 0.7568101575591671]
x = np.arange(0, 31)
rmse_list = [rmse1, rmse2, rmse3, rmse4, rmse5]
mae_list = [mae1, mae2, mae3, mae4, mae5]
rmse_list = np.array(rmse_list)
mae_list = np.array(mae_list)
rmse = rmse_list.sum(axis=0) / 5
mae = mae_list.sum(axis=0) / 5
plt.plot(x, rmse, color="red", label="SVD - RMSE loss", marker=".")
plt.plot(x, mae, color="green", label='SVD - MAE loss', marker="P")
plt.plot(base_line_rmse,color="yellow",label="baseline - RMSE loss")
plt.plot(x,base_line_mae,color="blue",label="baseline - MAE loss")
plt.legend()
plt.grid(ls='--')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title("curve of loss relative to training iteration")
plt.show()
print(mae,rmse)
# Results including MAE/RMSE/Training Time/Test Time by 5-fold cross validation
total_train_t = train_t1+train_t2+train_t3+train_t4+train_t5
total_test_t = test_t1+test_t2+test_t3+test_t4+test_t5
print("total training time\t=\t{}s\t\taverage training time for one epoch\t= {}s\n"
"total test time\t\t=\t{}s\t\taverage test time for MAE and RMSE\t= {}s".format(
round(total_train_t,3),round(total_train_t/150,3),round(total_test_t,3),round(total_test_t/150,3)
))