问题描述
限时送ChatGPT账号..我正在尝试使用两种不同的 mobilenet 模型.以下是我如何初始化模型的代码.
Im trying to use two different mobilenet models. Following is the code as of how I initialize the model.
def initialSetup():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
start_time = timeit.default_timer()
# This takes 2-5 seconds to run
# Unpersists graph from file
with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
age_graph_def = tf.GraphDef()
age_graph_def.ParseFromString(f.read())
tf.import_graph_def(age_graph_def, name='')
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
gender_graph_def = tf.GraphDef()
gender_graph_def.ParseFromString(f.read())
tf.import_graph_def(gender_graph_def, name='')
print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))
由于两者是两种不同的模型,我该如何使用它进行预测?
since both are two different models, how do I use it for predictions?
更新
initialSetup()
age_session = tf.Session(graph=age_graph_def)
gender_session = tf.Session(graph=gender_graph_def)
with tf.Session() as sess:
start_time = timeit.default_timer()
# Feed the image_data as input to the graph and get first prediction
softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0')
print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))
while True:
# Capture frame-by-frame
ret, frame = video_capture.read()
错误
回溯(最近一次调用):文件C:/Users/Desktop/untitled/testimg/testimg/combo.py",第 48 行,在age_session = tf.Session(graph=age_graph_def) 文件C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 1292 行,在 init 中super(Session, self).init(target, graph, config=config) 文件C:\ProgramFiles\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 529 行,在 init 中raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception异常被忽略:> 回溯(最近一次通话):文件"C:\程序Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第 587 行,del如果 self._session 不是 None: AttributeError: 'Session' object has no attribute '_session'
Traceback (most recent call last): File "C:/Users/Desktop/untitled/testimg/testimg/combo.py", line 48, in age_session = tf.Session(graph=age_graph_def) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1292, in init super(Session, self).init(target, graph, config=config) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 529, in init raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception ignored in: > Traceback (most recent call last): File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 587, in del if self._session is not None: AttributeError: 'Session' object has no attribute '_session'
推荐答案
当您在同一图中处理多个模型时,请使用名称范围为各个张量提供可预测的名称.例如,您可以将 initial_setup()
重写如下:
When you are working with multiple models in the same graph, use name scoping to give the individual tensors predictable names. For example, you could rewrite initial_setup()
as follows:
def initialSetup():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
start_time = timeit.default_timer()
# This takes 2-5 seconds to run
# Unpersists graph from file
with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
age_graph_def = tf.GraphDef()
age_graph_def.ParseFromString(f.read())
tf.import_graph_def(age_graph_def, name='age_model')
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
gender_graph_def = tf.GraphDef()
gender_graph_def.ParseFromString(f.read())
tf.import_graph_def(gender_graph_def, name='gender_model')
print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))
现在 age_graph_def
中所有节点的名称将以 "age_model/"
为前缀,gender_graph_def
中所有节点的名称code> 将以 "gender_model/"
为前缀.它们都是同一个默认图形的一部分,因此您可以使用一个没有 graph
参数的 tf.Session
来访问任一模型.
Now the names of all of the nodes from age_graph_def
will be prefixed with "age_model/"
and the names of all of the nodes from gender_graph_def
will be prefixed with "gender_model/"
. They are all part of the same default graph, so you can use a single tf.Session
with no graph
argument to access either model.
initialSetup()
with tf.Session() as sess:
start_time = timeit.default_timer()
# Feed the image_data as input to the graph and get first prediction
softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0')
# Alternatively, to get a tensor from the gender model:
# tensor = sess.graph.get_tensor_by_name('gender_model/...')
print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))
while True:
# Capture frame-by-frame
ret, frame = video_capture.read()
这篇关于在 tensorflow 中使用两种不同的模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
更多推荐
[db:关键词]
发布评论