问题描述
限时送ChatGPT账号..我有一个已保存的 Tensorflow 图,它通过带有 feed_dict
参数的 placeholder
消耗输入.
I have a saved Tensorflow graph that consumes input through a placeholder
with a feed_dict
param.
sess.run(my_tensor, feed_dict={input_image: image})
因为使用 Dataset
Iterator
提供数据是 更高效,我想加载保存的图形,用 Iterator
替换 input_image
placeholder
并运行.我怎样才能做到这一点?有没有更好的方法来做到这一点?非常感谢提供代码示例的答案.
Because feeding data with a Dataset
Iterator
is more efficient, I want to load the saved graph, replace the input_image
placeholder
with an Iterator
and run. How can I do that? Is there a better way to do it? An answer with code example would be highly appreciated.
推荐答案
您可以通过序列化您的图形并使用 tf.import_graph_def
重新导入它来实现这一点,它有一个 input_map
> 用于在所需位置插入输入的参数.
You can achieve that by serializing your graph and reimport it using tf.import_graph_def
, which has an input_map
argument used to plug-in inputs at the desired places.
要做到这一点,您至少需要知道要替换的输入的名称以及要执行的输出的名称(在我的示例中分别是 x
和 y
).
To do that you need at least to know the name of the inputs you replace and of the outputs you wish to execute (resp. x
and y
in my examples).
import tensorflow as tf
# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')
# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
print("with placeholder:")
for i in range(10):
print(sess.run(y, {x: i}))
# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()
tf.reset_default_graph()
# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])
# enjoy Dataset inputs!
with tf.Session() as sess:
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass
请注意,占位符节点仍然存在,因为我没有费心在这里解析 graph_def
以将其删除 - 您可以将其删除以作为改进,尽管我认为保留它也可以在这里.
Note that the placeholder node is still there as I did not bother here to parse graph_def
to remove it -- you could remove it as an improvement, although I think it is also OK to leave it here.
根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需回到 GraphDef
).例如,如果您从 .meta
文件加载图形,则可以使用接受相同 input_map
参数的 tf.train.import_meta_graph
.
Depending on how you restore your graph, the input replacement may be already built-in in the loader, which makes things simpler (no need to go back to a GraphDef
). For example, if you load your graph from a .meta
file, you can use tf.train.import_meta_graph
which accepts the same input_map
argument.
import tensorflow as tf
# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')
# enjoy Dataset inputs!
with tf.Session() as sess:
# not needed here, but in practice you would also need to restore weights
# restorer.restore(sess, weights_filepath)
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass
这篇关于如何替换已保存图形的输入,例如数据集迭代器的占位符?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
更多推荐
[db:关键词]
发布评论