tfe mnist 使用dataset 分类 保存和载入

编程入门 行业动态 更新时间:2024-10-27 08:24:01

<a href=https://www.elefans.com/category/jswz/34/1719472.html style=tfe mnist 使用dataset 分类 保存和载入"/>

tfe mnist 使用dataset 分类 保存和载入

原文链接: tfe mnist 使用dataset 分类 保存和载入

上一篇: tfe 模型保存和载入

下一篇: tfe 在静态图中的使用

使用Model构建模型

class Model(tf.keras.Model):def __init__(self):super(Model, self).__init__()self.model = keras.Sequential()self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Flatten())self.model.add(keras.layers.Dense(10))def call(self, in_x):# net = slim.conv2d(in_x, 32, 3, 2)# net = slim.conv2d(net, 32, 3, 2)# net = slim.conv2d(net, 32, 3, 2)# net = slim.flatten(net)# net = slim.fully_connected(net, 10)# res = self.model(in_x)# print(res.shape)return self.model(in_x)

不使用slim模块的原因是,每次获取输出时,都会重新运行call函数,这样的话slim构建的网络每次都不一样,参数也都不同无法进行优化

dataset的读取是直接使用迭代器

    filenames = ['d:/data/mnist/record/mnist_train.record']dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(parser).shuffle(128)dataset = dataset.repeat(-1).batch(32)lr = .01optimizer = tf.train.AdamOptimizer(lr)model = Model()show_step = 100train_step = 10000for (i, (images, labels)) in zip(range(1, 1 + train_step), dataset):# print(i, images.shape, labels.shape)  # 1 (4, 28, 28, 1) (4,)optimizer.minimize(lambda: loss(model, images, labels))if not i % show_step:print(i, accuracy(model, images, labels).numpy())

先训练,训练完毕进行载入模型测试

完整代码

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow import keras
import tensorflow.contrib.eager as tfe
import ostf.enable_eager_execution()def parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)print("IMAGES", images.shape)  # IMAGES (28, 28, 1)print("LABELS", labels.shape)  # LABELS ()return images, labelsclass Model(tf.keras.Model):def __init__(self):super(Model, self).__init__()self.model = keras.Sequential()self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Conv2D(32, 3, 2, "SAME"))self.model.add(keras.layers.Flatten())self.model.add(keras.layers.Dense(10))def call(self, in_x):# net = slim.conv2d(in_x, 32, 3, 2)# net = slim.conv2d(net, 32, 3, 2)# net = slim.conv2d(net, 32, 3, 2)# net = slim.flatten(net)# net = slim.fully_connected(net, 10)# res = self.model(in_x)# print(res.shape)return self.model(in_x)def loss(mode, images, labels):logits = mode(images)ls = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)return lsdef predict(model, images):return tf.argmax(model(images), axis=1)def accuracy(model, images, labels):predict_labels = predict(model, images)# print(predict_labels.numpy())# print(labels.numpy())return tf.reduce_mean(tf.cast(tf.equal(predict_labels, labels), dtype=tf.float32))def main():filenames = ['d:/data/mnist/record/mnist_train.record']dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(parser).shuffle(128)dataset = dataset.repeat(-1).batch(32)lr = .01optimizer = tf.train.AdamOptimizer(lr)model = Model()show_step = 100train_step = 10000for (i, (images, labels)) in zip(range(1, 1 + train_step), dataset):# print(i, images.shape, labels.shape)  # 1 (4, 28, 28, 1) (4,)optimizer.minimize(lambda: loss(model, images, labels))if not i % show_step:print(i, accuracy(model, images, labels).numpy())# 保存训练参数optimizer = tf.train.AdamOptimizer(learning_rate=0.001)checkpoint_dir = './save/'checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")root = tfe.Checkpoint(optimizer=optimizer,model=model,optimizer_step=tf.train.get_or_create_global_step())root.save(file_prefix=checkpoint_prefix)def test():filenames = ['d:/data/mnist/record/mnist_train.record']dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(parser).shuffle(128)dataset = dataset.repeat(5).batch(32)model = Model()# 恢复训练参数optimizer = tf.train.AdamOptimizer(learning_rate=0.001)checkpoint_dir = './save/'root = tfe.Checkpoint(optimizer=optimizer,model=model,optimizer_step=tf.train.get_or_create_global_step())root.restore(tf.train.latest_checkpoint(checkpoint_dir))for (i, (images, labels)) in zip(range(1, 1 + 10), dataset):print(i, accuracy(model, images, labels).numpy())if __name__ == '__main__':main()test()


运行结果,如果训练次数较少,载入模型后的准确率很小

9000 0.90625
9100 0.90625
9200 0.875
9300 0.84375
9400 0.8125
9500 0.90625
9600 0.84375
9700 0.875
9800 1.0
9900 0.96875
10000 0.875
IMAGES (28, 28, 1)
LABELS ()
1 0.84375
2 0.8125
3 0.84375
4 0.84375
5 0.84375
6 0.8125
7 0.8125
8 0.9375
9 0.78125
10 0.875

更多推荐

tfe mnist 使用dataset 分类 保存和载入

本文发布于:2024-03-08 01:18:09,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1719464.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:tfe   mnist   dataset

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!