-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdisplay.py
More file actions
85 lines (67 loc) · 3.26 KB
/
display.py
File metadata and controls
85 lines (67 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import bem.Experiments as Exp
import bem.Logger as Logger
from bem.utils_exp import *
import matplotlib.pyplot as plt
from script_utils import *
import PDMP.PDMPExperiment as pdmp_exp
from PDMP.NeptuneLogger import NeptuneLogger
SAVE_ANIMATION_PATH = './animation'
def display_exp(config_path):
args = parse_args()
# open and get parameters from file
p = FileHandler.get_param_from_config(config_path, args.config + '.yml')
update_parameters_before_loading(p, args)
# create experiment object. Specify directory to save and load checkpoints, experiment parameters, and potential logger object
checkpoint_dir = os.path.join('models', args.name)
# the ExpUtils class specifies how to hash the parameter dict, and what and how to initiliaze methods and models
exp = Exp.Experiment(checkpoint_dir=checkpoint_dir,
p=p,
logger = NeptuneLogger() if args.log else None,
exp_hash= pdmp_exp.exp_hash,
eval_hash=None, # will use default function
init_method_by_parameter= pdmp_exp.init_method_by_parameter,
init_models_by_parameter= pdmp_exp.init_models_by_parameter,
reset_models= pdmp_exp.reset_models)
exp.prepare()
additional_logging(exp, args)
# print parameters
exp.print_parameters()
print('Loading latest model')
exp.load()
update_experiment_after_loading(exp, args)
# exp.manager.method.sampler = 'BPS'
# exp.p['pdmp']['sampler'] = 'BPS' # HMC BPS ZigZag
# some information
run_info = [exp.p['data']['dataset'], exp.p['method'], exp.p['eval'][exp.p['method']]['reverse_steps'], exp.manager.total_steps]
# title = '{}, reverse_steps={}, training_steps={}'.format(*run_info[:3])
title=''
# display plot and animation, for a specific model
# limits = (-.5, 1.3) # None
if exp.p['data']['dataset'] == 'gmm_grid':
limits = (-0.5, 1.3) # None
else:
limits = (-1.1, 1.1) # None
anim = exp.manager.display_plots(ema_mu=None, # can specify ema rate, if such a model has been trained
plot_original_data=False,
title=title,
nb_datapoints=20000 if args.generate is None else args.generate, # number of points to display.
marker='.', # '.' marker displays pixel-wide points.
color='blue', # color of the points
xlim = limits, # x-axis limits
ylim = limits, # y-axis limits
alpha = 1.0,
forward=False, # display forward trajectory or backward
)
# save animation
path = os.path.join(SAVE_ANIMATION_PATH, '_'.join([str(x) for x in run_info] + ([exp.p['pdmp']['sampler']] if exp.p['method'] == 'pdmp' else [])))
anim.save(path + '.mp4')
print('Animation saved in {}'.format(path))
# stops the thread from continuing
# plt.show()
plt.close()
# close everything
exp.terminate()
if __name__ == '__main__':
config_path = 'PDMP/configs/'
display_exp(config_path)