admin管理员组

文章数量:1576283

错误信息全部信息:

WARNING:tensorflow:11 out of the last 11 calls to <function
train..train_step at 0x7f6d1843e840> triggered tf.function
retracing. Tracing is expensive and the excessive number of tracings
is likely due to passing python objects instead of tensors. Also,
tf.function has experimental_relax_shapes=True option that relaxes
argument shapes that can avoid unnecessary retracing. Please refer to
https://www.tensorflow/tutorials/customization/performance#python_or_tensor_args
and https://www.tensorflow/api_docs/python/tf/function for more
details.

StackOverflow上的解答多有出入,一般是他们写的代码本身问题。和上面错误信息描述的不一致。经过查看tf2.0官方的示例代码:https://tensorflow.google/tutorials/text/transformer
找到如下代码:

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

注意那个train_step_signature和tf.function(input_signature=train_step_signature)。
个人推测train函数有入参时候,这里必须声明入参的类型。比如这个例子中有两个入参则在train_step_signature中声明

本文标签: 错误TFfunction