-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate_joint.py
More file actions
123 lines (103 loc) · 4.65 KB
/
evaluate_joint.py
File metadata and controls
123 lines (103 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os
import torch
from torch import distributions
from torch.utils.data import DataLoader
from lib.dataloader import dataloader
from datasets import tabular_data
from src.plotter import plot4_tabular
from src.icnn import FICNN, PICNN
from src.mapficnn import MapFICNN
from src.pcpmap import PCPMap
from src.mmd import mmd
from lib.utils import AverageMeter
parser = argparse.ArgumentParser('PCP-Map')
parser.add_argument('--resume', type=str, default="experiments/tabjoint/...")
args = parser.parse_args()
# GPU Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load best model
checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
def load_data(dataset):
if dataset == 'wt_wine':
data = tabular_data.get_wt_wine()
elif dataset == 'rd_wine':
data = tabular_data.get_rd_wine()
elif dataset == 'parkinson':
data = tabular_data.get_parkinson()
else:
raise Exception("Dataset is Incorrect")
return data
if __name__ == '__main__':
dataset = checkpt['args'].data
batch_size = checkpt['args'].batch_size
test_ratio = checkpt['args'].test_ratio
valid_ratio = checkpt['args'].valid_ratio
random_state = checkpt['args'].random_state
_, _, test_data, _ = dataloader(dataset, batch_size, test_ratio, valid_ratio, random_state)
data = load_data(dataset)
data = tabular_data.process_data(data)
data = tabular_data.normalize_data(data)
dat = torch.tensor(data, dtype=torch.float32)
# Load Best Models
print(checkpt['args'])
input_x_dim = checkpt['args'].input_x_dim
input_y_dim = checkpt['args'].input_y_dim
feature_dim = checkpt['args'].feature_dim
feature_y_dim = checkpt['args'].feature_y_dim
out_dim = checkpt['args'].out_dim
num_layers_fi = checkpt['args'].num_layers_fi
num_layers_pi = checkpt['args'].num_layers_pi
clip = checkpt['args'].clip
if clip is True:
reparam = False
else:
reparam = True
prior_ficnn = distributions.MultivariateNormal(torch.zeros(input_y_dim).to(device),
torch.eye(input_y_dim).to(device))
prior_picnn = distributions.MultivariateNormal(torch.zeros(input_x_dim).to(device),
torch.eye(input_x_dim).to(device))
ficnn = FICNN(input_y_dim, feature_dim, out_dim, num_layers_fi, reparam=reparam).to(device)
picnn = PICNN(input_x_dim, input_y_dim, feature_dim, feature_y_dim, out_dim, num_layers_pi, reparam=reparam).to(device)
map_ficnn = MapFICNN(prior_ficnn, ficnn)
map_picnn = PCPMap(prior_picnn, picnn)
map_ficnn.load_state_dict(checkpt["state_dict_ficnn"])
map_picnn.load_state_dict(checkpt["state_dict_picnn"])
map_ficnn = map_ficnn.to(device)
map_picnn = map_picnn.to(device)
# load test data
test_loader = DataLoader(
test_data,
batch_size=batch_size, shuffle=True
)
# Obtain Test Metrics Numbers
testLossMeter = AverageMeter()
for test_sample in test_loader:
x_test = test_sample[:, input_y_dim:].requires_grad_(True).to(device)
y_test = test_sample[:, :input_y_dim].requires_grad_(True).to(device)
log_prob1 = map_ficnn.loglik_ficnn(y_test)
log_prob2 = map_picnn.loglik_picnn(x_test, y_test)
pb_mean_NLL = -(log_prob1 + log_prob2).mean()
testLossMeter.update(pb_mean_NLL.item(), test_sample.shape[0])
print('Mean Negative Log Likelihood: {:.3e}'.format(testLossMeter.avg))
# Gaussian Pullback
x_test_tot = test_data[:, input_y_dim:].requires_grad_(True).to(device)
y_test_tot = test_data[:, :input_y_dim].requires_grad_(True).to(device)
zy_approx = map_ficnn.gyinv(y_test_tot).detach()
zx_approx = map_picnn.gxinv(x_test_tot, y_test_tot).detach()
z = torch.cat((zy_approx, zx_approx), dim=1)
# Test Generated Samples
sample_size = dat.shape[0]
zy = torch.randn(sample_size, input_y_dim).to(device)
zx = torch.randn(sample_size, input_x_dim).to(device)
y_generated, _ = map_ficnn.gy(zy, tol=checkpt['args'].tol)
y_generated = y_generated.detach().to(device)
x_generated, _ = map_picnn.gx(zx, y_generated, tol=checkpt['args'].tol)
x_generated = x_generated.detach().to(device)
sample = torch.cat((y_generated, x_generated), dim=1)
# calculate MMD statistic
mean_max_dis = mmd(sample, dat)
print('Maximum Mean Discrepancy: {:.3e}'.format(mean_max_dis))
# plot figures and save
sPath = os.path.join(checkpt['args'].save, 'figs', checkpt['args'].data + '_{:03d}.png')
plot4_tabular(dataset, z, sample, sPath, sTitle=dataset + '_visualizations', hidevals=True)