在tensorflow中保存模型参数

    xiaoxiao2021-03-26  20

    想要保存训练之后得到的神经网络参数,一般有两种办法。

    第一种,可以将tensor对象转换为numpy数组进行保存。

    即,

    numpy.savetxt( 'weight.txt', weight.eval())

    第二种,是利用tensorflow自带的Saver对象。

    import tensorflow  as tf ##################################################3 w1  = tf.Variable(tf.constant( 1.0),  name = 'w1') w2  = tf.Variable(tf.constant( 2.0),  name = 'w2') tf.add_to_collection( 'vars', w1) tf.add_to_collection( 'vars', w2) saver  = tf.train.Saver() with tf.Session()  as sess:     sess.run(tf.global_variables_initializer())     w1  = tf.add(w1, w2)     saver.save(sess,  './my-model.ckpt')

    上面的代码中,创建了容器vars。它收集了tensor变量w1和w2。之后,tensorflow将这一容器保存。

    在session中运行,就能将数据保存到tensorflow创建的几个文件中。

    上面的代码运行结束后,当前目录下出现四个文件:

    my-model.ckpt.meta

    my-model.ckpt.data-*

    my-model.ckpt.index

    checkpoint

    利用这四个文件就能恢复出 w1和w2这两个变量。

    with tf.Session()  as sess:     new_saver  = tf.train.import_meta_graph( 'my-model.ckpt.meta')     new_saver.restore(sess, tf.train.latest_checkpoint( './'))     all_vars  = tf.get_collection( 'vars')      print(all_vars)      for v  in all_vars:          print(v)          print(v.name)         v_  = v.eval()  # sess.run(v)          print(v_)

    运行结果为:

    [ <tf.Tensor  'w1:0' shape=() dtype=float32_ref ><tf.Tensor  'w2:0' shape=() dtype=float32_ref >] Tensor( "w1:0"shape =(),  dtype =float32_ref) w1: 0 1.0 Tensor( "w2:0"shape =(),  dtype =float32_ref) w2: 0 2.0

    转载请注明原文地址: https://ju.6miu.com/read-600349.html

    最新回复(0)