全连接网络:实现第一个全连接网络

编程入门 行业动态 更新时间:2024-10-11 09:27:40

全连接网络:实现<a href=https://www.elefans.com/category/jswz/34/1770593.html style=第一个全连接网络"/>

全连接网络:实现第一个全连接网络

全连接网络实现手写数字识别

程序分为三个部分,分别是

mnist_forward.py:前向传播。

mnist_backward.py:反向传播。

mnist_test.py:模型测试。

前向传播

这里搭建了全连接网络。我使用了一个三层的网络。输入的是784个神经元(mnist中一张图片的大小);隐藏层的神经元的个数分别是500,200;输出层是10个神经元,是预测的结果。

下面是mnist_forward.py文件

import tensorflow as tf
import osos.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息def get_weight(shape,regularizer):w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))if regularizer!=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))return wdef get_bias(shape):b = tf.Variable(tf.zeros(shape))return bdef forward(x,regularizer):'''这里定义一个3层的全连接网络第0层:输入层,784个神经元第1层:隐藏层,500个神经元第2层:隐藏层,200个神经元第3层:输出层,10个神经元'''w1 = get_weight([784,500],regularizer)b1 = get_bias([500])y1 = tf.nn.leaky_relu(tf.matmul(x,w1) + b1)w2 = get_weight([500,200],regularizer)b2 = get_bias([200])y2 = tf.nn.leaky_relu(tf.matmul(y1,w2) + b2)w3 = get_weight([200,10],regularizer)b3 = get_bias([10])y = tf.matmul(y2,w3) + b3return y

反向传播

1)参数使用了滑动平均

2)使用指数衰减学习率

3)在训练过程中,进行了模型的保存

下面是mnist_backward.py

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息BATCH_SIZE = 200
LEARNING_TARE_BASE = 0.1
LEARNING_TARE_DECAY = 0.99
STEPS = 50000
regularizer = 0.0001
moving_average_decay = 0.99
model_saver_path = './model/' #模型保存的路径
model_name = 'mnist_model' #模型的名字def backward(mnist):x = tf.placeholder(tf.float32, [None, 784]) # 输入y_ = tf.placeholder(tf.float32, [None, 10]) # 标签y = mnist_forward.forward(x, regularizer) # 前向传播的输出global_step = tf.Variable(0, trainable=False) # 计数器#加入滑动平均ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cem = tf.reduce_mean(ce)loss = cem + tf.add_n(tf.get_collection('losses'))#指数衰减学习率learning_rate = tf.train.exponential_decay(LEARNING_TARE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_TARE_DECAY,staircase=True)#定义训练过程train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)#如果有滑动平均ema = tf.train.ExponentialMovingAverage(moving_average_decay,global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step,ema_op]):train_op = tf.no_op('train')#实例化saversaver = tf.train.Saver()#建立会话with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE) #加载BATCH_SIZE个mnist中的图片和标签_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i%1000 ==0:print('After %d training step, loss on training batch is %g' %(step,loss_value))# 保存模型到当前会话saver.save(sess,os.path.join(model_saver_path,model_name),global_step=global_step)def main():mnist =input_data.read_data_sets('./data/',one_hot=True)backward(mnist)if __name__ == '__main__':main()

 

模型测试

因为在训练时使用滑动平均,所以在测试时,需要恢复参数的滑动平均值。

 with tf.Graph().as_default() as g:  # 其内定义的节点在计算图g中

用这种方法,将神经网络复现到计算图中。

下面是:mnist_test.py

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backwardos.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息def test(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32,[None,784])y_ = tf.placeholder(tf.float32,[None,10])y = mnist_forward.forward(x,None) # 前向传播获得的输出值#实例化带滑动平均的saver对象ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)# 准确率计算correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) # 将输出结果和标签答案进行比较accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))with tf.Session( ) as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.model_saver_path)#判断,如果有模型,恢复模型到当前会话if ckpt and ckpt.model_checkpoint_path:saver.restore(sess,ckpt.model_checkpoint_path) # 恢复模型global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # 获取当前轮数# 喂入的是数据集中的,测试用的图片和标签accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('After %s training step, test accuracy = %g' %(global_step,accuracy_score))else:print('No checkpoint file found')def main():mnist = input_data.read_data_sets('./data/',one_hot=True)test(mnist)if __name__ == '__main__':main()

结果

训练结果

After 1 training step, loss on training batch is 2.7423
After 1001 training step, loss on training batch is 0.377551
After 2001 training step, loss on training batch is 0.243816

......

fter 36001 training step, loss on training batch is 0.144181
After 37001 training step, loss on training batch is 0.143775
After 38001 training step, loss on training batch is 0.14224
After 39001 training step, loss on training batch is 0.144268
After 40001 training step, loss on training batch is 0.142971
After 41001 training step, loss on training batch is 0.14316
After 42001 training step, loss on training batch is 0.142349
After 43001 training step, loss on training batch is 0.142429
After 44001 training step, loss on training batch is 0.140229
After 45001 training step, loss on training batch is 0.140057
After 46001 training step, loss on training batch is 0.138865
After 47001 training step, loss on training batch is 0.137823
After 48001 training step, loss on training batch is 0.138767
After 49001 training step, loss on training batch is 0.139221

测试结果

After 49001 training step, test accuracy = 0.98

如何实现断点续训?

实现断点续训,这样就可以在出现意外的情况下保存训练好的模型,下次训练,在此基础上进行。

上面的案例,若想要实现断点续训的功能,只需要在“反向传播”文件中,添加恢复模型的操作即可。

...... 
#建立会话with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)#实现断点续训,只需要加入下面三句话ckpt = tf.train.get_checkpoint_state(model_saver_path)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)  # 恢复模型for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE) #加载BATCH_SIZE个mnist中的图片和标签_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i%1000 ==0:
......

 

更多推荐

全连接网络:实现第一个全连接网络

本文发布于:2023-06-26 04:21:15,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/889202.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:第一个   网络

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!