问题描述
限时送ChatGPT账号..我正在使用 Tensorflow Dataset API 来准备我的数据用于输入到我的网络中.在此过程中,我使用 tf.py_function
将一些自定义 Python 函数映射到数据集.我希望能够调试进入这些函数的数据以及这些函数内的数据会发生什么.当 py_function
被调用时,这会回调到 Python 主进程(根据 this answer).由于此函数在 Python 中,并且在主进程中,我希望常规 IDE 断点能够在此进程中停止.但是,情况似乎并非如此(下面的示例中断点不会停止执行).有没有办法在数据集 map
使用的 py_function
中放入断点?
I'm using the Tensorflow Dataset API to prepare my data for input into my network. During this process, I have some custom Python functions which are mapped to the dataset using tf.py_function
. I want to be able to debug the data going into these functions and what happens to that data inside these functions. When a py_function
is called, this calls back to the main Python process (according to this answer). Since this function is in Python, and in the main process, I would expect a regular IDE breakpoint to be able stop in this process. However, this doesn't seem to be the case (example below where the breakpoint does not halt execution). Is there a way to drop into a breakpoint within a py_function
used by the Dataset map
?
断点不停止执行的示例
import tensorflow as tf
def add_ten(example, label):
example_plus_ten = example + 10 # Breakpoint here.
return example_plus_ten, label
examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels = [ 0, 0, 1, 1, 1, 1, 0, 0]
examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
推荐答案
tf.data.Dataset 的 Tensorflow 2.0 实现会为每次调用打开一个 C 线程,而不会通知您的调试器.使用 pydevd
手动设置跟踪功能,该功能将连接到您的默认调试器服务器并开始向其提供调试数据.
Tensorflow 2.0 implementation of tf.data.Dataset opens a C threads for each call without notifying your debugger.
Use pydevd
's to manually set a tracing function that will connect to your default debugger server and start feeding it the debug data.
import pydevd
pydevd.settrace()
代码示例:
import tensorflow as tf
import pydevd
def add_ten(example, label):
pydevd.settrace(suspend=False)
example_plus_ten = example + 10 # Breakpoint here.
return example_plus_ten, label
examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels = [ 0, 0, 1, 1, 1, 1, 0, 0]
examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
注意:如果您使用的 IDE 已经捆绑了 pydevd(例如 PyDev 或 PyCharm),则不必单独安装 pydevd
,它会在调试会话期间被选中.
Note: If you are using IDE which already bundles pydevd (such as PyDev or PyCharm) you do not have to install pydevd
separately, it will picked up during the debug session.
这篇关于TensorFlow 数据集 API 中的 IDE 断点映射 py_function?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
更多推荐
[db:关键词]
发布评论