tensorflow恢复模型

编程入门 行业动态 更新时间:2024-10-25 12:28:53

tensorflow恢复<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型"/>

tensorflow恢复模型

由于恢复模型时,按网上的操作存在很多问题,所以自己总结了一下:

问题:训练时,加载测试数据,测试很正常,但训练完,重新恢复模型进行测试时存在很大的偏差,就像随机的结果。这是因为恢复时一些图设置(具体什么原因也没正清楚,但找到了正确的恢复方法)

一.ckpt模型文件的恢复
1.保存模型ckpt

saver = tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:sess.run( tf.global_variables_initializer())for i in range(iterations):# ....训练。。。。if i%1500==0 or (i+1)==iterations: #每迭代1500次保存一次模型或者迭代结束时保存checkpoint_path = save_model_path+'/model.ckpt'saver.save(sess, checkpoint_path,global_step=i)

2.恢复模型
1)如果有网络结构代码,可以直接创建新的占位符


import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
graph = tf.Graph()
with graph.as_default():input_op = tf.placeholder(tf.float32, shape=[batch_size, data_size, data_size, 3],name='input_image')pred = net(input_op)#net结果为网络前向输出sess = tf.Session()saver = tf.train.Saver()saver.restore(sess, tf.train.latest_checkpoint(model_path))pred = sess.run(self.pred,feed_dict={input_op: data})

2)没有网络结构图(需要知道输入tensor名字和输出tensor名字)

import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
graph = tf.Graph()
with graph.as_default():saver = tf.train.import_meta_graph(model_path+'/model.ckpt-5362.meta')saver.restore(self.sess, tf.train.latest_checkpoint(model_path))graph = self.sess.graphinput_op = graph.get_tensor_by_name('inputs:0')pred = graph.get_tensor_by_name('output:0')predict = sess.run(pred,feed_dict={input_op: data})

训练时设置输入输出tensor名

    x = tf.placeholder(shape=[None, config.SIZE,config.SIZE, 3], dtype=tf.float32,name='inputs')y= tf.placeholder(shape=[None,config.OUTPUT_NUM], dtype=tf.float32,name='labels')training= tf.placeholder(tf.bool)result=net(x,training)result_ = tf.identity(result, name='output')#这个为输出

二、pb模型
1)保存为pb模型

x = tf.placeholder(shape=[None, config.SIZE,config.SIZE, 3], dtype=tf.float32,name='inputs')
y= tf.placeholder(shape=[None,config.OUTPUT_NUM], dtype=tf.float32,name='labels')
training= tf.placeholder(tf.bool)
result=net(x,training)
result_ = tf.identity(result, name='output')#这个为输出
....
#训练完后加上下面代码
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
with tf.gfile.FastGFile('model.pb', mode='wb') as f:f.write(constant_graph.SerializeToString())

2)恢复pb模型

model_path='model.pb'
with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")sess = tf.Session()sess.run(tf.global_variables_initializer())input_op = self.sess.graph.get_tensor_by_name('inputs:0')out_op = self.sess.graph.get_tensor_by_name('output:0')pred = self.sess.run(out_op,feed_dict={input_op: data})

更多推荐

tensorflow恢复模型

本文发布于:2024-02-10 23:13:27,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1677856.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!