tensorflow 初始化新增变量,保持载入的预训练模型权重不变

我们经常会遇到这样一些问题,想要使用一些预训练好的模型,然后在其基础上进行一些增减,以适应新的任务。在开始训练之前,要对所有的新增变量进行初始化,但是要保持预训练模型中已有的权重不变,即只初始化新增变量。而如果使用tensorflow中提供的saver.restore(sess, ckpt_path),会报找不到新增节点的错误!记录一下这种情况要如何处理。

首先,在对预训练模型进行增减之前,先进行saver.restore(sess, ckpt_path)载入预训练权重,然后再进行对网络结构的增减(这里可以使用tf.get_default_graph().get_tensor_by_name(tensor_name)获取到原网络中的tensor,来进行新增节点)

在增加完新增节点之后,要初始化这些新增节点权重变量。接下来就是最关键的一步,获取网络中所有未初始化的权重变量。

def get_uninitialized_variables(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
print([str(i.name) for i in not_initialized_vars])
return not_initialized_vars

然后在会话中,初始化这些变量

sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))

接下来就可以愉快地进行训练了

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注