TensorFlow 模型持久化

编程入门 行业动态 更新时间:2024-10-22 11:11:40

TensorFlow 模型<a href=https://www.elefans.com/category/jswz/34/1771330.html style=持久化"/>

TensorFlow 模型持久化

为了让训练结果可以复用,下面介绍如何将训练得到的网络模型持久化。

代码实现

tf.train.Saver

有关[tf.train.Saver]类的官网文档见这里或者GitHub

简单实现

保存代码

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)saver.save(sess, "Saved_model/model.ckpt")

运行代码后输出

'Saved_model/model.ckpt'

观察当前文件夹,新生成了Saved_model文件夹,其中包含四个文件:

  • checkpoint:保存了一个目录下所有的模型文件列表。
  • model.ckpt.data-00000-of-00001:保存了TensorFlow当前参数值。
  • model.ckpt.index:保存了TensorFlow当前参数名。
  • model.ckpt.meta:保存了TensorFlow计算图的结构。

加载代码

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(result))

输出如下:

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[ 3.]

该代码首先定义了Tensorflow计算图中的所有运算结构,然后从本地文件中读入变量的值,不需要初始化变量。

加载持久化的图

若我们不希望代码中再次定义所有的结构,则可以加载已经保存了的图结构。代码如下:

import tensorflow as tfsaver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

输出如下

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[ 3.]

上述所有代码,默认保存和加载了TensorFlow计算图中定义的全部变量。

保存指定变量

保存代码

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2saver = tf.train.Saver([v1])with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(result))

上述程序会出错,报错信息如下:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2

读取时对变量重命名

保存代码如下:

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)saver.save(sess, "Saved_model/model.ckpt")

利用字典来重命名变量,key为结构图中的变量name,value为本地变量。加载代码如下:

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})

保存和加载滑动平均模型

使用变量重命名方式

保存代码如下:

import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():print(variables.name)ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():print(variables.name)saver = tf.train.Saver()
with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)sess.run(tf.assign(v, 10))sess.run(maintain_averages_op)# 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。saver.save(sess, "Saved_model/model2.ckpt")print(sess.run([v, ema.average(v)]))

输出如下

v:0
v:0
v/ExponentialMovingAverage:0
[10.0, 0.099999905]

加载代码,因为滑动平均模型的特性,读取变量v的值,实际是要读取变量v的滑动平均值。

import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print(sess.run(v))

输出如下

INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

使用variables_to_restore

为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore(Docs,Github)函数来生成tf.train.Saver类所需要的变量重命名字典。

代码如下:

import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print(sess.run(v))

输出如下:

{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

PB文件保存

保存

import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)graph_def = tf.get_default_graph().as_graph_def()output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())

graph_def = tf.get_default_graph().as_graph_def():导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。

graph_util.convert_variables_to_constants:将图中的变量和取值转化为常量。

此时只生成了一个文件combined_model.pb

输出

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.

加载代码

import tensorflow as tf
from tensorflow.python.platform import gfilewith tf.Session() as sess:model_filename = "Saved_model/combined_model.pb"with gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())result = tf.import_graph_def(graph_def, return_elements=["add:0"])print(sess.run(result))

输出

[array([ 3.], dtype=float32)]

持久化原理和数据格式

TensorFlow保存的文件为Protocol Buffer形式的。下面首页介绍这种格式的文件。

Protocol Buffer

Protocol Buffer是Google开发的处理结构化数据的工具。类似的还有XML、JSON。

比如需要保存以下的一些结构化信息:

name: 张三
id: 12345
email: zhangsan@abc

XML保存:

<user><name>张三</name><id>12345</id><email>zhangsan@abc</email>
</user>

JSON保存

{"name": "张三","id": "12345","email": "zhangsan@abc",
}

Protocol Buffer与这两者的区别:

  • XML和JSON格式的数据,序列化后为可读的字符串,该字符串中包含所有信息。
  • Protocol Buffer序列化后为不可读的二进制流,使用Protocol Buffer需先定义数据的格式(schema),还原数据时也需要相应的格式。
  • Protocol Buffer序列化后的数据比XML或JSON小3到10倍,解析时间快20到100倍。

格式schema文件定义如下:

message user{optional string name = 1;required int32 id = 2;repeated string email = 3;
}

.ckpt.meta —— MetaGraphDef

TensorFlow是一个通过图的形式来表述计算的编程系统,TensorFlow程序中的所有计算都会被表达为计算图上的节点。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。

类型定义如下,详见Github:

message MetaGraphDef{MetaInfoDef meta_info_def = 1;GraphDef graph_def = 2;SaverDef saver_def = 3;map<string, CollectionDef> collection_def = 4;map<string, SignatureDef> signature_def = 5;repeated AssetFileDef asset_file_def = 6;
}

