想要保存训练之后得到的神经网络参数,一般有两种办法。
第一种,可以将tensor对象转换为numpy数组进行保存。
即,
numpy.savetxt( 'weight.txt', weight.eval())
第二种,是利用tensorflow自带的Saver对象。
import tensorflow as tf ##################################################3 w1 = tf.Variable(tf.constant( 1.0), name = 'w1') w2 = tf.Variable(tf.constant( 2.0), name = 'w2') tf.add_to_collection( 'vars', w1) tf.add_to_collection( 'vars', w2) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) w1 = tf.add(w1, w2) saver.save(sess, './my-model.ckpt')
上面的代码中,创建了容器vars。它收集了tensor变量w1和w2。之后,tensorflow将这一容器保存。
在session中运行,就能将数据保存到tensorflow创建的几个文件中。
上面的代码运行结束后,当前目录下出现四个文件:
my-model.ckpt.meta
my-model.ckpt.data-*
my-model.ckpt.index
checkpoint
利用这四个文件就能恢复出 w1和w2这两个变量。
with tf.Session() as sess: new_saver = tf.train.import_meta_graph( 'my-model.ckpt.meta') new_saver.restore(sess, tf.train.latest_checkpoint( './')) all_vars = tf.get_collection( 'vars') print(all_vars) for v in all_vars: print(v) print(v.name) v_ = v.eval() # sess.run(v) print(v_)
运行结果为:
[ <tf.Tensor 'w1:0' shape=() dtype=float32_ref >, <tf.Tensor 'w2:0' shape=() dtype=float32_ref >] Tensor( "w1:0", shape =(), dtype =float32_ref) w1: 0 1.0 Tensor( "w2:0", shape =(), dtype =float32_ref) w2: 0 2.0