在 Python 中编写和注册自定义 Tensorflow Op

编程入门 行业动态 更新时间:2024-10-24 01:56:52
本文介绍了在 Python 中编写和注册自定义 Tensorflow Op的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧! 问题描述

我想用 Python 编写一个自定义的 Tensorflow 操作,并将其注册到 Protobuf 注册表中以进行类似解释的操作 此处.Protobuf 注册是关键,因为我不会直接从 Python 使用这个 op,但是如果它像 C++ op 一样注册并加载到 Python 运行时环境中,那么我可以在我的环境中运行它.

I want to write a custom Tensorflow op in Python and register it in the Protobuf registry for operations like explained here. The Protobuf registration is key because I will not be using this op directly from Python, but if it is registered like a C++ op and loaded into the Python runtime environment then I can run it in my environment.

我希望代码看起来像,

import tensorflow as tf from google.protobuf import json_format from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list """ Missing the Python equivalent of, class HDF5QueueOp : public ResourceOpKernel<QueueInterface> { public: // Implementation }; REGISTER_OP("HDF5Queue") .Output("handle: resource") .Attr("filename: string") .Attr("datasets: list(string)") .Attr("overwrite: bool = false") .Attr("component_types: list(type) >= 0 = []") .Attr("shapes: list(shape) >= 0 = []") .Attr("shared_name: string = ''") .Attr("container: string = ''") .Attr("capacity: int = -1") .SetIsStateful() .SetShapeFn(TwoElementOutput); """ class HDF5Queue(QueueBase): def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100, shapes=None, names=None, name="hdf5_queue"): if not dtypes: dtypes = [tf.int64, tf.float32] if not shapes: shapes = [[1], [1]] dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) names = _as_name_list(names, dtypes) queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id, stream_columns=stream_columns, capacity=capacity, component_types=dtypes, shapes=shapes, name=name, container=None, shared_name=None) super(HDF5Queue, self).__init__(dtypes, shapes, names, queue_ref)

以上是 TF 的标准.例如,可以使用 FIFOQueue 来查看.Python 包装器、Protobuf 注册,C++ 实现.在编译期间生成了一个我不喜欢的 Python 包装器,但是您可以通过运行 grep -A 10 -B 10 -n FIFO $(find/usr/local -name "*gen_data_flow*.py")/dev/null

The above is pretty standard from TF. It can be seen for example with FIFOQueue. Python Wrapper, Protobuf Registration, C++ Implementation. There is a Python wrapper generated during compilation that I can't like to, but you see where its used by running grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null

下面将以 JSON 格式为 TF Graph 转储 Protobuf 消息.我希望这会与 HDF5Queue 操作的块一起转储,就像我编写 C++ 操作一样.

Below will dump a Protobuf message for the TF Graph in JSON format. I would expect this to dump with a block for the HDF5Queue operation as it does if I write C++ operations.

with tf.Session() as sess: queue = HDF5Queue(stream_id=0xa) write = queue.enqueue([[1], [1.2]]) read = queue.dequeue() print json_format.MessageToJson(tf.train.export_meta_graph())

推荐答案

这可以使用 py_func 来完成.这是一个例子.

This can sort of be done using py_func. Here is an example.

import tensorflow as tf from google.protobuf import json_format import sys, json, base64, numpy from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef graph = tf.Graph() graph2 = tf.Graph() def f(x): return x def g(x): return 2*x with graph.as_default(): x = tf.placeholder(tf.float32, shape=(3,), name='x') y = tf.py_func(f, [x], tf.float32, name='y') # py_func_registry._funcs.clear() # Optional line to clear the Python function registry msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph())) # Change the function being used by py_func msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g)) with graph2.as_default(): # Load graph meta_graph_def = MetaGraphDef() json_format.Parse(json.dumps(msg), meta_graph_def) tf.train.import_meta_graph(meta_graph_def) sess = tf.Session(graph=graph2) print sess.run('y:0', feed_dict={'x:0':numpy.array([1, 2, 3])}) print g(numpy.array([1, 2, 3]))

更多推荐

在 Python 中编写和注册自定义 Tensorflow Op

本文发布于:2023-11-24 03:27:46,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1623804.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:自定义   Python   Op   Tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!