forked from nnzhan/Graph-WaveNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
123 lines (98 loc) · 4.54 KB
/
Copy pathtest.py
File metadata and controls
123 lines (98 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import util
import argparse
from model import *
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
parser = argparse.ArgumentParser()
parser.add_argument('--device',type=str,default='cuda:0',help='')
parser.add_argument('--data',type=str,default='data/METR-LA',help='data path')
parser.add_argument('--adjdata',type=str,default='data/sensor_graph/adj_mx.pkl',help='adj data path')
parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type')
parser.add_argument('--gcn_bool',action='store_true',help='whether to add graph convolution layer')
parser.add_argument('--aptonly',action='store_true',help='whether only adaptive adj')
parser.add_argument('--addaptadj',action='store_true',help='whether add adaptive adj')
parser.add_argument('--randomadj',action='store_true',help='whether random initialize adaptive adj')
parser.add_argument('--seq_length',type=int,default=12,help='')
parser.add_argument('--nhid',type=int,default=32,help='')
parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension')
parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes')
parser.add_argument('--batch_size',type=int,default=64,help='batch size')
parser.add_argument('--learning_rate',type=float,default=0.001,help='learning rate')
parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate')
parser.add_argument('--weight_decay',type=float,default=0.0001,help='weight decay rate')
parser.add_argument('--checkpoint',type=str,help='')
parser.add_argument('--plotheatmap',type=str,default='True',help='')
args = parser.parse_args()
def main():
device = torch.device(args.device)
_, _, adj_mx = util.load_adj(args.adjdata,args.adjtype)
supports = [torch.tensor(i).to(device) for i in adj_mx]
if args.randomadj:
adjinit = None
else:
adjinit = supports[0]
if args.aptonly:
supports = None
model = gwnet(device, args.num_nodes, args.dropout, supports=supports,
gcn_bool=args.gcn_bool, addaptadj=args.addaptadj, aptinit=adjinit, in_dim=args.in_dim)
model.to(device)
model.load_state_dict(torch.load(args.checkpoint))
model.eval()
print('model load successfully')
dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size)
scaler = dataloader['scaler']
outputs = []
realy = torch.Tensor(dataloader['y_test']).to(device)
realy = realy.transpose(1,3)[:,0,:,:]
for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()):
testx = torch.Tensor(x).to(device)
testx = testx.transpose(1,3)
with torch.no_grad():
preds = model(testx).transpose(1,3)
outputs.append(preds.squeeze())
yhat = torch.cat(outputs,dim=0)
yhat = yhat[:realy.size(0),...]
amae = []
amape = []
armse = []
for i in range(12):
pred = scaler.inverse_transform(yhat[:,:,i])
real = realy[:,:,i]
metrics = util.metric(pred,real)
log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
print(log.format(i+1, metrics[0], metrics[1], metrics[2]))
amae.append(metrics[0])
amape.append(metrics[1])
armse.append(metrics[2])
log = 'On average over 12 horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
print(log.format(np.mean(amae),np.mean(amape),np.mean(armse)))
if args.plotheatmap == "True":
adp = F.softmax(F.relu(torch.mm(model.nodevec1, model.nodevec2)), dim=1)
device = torch.device('cpu')
adp.to(device)
adp = adp.cpu().detach().numpy()
adp = adp*(1/np.max(adp))
df = pd.DataFrame(adp)
sns.heatmap(df, cmap="RdYlBu")
plt.savefig("./emb"+ '.pdf')
# y12 = realy[:,18,1].cpu().detach().numpy()
# yhat12 = scaler.inverse_transform(yhat[:,18,1]).cpu().detach().numpy()
# y3 = realy[:,18,1].cpu().detach().numpy()
# yhat3 = scaler.inverse_transform(yhat[:,18,1]).cpu().detach().numpy()
result_dict = {}
for i in range(args.num_nodes):
y = realy[:,i,11].cpu().detach().numpy()
y_1 = scaler.inverse_transform(yhat[:,i,11]).cpu().detach().numpy()
result_dict['real' + str(i)] = y
result_dict['pred' + str(i)] = y_1
df2 = pd.DataFrame(result_dict)
if args.data[-2] == '-':
days = args.data[-1]
else:
days = args.data[-2:]
map_name = args.data[5:].split('-')[0]
df2.to_csv(f'./{map_name}-wave.csv',index=False)
if __name__ == "__main__":
main()