-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate_cond.py
More file actions
62 lines (51 loc) · 2.35 KB
/
evaluate_cond.py
File metadata and controls
62 lines (51 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import torch
from torch import distributions
from lib.dataloader import dataloader
from src.icnn import PICNN
from src.pcpmap import PCPMap
from src.mmd import mmd
parser = argparse.ArgumentParser('PCP-Map')
parser.add_argument('--resume', type=str, default="experiments/tabcond/...")
args = parser.parse_args()
# GPU Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load best model
checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
if __name__ == '__main__':
dataset = checkpt['args'].data
batch_size = checkpt['args'].batch_size
test_ratio = checkpt['args'].test_ratio
valid_ratio = checkpt['args'].valid_ratio
random_state = checkpt['args'].random_state
_, _, test_data, _ = dataloader(dataset, batch_size, test_ratio, valid_ratio, random_state)
# Load Best Models
print(checkpt['args'])
input_x_dim = checkpt['args'].input_x_dim
input_y_dim = checkpt['args'].input_y_dim
feature_dim = checkpt['args'].feature_dim
feature_y_dim = checkpt['args'].feature_y_dim
out_dim = checkpt['args'].out_dim
num_layers_pi = checkpt['args'].num_layers_pi
clip = checkpt['args'].clip
if clip is True:
reparam = False
else:
reparam = True
prior_picnn = distributions.MultivariateNormal(torch.zeros(input_x_dim).to(device), torch.eye(input_x_dim).to(device))
picnn = PICNN(input_x_dim, input_y_dim, feature_dim, feature_y_dim, out_dim, num_layers_pi, reparam=reparam).to(device)
pcpmap = PCPMap(prior_picnn, picnn)
pcpmap.load_state_dict(checkpt["state_dict_picnn"])
pcpmap = pcpmap.to(device)
# Obtain test metrics numbers
x_test = test_data[:, input_y_dim:].requires_grad_(True).to(device)
y_test = test_data[:, :input_y_dim].requires_grad_(True).to(device)
log_prob_picnn = pcpmap.loglik_picnn(x_test, y_test)
pb_mean_NLL = -log_prob_picnn.mean()
print('Mean Conditional Negative Log Likelihood: {:.3e}'.format(pb_mean_NLL.item()))
# Calculate MMD
zx = torch.randn(test_data.shape[0], input_x_dim).to(device)
x_generated, _ = pcpmap.gx(zx, test_data[:, :input_y_dim].to(device), tol=checkpt['args'].tol)
x_generated = x_generated.detach().to(device)
mean_max_dis = mmd(x_generated, test_data[:, input_y_dim:])
print('Maximum Mean Discrepancy: {:.3e}'.format(mean_max_dis))