This repository was archived by the owner on Sep 20, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsave_model.py
More file actions
49 lines (35 loc) · 1.58 KB
/
save_model.py
File metadata and controls
49 lines (35 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Converts checkpoint variables into Const ops in standalone GraphDef file."""
import os
import tensorflow as tf
def freeze_graph(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
"""
# Retrieve checkpoint fullpath
ckpt = tf.train.get_checkpoint_state(model_dir)
input_ckpt = ckpt.model_checkpoint_path
# precise the file fullname of freezed graph
output_graph = "cnn_files/frozen_model.pb"
# Clear devices to allow TF to control which device loads operations
clear_devices = True
# Start session using a temp Graph
with tf.Session(graph=tf.Graph()) as sess:
# import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(input_ckpt + '.meta', clear_devices=clear_devices)
# Restore weights
saver.restore(sess, input_ckpt)
# Use built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # graph_def used to retrieve nodes
output_node_names.split(",") # output node names are used to select usefull nodes
)
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." %len(output_graph_def.node))
return output_graph_def
def main():
freeze_graph('cnn_files', 'softmax_linear/softmax_linear,shuffle_batch')
if __name__ == "__main__":
main()