以上信息都保存在了model.ckpt.meta文件中,此为二进制文件,无法直接查看。为了方便调试,TensorFlow提供了export_meta_graph函数,支持以Json格式导出Protocol Buffer。代码如下

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result1 = v1 + v2saver = tf.train.Saver()
saver.export_meta_graph("Saved_model/model.ckpt.meta.json", as_text=True)

查看Json文件

meta_info_def {stripped_op_list {op {name: "Add"input_arg {name: "x"type_attr: "T"}input_arg {name: "y"type_attr: "T"}output_arg {name: "z"type_attr: "T"}attr {name: "T"type: "type"allowed_values {list {type: DT_HALFtype: DT_FLOATtype: DT_DOUBLEtype: DT_UINT8type: DT_INT8type: DT_INT16type: DT_INT32type: DT_INT64type: DT_COMPLEX64type: DT_COMPLEX128type: DT_STRING}}}}op {name: "Assign"input_arg {name: "ref"type_attr: "T"is_ref: true}input_arg {name: "value"type_attr: "T"}output_arg {name: "output_ref"type_attr: "T"is_ref: true}attr {name: "T"type: "type"}attr {name: "validate_shape"type: "bool"default_value {b: true}}attr {name: "use_locking"type: "bool"default_value {b: true}}allows_uninitialized_input: true}op {name: "Const"output_arg {name: "output"type_attr: "dtype"}attr {name: "value"type: "tensor"}attr {name: "dtype"type: "type"}}op {name: "Identity"input_arg {name: "input"type_attr: "T"}output_arg {name: "output"type_attr: "T"}attr {name: "T"type: "type"}}op {name: "NoOp"}op {name: "RestoreV2"input_arg {name: "prefix"type: DT_STRING}input_arg {name: "tensor_names"type: DT_STRING}input_arg {name: "shape_and_slices"type: DT_STRING}output_arg {name: "tensors"type_list_attr: "dtypes"}attr {name: "dtypes"type: "list(type)"has_minimum: trueminimum: 1}is_stateful: true}op {name: "SaveV2"input_arg {name: "prefix"type: DT_STRING}input_arg {name: "tensor_names"type: DT_STRING}input_arg {name: "shape_and_slices"type: DT_STRING}input_arg {name: "tensors"type_list_attr: "dtypes"}attr {name: "dtypes"type: "list(type)"has_minimum: trueminimum: 1}is_stateful: true}op {name: "VariableV2"output_arg {name: "ref"type_attr: "dtype"is_ref: true}attr {name: "shape"type: "shape"}attr {name: "dtype"type: "type"}attr {name: "container"type: "string"default_value {s: ""}}attr {name: "shared_name"type: "string"default_value {s: ""}}is_stateful: true}}tensorflow_version: "1.3.0"tensorflow_git_version: "v1.3.0-rc2-20-g0787eee"
}
graph_def {node {name: "Const"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_FLOAT}}attr {key: "value"value {tensor {dtype: DT_FLOATtensor_shape {dim {size: 1}}float_val: 1.0}}}}node {name: "v1"op: "VariableV2"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "container"value {s: ""}}attr {key: "dtype"value {type: DT_FLOAT}}attr {key: "shape"value {shape {dim {size: 1}}}}attr {key: "shared_name"value {s: ""}}}node {name: "v1/Assign"op: "Assign"input: "v1"input: "Const"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v1"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "use_locking"value {b: true}}attr {key: "validate_shape"value {b: true}}}node {name: "v1/read"op: "Identity"input: "v1"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v1"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}}node {name: "Const_1"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_FLOAT}}attr {key: "value"value {tensor {dtype: DT_FLOATtensor_shape {dim {size: 1}}float_val: 2.0}}}}node {name: "v2"op: "VariableV2"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "container"value {s: ""}}attr {key: "dtype"value {type: DT_FLOAT}}attr {key: "shape"value {shape {dim {size: 1}}}}attr {key: "shared_name"value {s: ""}}}node {name: "v2/Assign"op: "Assign"input: "v2"input: "Const_1"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v2"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "use_locking"value {b: true}}attr {key: "validate_shape"value {b: true}}}node {name: "v2/read"op: "Identity"input: "v2"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v2"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}}node {name: "add"op: "Add"input: "v1/read"input: "v2/read"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}}node {name: "save/Const"op: "Const"attr {key: "_output_shapes"value {list {shape {}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {}string_val: "model"}}}}node {name: "save/SaveV2/tensor_names"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 2}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 2}}string_val: "v1"string_val: "v2"}}}}node {name: "save/SaveV2/shape_and_slices"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 2}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 2}}string_val: ""string_val: ""}}}}node {name: "save/SaveV2"op: "SaveV2"input: "save/Const"input: "save/SaveV2/tensor_names"input: "save/SaveV2/shape_and_slices"input: "v1"input: "v2"attr {key: "dtypes"value {list {type: DT_FLOATtype: DT_FLOAT}}}}node {name: "save/control_dependency"op: "Identity"input: "save/Const"input: "^save/SaveV2"attr {key: "T"value {type: DT_STRING}}attr {key: "_class"value {list {s: "loc:@save/Const"}}}attr {key: "_output_shapes"value {list {shape {}}}}}node {name: "save/RestoreV2/tensor_names"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 1}}string_val: "v1"}}}}node {name: "save/RestoreV2/shape_and_slices"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 1}}string_val: ""}}}}node {name: "save/RestoreV2"op: "RestoreV2"input: "save/Const"input: "save/RestoreV2/tensor_names"input: "save/RestoreV2/shape_and_slices"attr {key: "_output_shapes"value {list {shape {unknown_rank: true}}}}attr {key: "dtypes"value {list {type: DT_FLOAT}}}}node {name: "save/Assign"op: "Assign"input: "v1"input: "save/RestoreV2"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v1"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "use_locking"value {b: true}}attr {key: "validate_shape"value {b: true}}}node {name: "save/RestoreV2_1/tensor_names"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 1}}string_val: "v2"}}}}node {name: "save/RestoreV2_1/shape_and_slices"op: "Const"attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "dtype"value {type: DT_STRING}}attr {key: "value"value {tensor {dtype: DT_STRINGtensor_shape {dim {size: 1}}string_val: ""}}}}node {name: "save/RestoreV2_1"op: "RestoreV2"input: "save/Const"input: "save/RestoreV2_1/tensor_names"input: "save/RestoreV2_1/shape_and_slices"attr {key: "_output_shapes"value {list {shape {unknown_rank: true}}}}attr {key: "dtypes"value {list {type: DT_FLOAT}}}}node {name: "save/Assign_1"op: "Assign"input: "v2"input: "save/RestoreV2_1"attr {key: "T"value {type: DT_FLOAT}}attr {key: "_class"value {list {s: "loc:@v2"}}}attr {key: "_output_shapes"value {list {shape {dim {size: 1}}}}}attr {key: "use_locking"value {b: true}}attr {key: "validate_shape"value {b: true}}}node {name: "save/restore_all"op: "NoOp"input: "^save/Assign"input: "^save/Assign_1"}versions {producer: 24}
}
saver_def {filename_tensor_name: "save/Const:0"save_tensor_name: "save/control_dependency:0"restore_op_name: "save/restore_all"max_to_keep: 5keep_checkpoint_every_n_hours: 10000.0version: V2
}
collection_def {key: "trainable_variables"value {bytes_list {value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"}}
}
collection_def {key: "variables"value {bytes_list {value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"}}
}

meta_info_def属性

保存了Tensorflow计算图中的元数据和程序中所有用到的运算方法的信息。

定义如下:

message MetaInfoDef {string meta_graph_version = 1;OpList stripped_op_list = 2;google.protobuf.Any any_info = 3;repeated string tags = 4;string tensorflow_version = 5;string tensorflow_git_version = 6;

OpList定义见Github

在OpDef中的attr属性中,必须包含name为T的属性,指定了运算输入输出允许的参数类型。

graph_def

GraphDef、NodeDef

主要记录计算图上的节点信息。

saver_def

SaverDef

主要记录持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史纪录等。

collection_def

CollectionDef

维护不同的集合,是一个从集合名称到集合内容的映射。

.ckpt

TensorFlow采用tf.train.NewCheckpointReader来读取ckpt文件中的所有变量信息。

import tensorflow as tfreader = tf.train.NewCheckpointReader("Saved_model/model.ckpt")all_variables = reader.get_variable_to_shape_map()for variable_name in all_variables:print(variable_name, all_variables[variable_name])print("Value for variable v1 is ", reader.get_tensor("v1"))

tf.train.NewCheckpointReader读取ckpt文件中的所有变量。
variable_name为变量名称
all_variables[variable_name]为变量维度

输出如下:

v2 [1]
v1 [1]
Value for variable v1 is  [ 1.]

checkpoint

tf.train.Saver类自动生成且维护,记录所有Tensorflow模型文件的文件名。可读

格式如下:

message CheckpointState{string model_checkpoint_path = 1;repeated string all_model_checkpoint_paths = 2;
}

实例如下:

model_checkpoint_path: "model.ckpt"
all_model_checkpoint_paths: "model.ckpt"

更多推荐

TensorFlow 模型持久化

本文发布于:2024-02-07 05:34:43,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1753537.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:持久   模型   TensorFlow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!