模型"/>
tensorflow恢复模型
由于恢复模型时,按网上的操作存在很多问题,所以自己总结了一下:
问题:训练时,加载测试数据,测试很正常,但训练完,重新恢复模型进行测试时存在很大的偏差,就像随机的结果。这是因为恢复时一些图设置(具体什么原因也没正清楚,但找到了正确的恢复方法)
一.ckpt模型文件的恢复
1.保存模型ckpt
saver = tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:sess.run( tf.global_variables_initializer())for i in range(iterations):# ....训练。。。。if i%1500==0 or (i+1)==iterations: #每迭代1500次保存一次模型或者迭代结束时保存checkpoint_path = save_model_path+'/model.ckpt'saver.save(sess, checkpoint_path,global_step=i)
2.恢复模型
1)如果有网络结构代码,可以直接创建新的占位符
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
graph = tf.Graph()
with graph.as_default():input_op = tf.placeholder(tf.float32, shape=[batch_size, data_size, data_size, 3],name='input_image')pred = net(input_op)#net结果为网络前向输出sess = tf.Session()saver = tf.train.Saver()saver.restore(sess, tf.train.latest_checkpoint(model_path))pred = sess.run(self.pred,feed_dict={input_op: data})
2)没有网络结构图(需要知道输入tensor名字和输出tensor名字)
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
graph = tf.Graph()
with graph.as_default():saver = tf.train.import_meta_graph(model_path+'/model.ckpt-5362.meta')saver.restore(self.sess, tf.train.latest_checkpoint(model_path))graph = self.sess.graphinput_op = graph.get_tensor_by_name('inputs:0')pred = graph.get_tensor_by_name('output:0')predict = sess.run(pred,feed_dict={input_op: data})
训练时设置输入输出tensor名
x = tf.placeholder(shape=[None, config.SIZE,config.SIZE, 3], dtype=tf.float32,name='inputs')y= tf.placeholder(shape=[None,config.OUTPUT_NUM], dtype=tf.float32,name='labels')training= tf.placeholder(tf.bool)result=net(x,training)result_ = tf.identity(result, name='output')#这个为输出
二、pb模型
1)保存为pb模型
x = tf.placeholder(shape=[None, config.SIZE,config.SIZE, 3], dtype=tf.float32,name='inputs')
y= tf.placeholder(shape=[None,config.OUTPUT_NUM], dtype=tf.float32,name='labels')
training= tf.placeholder(tf.bool)
result=net(x,training)
result_ = tf.identity(result, name='output')#这个为输出
....
#训练完后加上下面代码
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
with tf.gfile.FastGFile('model.pb', mode='wb') as f:f.write(constant_graph.SerializeToString())
2)恢复pb模型
model_path='model.pb'
with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")sess = tf.Session()sess.run(tf.global_variables_initializer())input_op = self.sess.graph.get_tensor_by_name('inputs:0')out_op = self.sess.graph.get_tensor_by_name('output:0')pred = self.sess.run(out_op,feed_dict={input_op: data})
更多推荐
tensorflow恢复模型
发布评论