tensorflow 导入pb模型进行前向推导

tensorflow一般使用pb文件进行前向推导(在非部署环境使用ckpt也可以)载入pb文件到图的函数

def load_pb_to_graph(sess, pb_file):
with tf.gfile.FastGFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def,name="")

在会话中,调用该函数,并根据名称获取输入和输出的tensor,然后就可以sess.run进行前向推导

load_pb_to_graph(sess, "xxnet.pb")
inputs = tf.get_default_graph().get_tensor_by_name("xxnet/input:0")
outputs = tf.get_default_graph().get_tensor_by_name("xxnet/score:0")
scores = sess.run([outputs], feed_dict={inputs: input_data})

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注