tensorflow:保存和恢复会话

编程入门 行业动态 更新时间:2024-10-27 11:17:54
本文介绍了tensorflow:保存和恢复会话的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧! 问题描述

我正在尝试实施来自答案的建议:Tensorflow:如何保存/恢复模型?

I am trying to implement a suggestion from answers: Tensorflow: how to save/restore a model?

我有一个以 sklearn 样式包装 tensorflow 模型的对象.

I have an object which wraps a tensorflow model in a sklearn style.

import tensorflow as tf class tflasso(): saver = tf.train.Saver() def __init__(self, learning_rate = 2e-2, training_epochs = 5000, display_step = 50, BATCH_SIZE = 100, ALPHA = 1e-5, checkpoint_dir = "./", ): ... def _create_network(self): ... def _load_(self, sess, checkpoint_dir = None): if checkpoint_dir: self.checkpoint_dir = checkpoint_dir print("loading a session") ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: self.saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("no checkpoint found") return def fit(self, train_X, train_Y , load = True): self.X = train_X self.xlen = train_X.shape[1] # n_samples = y.shape[0] self._create_network() tot_loss = self._create_loss() optimizer = tf.train.AdagradOptimizer( self.learning_rate).minimize(tot_loss) # Initializing the variables init = tf.initialize_all_variables() " training per se" getb = batchgen( self.BATCH_SIZE) yvar = train_Y.var() print(yvar) # Launch the graph NUM_CORES = 3 # Choose how many cores to use. sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, intra_op_parallelism_threads=NUM_CORES) with tf.Session(config= sess_config) as sess: sess.run(init) if load: self._load_(sess) # Fit all training data for epoch in range( self.training_epochs): for (_x_, _y_) in getb(train_X, train_Y): _y_ = np.reshape(_y_, [-1, 1]) sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) # Display logs per epoch step if (1+epoch) % self.display_step == 0: cost = sess.run(tot_loss, feed_dict={ self.vars.xx: train_X, self.vars.yy: np.reshape(train_Y, [-1, 1])}) rsq = 1 - cost / yvar logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) print(logstr ) self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', global_step= 1+ epoch) print("Optimization Finished!") return self

当我跑步时:

tfl = tflasso() tfl.fit( train_X, train_Y , load = False)

我得到输出:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 b1: 0.118122 Epoch: 100 cost = 26.4506 R^2 = -0.5151 b1: 0.133597 Epoch: 150 cost = 22.4330 R^2 = -0.2850 b1: 0.142261 Epoch: 200 cost = 20.0361 R^2 = -0.1477 b1: 0.147998

但是,当我尝试恢复参数时(即使没有杀死对象):tfl.fit(train_X, train_Y, load = True)

However, when I try to recover the parameters (even without killing the object): tfl.fit( train_X, train_Y , load = True)

我得到了奇怪的结果.首先,加载的值与保存的值不对应.

I get strange results. First of all, the loaded value does not correspond the saved one.

loading a session loaded b1: 0.1 <------- Loaded another value than saved Epoch: 50 cost = 30.8483 R^2 = -0.7670 b1: 0.137484

加载的正确方法是什么,可能首先检查保存的变量?

What is the right way to load, and probably first inspect the saved variables?

推荐答案

TL;DR: 你应该尝试重新编写这个类,以便 self.create_network()(i) 只调用一次,并且 (ii) 在 tf.train.Saver() 构造之前调用.

TL;DR: You should try to rework this class so that self.create_network() is called (i) only once, and (ii) before the tf.train.Saver() is constructed.

这里有两个微妙的问题,这是由于代码结构和 tf.train.Saver 构造函数.当您构造一个没有参数的保护程序(如您的代码中)时,它会收集您程序中的当前变量集,并将操作添加到图形中以保存和恢复它们.在你的代码中,当你调用tflasso() 时,它会构造一个saver,并且不会有任何变量(因为create_network() 还没有被调用).结果,检查点应该是空的.

There are two subtle issues here, which are due to the code structure, and the default behavior of the tf.train.Saver constructor. When you construct a saver with no arguments (as in your code), it collects the current set of variables in your program, and adds ops to the graph for saving and restoring them. In your code, when you call tflasso(), it will construct a saver, and there will be no variables (because create_network() has not yet been called). As a result, the checkpoint should be empty.

第二个问题是——默认情况下——保存的检查点的格式是来自变量的name 属性 为其当前值.如果您创建两个同名变量,它们将被 TensorFlow 自动统一化":

The second issue is that—by default—the format of a saved checkpoint is a map from the name property of a variable to its current value. If you create two variables with the same name, they will be automatically "uniquified" by TensorFlow:

v = tf.Variable(..., name="weights") assert v.name == "weights" w = tf.Variable(..., name="weights") assert v.name == "weights_1" # The "_1" is added by TensorFlow.

这样做的结果是,当您在第二次调用 tfl.fit() 时调用 self.create_network() 时,变量都会有不同的名称来自存储在检查点中的名称——或者如果保护程序是在网络之后构建的.(您可以通过将 name-Variable 字典传递给保护程序构造函数来避免这种行为,但这通常很尴尬.)

The consequence of this is that, when you call self.create_network() in the second call to tfl.fit(), the variables will all have different names from the names that are stored in the checkpoint—or would have been if the saver had been constructed after the network. (You can avoid this behavior by passing a name-Variable dictionary to the saver constructor, but this is usually quite awkward.)

有两种主要的解决方法:

There are two main workarounds:

  • 在每次调用 tflasso.fit() 时,通过定义新的 tf.Graph 重新创建整个模型,然后在该图中构建网络并创建一个 tf.train.Saver.

  • In each call to tflasso.fit(), create the whole model afresh, by defining a new tf.Graph, then in that graph building the network and creating a tf.train.Saver.

    RECOMMENDED 创建网络,然后在 tflasso 构造函数中创建 tf.train.Saver,并在每个调用 tflasso.fit().请注意,您可能需要做更多的工作来重新组织事物(特别是,我不确定您对 self.X 和 self.xlen 做了什么)但它应该可以通过 占位符 和喂食来实现这一点.

    RECOMMENDED Create the network, then the tf.train.Saver in the tflasso constructor, and reuse this graph on each call to tflasso.fit(). Note that you might need to do some more work to reorganize things (in particular, I'm not sure what you do with self.X and self.xlen) but it should be possible to achieve this with placeholders and feeding.

  • 更多推荐

    tensorflow:保存和恢复会话

    本文发布于:2023-10-11 22:18:40,感谢您对本站的认可!
    本文链接:https://www.elefans.com/category/jswz/34/1482998.html
    版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
    本文标签:tensorflow

    发布评论

    评论列表 (有 0 条评论)
    草根站长

    >www.elefans.com

    编程频道|电子爱好者 - 技术资讯及电子产品介绍!