tensorflow 训练网络的一般步骤

本文不针对tensorflow2.0。首先要构建数据的输入,一般是将数据转化为pb格式

然后构建自己的网络,并构建损失函数的节点。构建网络有多种方式,可以用代码构建(利用slim、keras等高级api,或者基础的api,或者已有的代码),也可以从ckpt.meta中载入网络结构(断点继续训练等情况)tf.train.import_meta_graph(“xxx.ckpt.meta”)。这里要注意,一般训练时会同时进行网络在验证集上的测试,比如每训练n步后在训练集上进行测试。因此构建网络需要同时构建一个验证网络,共享训练网络的变量权重。构建验证网络时要在variable_scope中设置reuse=True。

定义优化器,如opt=tf.train.AdamOptimizer()
将优化器应用在损失节点上计算梯度。grads=opt.compute_gradients(L)
梯度下降优化节点 apply_grad_op = opt.apply_gradients(grads)

训练模型需要保存,定义一个saver
saver = tf.train.Saver(max_to_keep=10) 最多保留10个ckpt
在训练时,使用saver.save(sess, “xxx.ckpt”, global_step=step)保存ckpt文件

希望在训练时看到训练过程, 使用tf.summary.scalar 添加想要的变量到训练过程日志中。
如 tf.summary.scalar(“training loss”, L)添加训练损失到训练过程。然后定义summary_op = tf.summary.merge_all()
然后要定义一个summary_writer
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
训练时,每隔n步,使用summary_writer.add_summary(sess.run(summary_op), step)保存训练过程日志
训练开始后,就可以使用tensorboard查看训练过程了

训练过程一般在一个for循环中进行,
sess.run([apply_grad_op])进行网络的训练
在这个循环中,还要进行上面所说的保存ckpt文件、训练日志

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注