-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
20 lines (19 loc) · 715 Bytes
/
main.py
File metadata and controls
20 lines (19 loc) · 715 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from config import *
from Environment import *
from AutoPath import *
if __name__ == '__main__':
environment = Environment(args)
tf.reset_default_graph()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id)
with tf.device('/gpu:' + str(args.device_id)):
agent = AutoPath(environment.params, environment)
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
if os.path.exists(args.model_file + '.meta'):
saver.restore(sess, args.model_file)
else:
sess.run(tf.global_variables_initializer())
agent.train(sess)
saver.save(sess, args.model_file)
print('Node type accuracy: %f' % agent.accuracy(sess))
rank_lists = agent.plan(sess)