在 tensorflow 中使用两种不同的模型

编程入门 行业动态 更新时间:2024-10-06 01:42:07
本文介绍了在 tensorflow 中使用两种不同的模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

限时送ChatGPT账号..

我正在尝试使用两种不同的 mobilenet 模型.以下是我如何初始化模型的代码.

Im trying to use two different mobilenet models. Following is the code as of how I initialize the model.

def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='')

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

由于两者是两种不同的模型,我该如何使用它进行预测?

since both are two different models, how do I use it for predictions?

更新

initialSetup()

age_session = tf.Session(graph=age_graph_def)
gender_session = tf.Session(graph=gender_graph_def)

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0')

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()

错误

回溯(最近一次调用):文件C:/Users/Desktop/untitled/testimg/testimg/combo.py",第 48 行,在age_session = tf.Session(graph=age_graph_def) 文件C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 1292 行,在 init 中super(Session, self).init(target, graph, config=config) 文件C:\ProgramFiles\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 529 行,在 init 中raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception异常被忽略:> 回溯(最近一次通话):文件"C:\程序Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 587 行,del如果 self._session 不是 None: AttributeError: 'Session' object has no attribute '_session'

Traceback (most recent call last): File "C:/Users/Desktop/untitled/testimg/testimg/combo.py", line 48, in age_session = tf.Session(graph=age_graph_def) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1292, in init super(Session, self).init(target, graph, config=config) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 529, in init raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception ignored in: > Traceback (most recent call last): File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 587, in del if self._session is not None: AttributeError: 'Session' object has no attribute '_session'

推荐答案

当您在同一图中处理多个模型时,请使用名称范围为各个张量提供可预测的名称.例如,您可以将 initial_setup() 重写如下:

When you are working with multiple models in the same graph, use name scoping to give the individual tensors predictable names. For example, you could rewrite initial_setup() as follows:

def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='age_model')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='gender_model')

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

现在 age_graph_def 中所有节点的名称将以 "age_model/" 为前缀,gender_graph_def 中所有节点的名称code> 将以 "gender_model/" 为前缀.它们都是同一个默认图形的一部分,因此您可以使用一个没有 graph 参数的 tf.Session 来访问任一模型.

Now the names of all of the nodes from age_graph_def will be prefixed with "age_model/" and the names of all of the nodes from gender_graph_def will be prefixed with "gender_model/". They are all part of the same default graph, so you can use a single tf.Session with no graph argument to access either model.

initialSetup()

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0')

    # Alternatively, to get a tensor from the gender model:
    # tensor = sess.graph.get_tensor_by_name('gender_model/...')

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()

这篇关于在 tensorflow 中使用两种不同的模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

更多推荐

[db:关键词]

本文发布于:2023-05-01 01:55:21,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1402437.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:两种   模型   tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!