tfe 模型保存和载入

编程入门 行业动态 更新时间:2024-10-27 10:33:16

tfe <a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型保存和载入"/>

tfe 模型保存和载入

原文链接: tfe 模型保存和载入

上一篇: tfe 配合 Keras model 线性拟合 和 自己处理梯度进行线性拟合

下一篇: tfe mnist 使用dataset 分类 保存和载入

简单参数保存和载入

如果路径不存在回自动创建

import tensorflow as tf
import tensorflow.contrib.eager as tfetf.enable_eager_execution()
x = tfe.Variable(10.)checkpoint = tfe.Checkpoint(x=x)
x.assign(2.)  # Assign a new value to the variables and save.
print(x.numpy())  # 2.0save_path = checkpoint.save('./ckpt/')
print(save_path)  # ./ckpt/-1x.assign(11.)  # Change the variable after saving.
print(x.numpy())  # 11.0# Restore values from the checkpoint
checkpoint.restore(save_path)
print(x.numpy())  # 2.0

使用Keras的Model时,需要保存很多参数,此时使用对象保存的方式

载入时使用的是模型的文件夹路径

主要代码

# 保存训练参数
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = './save/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,model=model,optimizer_step=tf.train.get_or_create_global_step())root.save(file_prefix=checkpoint_prefix)
# or
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))

完整代码

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import ostf.enable_eager_execution()class Model(tf.keras.Model):def __init__(self):super(Model, self).__init__()self.W = tfe.Variable(5., name='weight')self.B = tfe.Variable(10., name='bias')def call(self, inputs):return inputs * self.W + self.B# A toy dataset of points around 3 * x + 2
NUM_EXAMPLES = 2000
inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
targets = inputs * 3 + 2 + noise# The loss function to be optimized
def loss():error = model(inputs) - targetsreturn tf.reduce_mean(tf.square(error))# Define:
# 1. A model.
# 2. Derivatives of a loss function with respect to model parameters.
# 3. A strategy for updating the variables based on the derivatives.
model = Model()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)# 载入训练参数
# checkpoint_dir = './save/'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# root = tfe.Checkpoint(optimizer=optimizer,
#                       model=model,
#                       optimizer_step=tf.train.get_or_create_global_step())
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))
## Training loop
for i in range(300):optimizer.minimize(loss)if i % 20 == 0:print(model.W.numpy(), model.B.numpy())# 保存训练参数
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = './save/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,model=model,optimizer_step=tf.train.get_or_create_global_step())root.save(file_prefix=checkpoint_prefix)
# or
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))

更多推荐

tfe 模型保存和载入

本文发布于:2024-03-08 01:18:35,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1719465.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   tfe

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!