如何替换已保存图形的输入,例如数据集迭代器的占位符?

编程入门 行业动态 更新时间:2024-10-13 06:13:17
本文介绍了如何替换已保存图形的输入,例如数据集迭代器的占位符?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

限时送ChatGPT账号..

我有一个已保存的 Tensorflow 图,它​​通过带有 feed_dict 参数的 placeholder 消耗输入.

I have a saved Tensorflow graph that consumes input through a placeholder with a feed_dict param.

sess.run(my_tensor, feed_dict={input_image: image})

因为使用 Dataset Iterator 提供数据是 更高效,我想加载保存的图形,用 Iterator 替换 input_image placeholder 并运行.我怎样才能做到这一点?有没有更好的方法来做到这一点?非常感谢提供代码示例的答案.

Because feeding data with a Dataset Iterator is more efficient, I want to load the saved graph, replace the input_image placeholder with an Iterator and run. How can I do that? Is there a better way to do it? An answer with code example would be highly appreciated.

推荐答案

您可以通过序列化您的图形并使用 tf.import_graph_def 重新导入它来实现这一点,它有一个 input_map> 用于在所需位置插入输入的参数.

You can achieve that by serializing your graph and reimport it using tf.import_graph_def, which has an input_map argument used to plug-in inputs at the desired places.

要做到这一点,您至少需要知道要替换的输入的名称以及要执行的输出的名称(在我的示例中分别是 xy).

To do that you need at least to know the name of the inputs you replace and of the outputs you wish to execute (resp. x and y in my examples).

import tensorflow as tf

# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')

# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
  print("with placeholder:")
  for i in range(10):
    print(sess.run(y, {x: i}))

# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()

tf.reset_default_graph()

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])

# enjoy Dataset inputs!
with tf.Session() as sess:
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        

请注意,占位符节点仍然存在,因为我没有费心在这里解析 graph_def 以将其删除 - 您可以将其删除以作为改进,尽管我认为保留它也可以在这里.

Note that the placeholder node is still there as I did not bother here to parse graph_def to remove it -- you could remove it as an improvement, although I think it is also OK to leave it here.

根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需回到 GraphDef).例如,如果您从 .meta 文件加载图形,则可以使用接受相同 input_map 参数的 tf.train.import_meta_graph.

Depending on how you restore your graph, the input replacement may be already built-in in the loader, which makes things simpler (no need to go back to a GraphDef). For example, if you load your graph from a .meta file, you can use tf.train.import_meta_graph which accepts the same input_map argument.

import tensorflow as tf

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')

# enjoy Dataset inputs!
with tf.Session() as sess:
  # not needed here, but in practice you would also need to restore weights
  # restorer.restore(sess, weights_filepath)
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        

这篇关于如何替换已保存图形的输入,例如数据集迭代器的占位符?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

更多推荐

[db:关键词]

本文发布于:2023-05-01 05:23:05,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1405331.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:图形   迭代   数据

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!