-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_model.py
More file actions
60 lines (44 loc) · 1.75 KB
/
test_model.py
File metadata and controls
60 lines (44 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import tensorflow as tf
import scipy
import scipy.misc as misc
import matplotlib.pyplot as plt
from matplotlib import colors as mpl_colors
tf.app.flags.DEFINE_string(
'image', None, 'Test image')
FLAGS = tf.app.flags.FLAGS
palette = [(0.0, 0.0, 0.0), (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.5, 0.5, 0.0),
(0.0, 0.0, 0.5), (0.5, 0.0, 0.5), (0.0, 0.5, 0.5), (0.5, 0.5, 0.5),
(0.25, 0.0, 0.0), (0.75, 0.0, 0.0), (0.25, 0.5, 0.0), (0.75, 0.5, 0.0),
(0.25, 0.0, 0.5), (0.75, 0.0, 0.5), (0.25, 0.5, 0.5), (0.75, 0.5, 0.5),
(0.0, 0.25, 0.0), (0.5, 0.25, 0.0), (0.0, 0.75, 0.0), (0.5, 0.75, 0.0),
(0.0, 0.25, 0.5)]
my_cmap = mpl_colors.LinearSegmentedColormap.from_list('Custom cmap', palette, 21)
def main(_):
g = tf.Graph()
sess = tf.Session(graph=g)
with sess.graph.as_default():
graph_def = tf.GraphDef()
with open('./train/skynet_v1_50_graph.pb', 'rb') as file:
graph_def.ParseFromString(file.read())
tf.import_graph_def(graph_def, name="")
input_x = sess.graph.get_operation_by_name('ph_input_x').outputs[0]
pred = sess.graph.get_operation_by_name('predictions').outputs[0]
input_image_ori = scipy.misc.imread(FLAGS.image)
H, W = input_image_ori.shape[0], input_image_ori.shape[1]
input_image = scipy.misc.imresize(input_image_ori, (224, 224))
import time
before = time.time()
p = sess.run(pred, feed_dict={input_x: input_image})[0]
print(time.time() - before)
before = time.time()
p = sess.run(pred, feed_dict={input_x: input_image})[0]
print(time.time() - before)
fig = plt.figure()
ax = fig.add_subplot('121')
ax.imshow(input_image_ori)
ax = fig.add_subplot('122')
ax.imshow(p)
plt.show()
if __name__ == '__main__':
tf.app.run()