Tensorflow 2.0 中的 tf.function 和 tf.while 循环

编程入门 行业动态 更新时间:2024-10-08 10:56:35
本文介绍了Tensorflow 2.0 中的 tf.function 和 tf.while 循环的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

限时送ChatGPT账号..

我正在尝试使用 tf.while_loop 并行化循环.正如此处所建议的那样,parallel_iterations 参数在 Eager 模式中没有影响.所以我试图用 tf.function 包装 tf.while_loop.但是,添加装饰器后,迭代变量的行为发生了变化.

I am trying to parallelize loop using tf.while_loop. As suggested here, the parallel_iterations argument doesn't make a difference in the eager mode. So I attempted to wrap tf.while_loop with tf.function. However, after adding the decorator,the behavior of the iteration variable changes.

例如,这段代码有效.

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)
tf.while_loop(c, print_fun, [iteration])

如果我添加装饰器,就会出现错误.

If I add the decorator, bug occurs.

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)

@tf.function
def run_graph():
    iteration = tf.constant(0)
    tf.while_loop(c, print_fun, [iteration])

run_graph()

从我的调试过程中,我发现变量 iteration 从张量变为占位符.这是为什么?我应该如何修改代码以消除错误?

From my debugging process, I found that variable iteration changes from a tensor to a placeholder. Why is that? How should I modify the code to eliminate the bug?

谢谢.

推荐答案

您的第一个代码段(没有 @tf.function 的代码段)中的代码利用了 TensorFlow 2 的急切执行来操作numpy 数组(即,您的外部 iteration 对象)直接.使用 @tf.function,这不起作用,因为@tf.function 试图将您的代码编译成 tf.Graph,它不能直接对 numpy 数组进行操作(它只能处理 tensorflow 张量).要解决此问题,请使用 tf.Variable 并继续为其切片赋值.

The code in your first snippet (the one without the @tf.function) takes advantage of TensorFlow 2's eager execution to manipulate a numpy array (i.e., your outer iteration object) directly. With @tf.function, this doesn't work because @tf.function tries to compile your code into a tf.Graph, which cannot operate on a numpy array directly (it can only process tensorflow tensors). To get around this issue, use a tf.Variable and keep assigning value into its slices.

使用 @tf.function,您实际上可以通过利用 @tf.function 的自动 Python-to-图形转换功能(称为 AutoGraph).你只需要写一个普通的 Python while 循环(使用 tf.less() 代替 < 运算符),while 循环会被 AutoGraph 编译成一个 tf.while_loop 在幕后.

With @tf.function, what you are trying to do is actually achievable with simpler code, by taking advantage of @tf.function's automatic Python-to-graph transformation feature (known as AutoGraph). You just write a normal Python while loop (using tf.less() in lieu of the < operator), and the while loop will be compiled by AutoGraph into a tf.while_loop under the hood.

代码看起来像:

result = tf.Variable(np.zeros([10], dtype=np.int32))

@tf.function
def run_graph():
  i = tf.constant(0, dtype=tf.int32)
  while tf.less(i, 10):
    result[i].assign(i)  # Performance may require tuning here.
    i += 1

run_graph()
print(result.read_value())

这篇关于Tensorflow 2.0 中的 tf.function 和 tf.while 循环的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

更多推荐

[db:关键词]

本文发布于:2023-05-01 01:58:16,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1402471.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:Tensorflow   function   tf

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!