-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
122 lines (105 loc) · 4.2 KB
/
utils.py
File metadata and controls
122 lines (105 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import tensorflow as tf
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
sns.set_color_codes()
class GlobalCounter:
def __init__(self, total_step, save_step, log_step):
self.counter = itertools.count(1)
self.cur_step = 0
self.cur_save_step = 0
self.total_step = total_step
self.save_step = save_step
self.log_step = log_step
def next(self):
self.cur_step = next(self.counter)
return self.cur_step
def should_save(self):
save = False
if (self.cur_step - self.cur_save_step) >= self.save_step:
save = True
self.cur_save_step = self.cur_step
return save
def should_log(self):
return (self.cur_step % self.log_step == 0)
def should_stop(self):
return (self.cur_step >= self.total_step)
def signal_handler(signal, frame):
print('You pressed Ctrl+C!')
def init_out_dir(base_dir, mode):
if not os.path.exists(base_dir):
os.mkdir(base_dir)
save_path = base_dir + '/model/'
if not os.path.exists(save_path):
os.mkdir(save_path)
if mode == 'train':
log_path = base_dir + '/log/'
if not os.path.exists(log_path):
os.mkdir(log_path)
elif mode == 'evaluate':
log_path = base_dir + '/evaluate/'
if not os.path.exists(log_path):
os.mkdir(log_path)
return save_path, log_path
def init_model_summary(algo):
policy_loss = tf.placeholder(tf.float32, [])
value_loss = tf.placeholder(tf.float32, [])
total_loss = tf.placeholder(tf.float32, [])
lr = tf.placeholder(tf.float32, [])
gradnorm = tf.placeholder(tf.float32, [])
if algo in ['a2c', 'ppo']:
entropy_loss = tf.placeholder(tf.float32, [])
beta = tf.placeholder(tf.float32, [])
if algo == 'ppo':
policy_kl = tf.placeholder(tf.float32, [])
clip_rate = tf.placeholder(tf.float32, [])
elif algo == 'ddpg':
gradnorm_v = tf.placeholder(tf.float32, [])
summaries = []
summaries.append(tf.summary.scalar('loss/policy', policy_loss))
summaries.append(tf.summary.scalar('loss/value', value_loss))
summaries.append(tf.summary.scalar('loss/total', total_loss))
summaries.append(tf.summary.scalar('train/lr', lr))
summaries.append(tf.summary.scalar('train/gradnorm', gradnorm))
if algo in ['a2c', 'ppo']:
summaries.append(tf.summary.scalar('loss/entropy', entropy_loss))
summaries.append(tf.summary.scalar('train/beta', beta))
if algo == 'a2c':
summary = tf.summary.merge(summaries)
return (summary, policy_loss, value_loss,
total_loss, lr, gradnorm, entropy_loss, beta)
summaries.append(tf.summary.scalar('train/policy_kl', policy_kl))
summaries.append(tf.summary.scalar('train/clip_rate', clip_rate))
summary = tf.summary.merge(summaries)
return (summary, policy_loss, value_loss,
total_loss, lr, gradnorm, entropy_loss,
beta, policy_kl, clip_rate)
elif algo == 'ddpg':
summaries.append(tf.summary.scalar('train/gradnorm_value', gradnorm_v))
summary = tf.summary.merge(summaries)
return (summary, policy_loss, value_loss,
total_loss, lr, gradnorm, gradnorm_v)
def plot_episode(actions, states, rewards, run, plot_path):
fig = plt.figure(figsize=(12, 18))
title = fig.suptitle('EPISODE RUN: %d' % run, fontsize='x-large')
plt.subplot(3, 1, 1)
for i in range(states.shape[1]):
plt.plot(states[:,i], 'o-', markersize=6, markeredgewidth=0,
label=('state_%d' % i))
plt.legend(fontsize=15, loc='best')
plt.yticks(fontsize=15)
plt.ylabel('Normalized states', fontsize=15)
plt.subplot(3, 1, 2)
plt.plot(actions, 'o-', markersize=12, markeredgewidth=0, linewidth=3)
plt.ylabel('Actions', fontsize=15)
plt.yticks(fontsize=15)
plt.subplot(3, 1, 3)
plt.plot(rewards, 'o-', markersize=12, markeredgewidth=0, linewidth=3)
plt.ylabel('Rewards', fontsize=15)
plt.yticks(fontsize=15)
fig.tight_layout()
title.set_y(0.95)
fig.subplots_adjust(top=0.9)
fig.savefig(plot_path + '/RUN' + str(run))