tensorflow训练时将模型保存为ckpt文件,它包含了网络结构、网络权重、训练过程中间变量等等信息。而网络部署一般是使用pb文件,它将变量保存为常量,以及网络前向传播的所有必要结构。如何将ckpt文件导出为pb文件?
首先,使用tfrecord训练的ckpt一般包含读取训练tfrecord文件的结构,而这是pb文件所不需要的。pb文件通常使用placeholder接受输入。因此,要以placeholder为输入重新定义一遍网络结构(通常就是调用一次网络构建函数)。假设为
output = xxnet(input_placeholder)
要获取输出节点的名称
output_nd_name = output.op.name
然后,载入ckpt的权重
saver = tf.train.Saver()
saver.restore(sess, “xxnet.ckpt”)
然后,将其中的变量转化为常量,保存模型
out_graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph_def,
output_node_names=[output_nd_name]
)
with tf.gfile.GFile("xxnet.pb","wb") as f:
f.write(out_graph_def.SerializeToString())
